package tsgen
import (
"errors"
"fmt"
"go/ast"
"go/parser"
"go/token"
"regexp"
"strings"
"github.com/projectdiscovery/gologger"
sliceutil "github.com/projectdiscovery/utils/slice"
"golang.org/x/tools/go/packages"
)
type EntityParser struct {
syntax []*ast.File
structTypes map[string]Entity
imports map[string]*packages.Package
newObjects map[string]*Entity
vars []Entity
entities []Entity
}
func NewEntityParser(dir string) (*EntityParser, error) {
cfg := &packages.Config{
Mode: packages.NeedName | packages.NeedFiles | packages.NeedImports |
packages.NeedTypes | packages.NeedSyntax | packages.NeedTypes |
packages.NeedModule | packages.NeedTypesInfo,
Tests: false,
Dir: dir,
ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) {
return parser.ParseFile(fset, filename, src, parser.ParseComments)
},
}
pkgs, err := packages.Load(cfg, ".")
if err != nil {
return nil, err
}
if len(pkgs) == 0 {
return nil, errors.New("no packages found")
}
pkg := pkgs[0]
return &EntityParser{
syntax: pkg.Syntax,
structTypes: map[string]Entity{},
imports: map[string]*packages.Package{},
newObjects: map[string]*Entity{},
}, nil
}
func (p *EntityParser) GetEntities() []Entity {
return p.entities
}
func (p *EntityParser) Parse() error {
p.extractVarsNConstants()
p.extractStructTypes()
if err := p.loadImportedPackages(); err != nil {
return err
}
for _, file := range p.syntax {
ast.Inspect(file, func(n ast.Node) bool {
fn, ok := n.(*ast.FuncDecl)
if ok {
if !isExported(fn.Name.Name) {
return false
}
entity, err := p.extractFunctionFromNode(fn)
if err != nil {
gologger.Error().Msgf("Could not extract function %s: %s\n", fn.Name.Name, err)
return false
}
if entity.IsConstructor {
p.entities = append(p.entities, entity)
return false
}
if fn.Recv != nil {
receiverName := exprToString(fn.Recv.List[0].Type)
if _, ok := p.structTypes[receiverName]; ok {
method := Method{
Name: entity.Name,
Description: strings.ReplaceAll(entity.Description, "Function", "Method"),
Parameters: entity.Function.Parameters,
Returns: entity.Function.Returns,
CanFail: entity.Function.CanFail,
ReturnStmt: entity.Function.ReturnStmt,
}
allMethods := p.structTypes[receiverName].Class.Methods
if allMethods == nil {
allMethods = []Method{}
}
entity = p.structTypes[receiverName]
entity.Class.Methods = append(allMethods, method)
p.structTypes[receiverName] = entity
return false
}
}
p.entities = append(p.entities, entity)
return false
}
return true
})
}
for _, file := range p.syntax {
ast.Inspect(file, func(n ast.Node) bool {
typeSpec, ok := n.(*ast.TypeSpec)
if ok {
if !isExported(typeSpec.Name.Name) {
return false
}
structType, ok := typeSpec.Type.(*ast.StructType)
if !ok {
return false
}
entity := Entity{
Name: typeSpec.Name.Name,
Type: "class",
Description: Ternary(strings.TrimSpace(typeSpec.Doc.Text()) != "", typeSpec.Doc.Text(), typeSpec.Name.Name+" Class"),
Class: Class{
Properties: p.extractClassProperties(structType),
},
}
if _, ok := p.structTypes[typeSpec.Name.Name]; ok {
entity.Class.Methods = p.structTypes[typeSpec.Name.Name].Class.Methods
entity.Description = p.structTypes[typeSpec.Name.Name].Description
p.structTypes[typeSpec.Name.Name] = entity
} else {
p.structTypes[typeSpec.Name.Name] = entity
}
return false
}
return true
})
}
for k, v := range p.structTypes {
if v.Type == "class" && len(v.Class.Methods) > 0 {
p.entities = append(p.entities, v)
} else if v.Type == "class" && len(v.Class.Methods) == 0 {
if k == "Object" {
continue
}
entity := Entity{
Name: k,
Type: "interface",
Description: strings.TrimSpace(strings.ReplaceAll(v.Description, "Class", "interface")),
Object: Interface{
Properties: v.Class.Properties,
},
}
p.entities = append(p.entities, entity)
}
}
for k := range p.newObjects {
if err := p.scrapeAndCreate(k); err != nil {
return fmt.Errorf("could not scrape and create new object: %s", err)
}
}
interfaceList := map[string]struct{}{}
for _, v := range p.entities {
if v.Type == "interface" {
interfaceList[v.Name] = struct{}{}
}
}
for index, v := range p.entities {
if len(v.Class.Methods) > 0 {
for i, method := range v.Class.Methods {
if !strings.Contains(method.Returns, "null") {
x := strings.TrimSpace(method.Returns)
if _, ok := interfaceList[x]; ok {
method.Returns = x + " | null"
method.ReturnStmt = "return null;"
p.entities[index].Class.Methods[i] = method
}
}
}
}
}
for _, v := range p.entities {
if v.IsConstructor {
foundStruct:
for i, class := range p.entities {
if class.Type != "class" {
continue foundStruct
}
if strings.Contains(v.Name, class.Name) {
p.entities[i].Class.Constructor = v.Function
break foundStruct
}
}
}
}
filtered := []Entity{}
for _, v := range p.entities {
if !v.IsConstructor {
filtered = append(filtered, v)
}
}
filtered = append(filtered, p.vars...)
p.entities = filtered
return nil
}
func (p *EntityParser) extractClassProperties(node *ast.StructType) []Property {
var properties []Property
for _, field := range node.Fields.List {
if len(field.Names) > 0 && !field.Names[0].IsExported() {
continue
}
typeString := exprToString(field.Type)
if len(field.Names) == 0 {
if ident, ok := field.Type.(*ast.Ident); ok {
properties = append(properties, Property{
Name: ident.Name,
Type: typeString,
Description: field.Doc.Text(),
})
}
continue
}
for _, fieldName := range field.Names {
if fieldName.IsExported() {
property := Property{
Name: fieldName.Name,
Type: typeString,
Description: field.Doc.Text(),
}
if strings.Contains(property.Type, ".") {
property.Type = p.handleExternalStruct(property.Type)
}
properties = append(properties, property)
}
}
}
return properties
}
var (
constructorRe = `(constructor\([^)]*\))`
constructorReCompiled = regexp.MustCompile(constructorRe)
)
func (p *EntityParser) extractFunctionFromNode(fn *ast.FuncDecl) (Entity, error) {
entity := Entity{
Name: fn.Name.Name,
Type: "function",
Description: Ternary(strings.TrimSpace(fn.Doc.Text()) != "", fn.Doc.Text(), fn.Name.Name+" Function"),
Function: Function{
Parameters: p.extractParameters(fn),
Returns: p.extractReturnType(fn),
CanFail: checkCanFail(fn),
},
}
if strings.Contains(entity.Function.Returns, "Object") && len(entity.Function.Parameters) == 2 {
constructorSig := constructorReCompiled.FindString(entity.Description)
entity.IsConstructor = true
entity.Function = updateFuncWithConstructorSig(constructorSig, entity.Function)
return entity, nil
}
if entity.Function.Returns == "void" {
entity.Function.ReturnStmt = "return;"
} else if strings.Contains(entity.Function.Returns, "null") {
entity.Function.ReturnStmt = "return null;"
} else if fn.Recv != nil && exprToString(fn.Recv.List[0].Type) == entity.Function.Returns {
entity.Function.ReturnStmt = "return this;"
} else {
entity.Function.ReturnStmt = "return " + TsDefaultValue(entity.Function.Returns) + ";"
}
return entity, nil
}
func (p *EntityParser) extractReturnType(fn *ast.FuncDecl) (out string) {
defer func() {
if out == "" {
out = "void"
}
if strings.Contains(out, "interface{}") {
out = strings.ReplaceAll(out, "interface{}", "any")
}
}()
var returns []string
if fn.Type.Results != nil && len(fn.Type.Results.List) > 0 {
for _, result := range fn.Type.Results.List {
tmp := exprToString(result.Type)
if strings.Contains(tmp, ".") && !strings.HasPrefix(tmp, "goja.") {
tmp = p.handleExternalStruct(tmp) + " | null"
}
returns = append(returns, tmp)
}
}
if len(returns) == 1 {
val := returns[0]
val = strings.TrimPrefix(val, "*")
if val == "error" {
out = "void"
} else {
out = val
}
return
}
if len(returns) > 1 {
for _, val := range returns {
val = strings.TrimPrefix(val, "*")
if val != "error" {
out = val
break
}
}
if sliceutil.Contains(returns, "error") {
out = out + " | null"
return
}
}
return "void"
}
func convertMaptoRecord(input string) (out string) {
var key, value string
input = strings.TrimPrefix(input, "Map[")
key = input[:strings.Index(input, "]")]
value = input[strings.Index(input, "]")+1:]
return "Record<" + toTsTypes(key) + ", " + toTsTypes(value) + ">"
}
func (p *EntityParser) extractParameters(fn *ast.FuncDecl) []Parameter {
var parameters []Parameter
for _, param := range fn.Type.Params.List {
name := param.Names[0].Name
typ := exprToString(param.Type)
if strings.Contains(typ, ".") {
typ = "any"
}
parameters = append(parameters, Parameter{
Name: name,
Type: toTsTypes(typ),
})
}
return parameters
}
func (p *EntityParser) handleExternalStruct(typeName string) string {
baseType := typeName[strings.LastIndex(typeName, ".")+1:]
p.newObjects[typeName] = &Entity{
Name: baseType,
Type: "interface",
Description: baseType + " Object",
}
return baseType
}
func (p *EntityParser) extractStructTypes() {
for _, file := range p.syntax {
ast.Inspect(file, func(n ast.Node) bool {
typeSpec, ok := n.(*ast.TypeSpec)
if ok {
_, ok := typeSpec.Type.(*ast.StructType)
if ok {
p.structTypes[typeSpec.Name.Name] = Entity{
Name: typeSpec.Name.Name,
Description: typeSpec.Doc.Text(),
}
}
}
return true
})
}
}
func (p *EntityParser) extractVarsNConstants() {
p.vars = []Entity{}
for _, file := range p.syntax {
ast.Inspect(file, func(n ast.Node) bool {
gen, ok := n.(*ast.GenDecl)
if !ok {
return true
}
for _, v := range gen.Specs {
switch spec := v.(type) {
case *ast.ValueSpec:
if !spec.Names[0].IsExported() {
continue
}
if len(spec.Values) == 0 {
continue
}
p.vars = append(p.vars, Entity{
Name: spec.Names[0].Name,
Type: "const",
Description: strings.TrimSpace(spec.Comment.Text()),
Value: spec.Values[0].(*ast.BasicLit).Value,
})
}
}
return true
})
}
}
func (p *EntityParser) loadImportedPackages() error {
for _, file := range p.syntax {
for _, imp := range file.Imports {
path := imp.Path.Value
path = path[1 : len(path)-1]
pkg, err := loadPackage(path)
if err != nil {
return err
}
importName := path[strings.LastIndex(path, "/")+1:]
if imp.Name != nil {
importName = imp.Name.Name
} else {
if !strings.HasSuffix(imp.Path.Value, pkg.Types.Name()+`"`) {
importName = pkg.Types.Name()
}
}
if _, ok := p.imports[importName]; !ok {
p.imports[importName] = pkg
}
}
}
return nil
}
func loadPackage(pkgPath string) (*packages.Package, error) {
cfg := &packages.Config{Mode: packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesInfo}
pkgs, err := packages.Load(cfg, pkgPath)
if err != nil {
return nil, err
}
if len(pkgs) == 0 {
return nil, errors.New("no packages found")
}
return pkgs[0], nil
}
func updateFuncWithConstructorSig(sig string, f Function) Function {
sig = strings.TrimSpace(sig)
f.Parameters = []Parameter{}
f.CanFail = true
f.ReturnStmt = ""
f.Returns = ""
if sig == "" {
return f
}
sig = strings.TrimPrefix(sig, "constructor(")
sig = strings.TrimSuffix(sig, ")")
args := strings.Split(sig, ",")
for _, arg := range args {
arg = strings.TrimSpace(arg)
typeData := strings.Split(arg, ":")
if len(typeData) != 2 {
panic("invalid constructor signature")
}
f.Parameters = append(f.Parameters, Parameter{
Name: strings.TrimSpace(typeData[0]),
Type: strings.TrimSpace(typeData[1]),
})
}
return f
}