github.com/unionj-cloud/go-doudou/v2@v2.3.5/toolkit/gormgen/internal/parser/parser.go (about)

     1  package parser
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/parser"
     7  	"go/token"
     8  	"log"
     9  	"path"
    10  	"path/filepath"
    11  	"strconv"
    12  	"strings"
    13  )
    14  
    15  // InterfaceSet ...
    16  type InterfaceSet struct {
    17  	Interfaces []InterfaceInfo
    18  	imports    map[string]string // package name -> quoted "package path"
    19  }
    20  
    21  // InterfaceInfo ...
    22  type InterfaceInfo struct {
    23  	Name        string
    24  	Doc         string
    25  	Methods     []*Method
    26  	Package     string
    27  	ApplyStruct []string
    28  }
    29  
    30  // MatchStruct ...
    31  func (i *InterfaceInfo) MatchStruct(name string) bool {
    32  	for _, s := range i.ApplyStruct {
    33  		if s == name {
    34  			return true
    35  		}
    36  	}
    37  	return false
    38  }
    39  
    40  // ParseFile get interface's info from source file
    41  func (i *InterfaceSet) ParseFile(paths []*InterfacePath, structNames []string) error {
    42  	for _, path := range paths {
    43  		for _, file := range path.Files {
    44  			absFilePath, err := filepath.Abs(file)
    45  			if err != nil {
    46  				return fmt.Errorf("file not found: %s", file)
    47  			}
    48  
    49  			err = i.getInterfaceFromFile(absFilePath, path.Name, path.FullName, structNames)
    50  			if err != nil {
    51  				return fmt.Errorf("can't get interface from %s:%s", path.FullName, err)
    52  			}
    53  		}
    54  	}
    55  	return nil
    56  }
    57  
    58  // Visit ast visit function
    59  func (i *InterfaceSet) Visit(n ast.Node) (w ast.Visitor) {
    60  	switch n := n.(type) {
    61  	case *ast.ImportSpec:
    62  		importName, _ := strconv.Unquote(n.Path.Value)
    63  		importName = path.Base(importName)
    64  		if n.Name != nil {
    65  			name := n.Name.Name
    66  			// ignore dummy imports
    67  			// TODO: full support for dot imports requires type checking the whole package
    68  			if name == "_" || name == "." {
    69  				return i
    70  			}
    71  			importName = name
    72  		}
    73  		i.imports[importName] = n.Path.Value
    74  	case *ast.TypeSpec:
    75  		if data, ok := n.Type.(*ast.InterfaceType); ok {
    76  			r := InterfaceInfo{
    77  				Methods: []*Method{},
    78  			}
    79  			methods := data.Methods.List
    80  			r.Name = n.Name.Name
    81  			r.Doc = n.Doc.Text()
    82  
    83  			for _, m := range methods {
    84  				for _, name := range m.Names {
    85  					method := &Method{
    86  						MethodName: name.Name,
    87  						Doc:        m.Doc.Text(),
    88  						Params:     getParamList(m.Type.(*ast.FuncType).Params),
    89  						Result:     getParamList(m.Type.(*ast.FuncType).Results),
    90  					}
    91  					fixParamPackagePath(i.imports, method.Params)
    92  					r.Methods = append(r.Methods, method)
    93  				}
    94  			}
    95  			i.Interfaces = append(i.Interfaces, r)
    96  		}
    97  	}
    98  	return i
    99  }
   100  
   101  // getInterfaceFromFile get interfaces
   102  // get all interfaces from file and compare with specified name
   103  func (i *InterfaceSet) getInterfaceFromFile(filename string, name, Package string, structNames []string) error {
   104  	fileset := token.NewFileSet()
   105  	f, err := parser.ParseFile(fileset, filename, nil, parser.ParseComments)
   106  	if err != nil {
   107  		return fmt.Errorf("can't parse file %q: %s", filename, err)
   108  	}
   109  
   110  	astResult := &InterfaceSet{imports: make(map[string]string)}
   111  	ast.Walk(astResult, f)
   112  
   113  	for _, info := range astResult.Interfaces {
   114  		if name == info.Name {
   115  			info.Package = Package
   116  			info.ApplyStruct = structNames
   117  			i.Interfaces = append(i.Interfaces, info)
   118  		}
   119  	}
   120  
   121  	return nil
   122  }
   123  
   124  // Param parameters in method
   125  type Param struct { // (user model.User)
   126  	PkgPath   string // package's path: internal/model
   127  	Package   string // package's name: model
   128  	Name      string // param's name: user
   129  	Type      string // param's type: User
   130  	IsArray   bool   // is array or not
   131  	IsPointer bool   // is pointer or not
   132  }
   133  
   134  // Eq if param equal to another
   135  func (p *Param) Eq(q Param) bool {
   136  	return p.Package == q.Package && p.Type == q.Type
   137  }
   138  
   139  // IsError ...
   140  func (p *Param) IsError() bool {
   141  	return p.Type == "error"
   142  }
   143  
   144  // IsGenM ...
   145  func (p *Param) IsGenM() bool {
   146  	return p.Package == "gen" && p.Type == "M"
   147  }
   148  
   149  // IsGenRowsAffected ...
   150  func (p *Param) IsGenRowsAffected() bool {
   151  	return p.Package == "gen" && p.Type == "RowsAffected"
   152  }
   153  
   154  // IsMap ...
   155  func (p *Param) IsMap() bool {
   156  	return strings.HasPrefix(p.Type, "map[")
   157  }
   158  
   159  // IsGenT ...
   160  func (p *Param) IsGenT() bool {
   161  	return p.Package == "gen" && p.Type == "T"
   162  }
   163  
   164  // IsInterface ...
   165  func (p *Param) IsInterface() bool {
   166  	return p.Type == "interface{}"
   167  }
   168  
   169  // IsNull ...
   170  func (p *Param) IsNull() bool {
   171  	return p.Package == "" && p.Type == "" && p.Name == ""
   172  }
   173  
   174  // InMainPkg ...
   175  func (p *Param) InMainPkg() bool {
   176  	return p.Package == "main"
   177  }
   178  
   179  // IsTime ...
   180  func (p *Param) IsTime() bool {
   181  	return p.Package == "time" && p.Type == "Time"
   182  }
   183  
   184  // IsSQLResult ...
   185  func (p *Param) IsSQLResult() bool {
   186  	return (p.Package == "sql" && p.Type == "Result") || (p.Package == "gen" && p.Type == "SQLResult")
   187  }
   188  
   189  // IsSQLRow ...
   190  func (p *Param) IsSQLRow() bool {
   191  	return (p.Package == "sql" && p.Type == "Row") || (p.Package == "gen" && p.Type == "SQLRow")
   192  }
   193  
   194  // IsSQLRows ...
   195  func (p *Param) IsSQLRows() bool {
   196  	return (p.Package == "sql" && p.Type == "Rows") || (p.Package == "gen" && p.Type == "SQLRows")
   197  }
   198  
   199  // SetName ...
   200  func (p *Param) SetName(name string) {
   201  	p.Name = name
   202  }
   203  
   204  // TypeName ...
   205  func (p *Param) TypeName() string {
   206  	if p.IsArray {
   207  		return "[]" + p.Type
   208  	}
   209  	return p.Type
   210  }
   211  
   212  // TmplString param to string in tmpl
   213  func (p *Param) TmplString() string {
   214  	var res strings.Builder
   215  	if p.Name != "" {
   216  		res.WriteString(p.Name)
   217  		res.WriteString(" ")
   218  	}
   219  
   220  	if p.IsArray {
   221  		res.WriteString("[]")
   222  	}
   223  	if p.IsPointer {
   224  		res.WriteString("*")
   225  	}
   226  	if p.Package != "" {
   227  		res.WriteString(p.Package)
   228  		res.WriteString(".")
   229  	}
   230  	res.WriteString(p.Type)
   231  	return res.String()
   232  }
   233  
   234  // IsBaseType judge whether the param type is basic type
   235  func (p *Param) IsBaseType() bool {
   236  	switch p.Type {
   237  	case "string", "byte":
   238  		return true
   239  	case "int", "int8", "int16", "int32", "int64", "uint", "uint8", "uint16", "uint32", "uint64":
   240  		return true
   241  	case "float64", "float32":
   242  		return true
   243  	case "bool":
   244  		return true
   245  	case "time.Time":
   246  		return true
   247  	default:
   248  		return false
   249  	}
   250  }
   251  
   252  func (p *Param) astGetParamType(param *ast.Field) {
   253  	switch v := param.Type.(type) {
   254  	case *ast.Ident:
   255  		p.Type = v.Name
   256  		if v.Obj != nil {
   257  			p.Package = "UNDEFINED" // set a placeholder
   258  		}
   259  	case *ast.SelectorExpr:
   260  		p.astGetEltType(v)
   261  	case *ast.ArrayType:
   262  		p.astGetEltType(v.Elt)
   263  		p.IsArray = true
   264  	case *ast.Ellipsis:
   265  		p.astGetEltType(v.Elt)
   266  		p.IsArray = true
   267  	case *ast.MapType:
   268  		p.astGetMapType(v)
   269  	case *ast.InterfaceType:
   270  		p.Type = "interface{}"
   271  	case *ast.StarExpr:
   272  		p.IsPointer = true
   273  		p.astGetEltType(v.X)
   274  	default:
   275  		log.Fatalf("unknow param type: %+v", v)
   276  	}
   277  }
   278  
   279  func (p *Param) astGetEltType(expr ast.Expr) {
   280  	switch v := expr.(type) {
   281  	case *ast.Ident:
   282  		p.Type = v.Name
   283  		if v.Obj != nil {
   284  			p.Package = "UNDEFINED"
   285  		}
   286  	case *ast.SelectorExpr:
   287  		p.Type = v.Sel.Name
   288  		p.astGetPackageName(v.X)
   289  	case *ast.MapType:
   290  		p.astGetMapType(v)
   291  	case *ast.StarExpr:
   292  		p.IsPointer = true
   293  		p.astGetEltType(v.X)
   294  	case *ast.InterfaceType:
   295  		p.Type = "interface{}"
   296  	case *ast.ArrayType:
   297  		p.astGetEltType(v.Elt)
   298  		p.Type = "[]" + p.Type
   299  	default:
   300  		log.Fatalf("unknow param type: %+v", v)
   301  	}
   302  }
   303  
   304  func (p *Param) astGetPackageName(expr ast.Expr) {
   305  	switch v := expr.(type) {
   306  	case *ast.Ident:
   307  		p.Package = v.Name
   308  	}
   309  }
   310  
   311  func (p *Param) astGetMapType(expr *ast.MapType) {
   312  	p.Type = fmt.Sprintf("map[%s]%s", astGetType(expr.Key), astGetType(expr.Value))
   313  }
   314  
   315  func astGetType(expr ast.Expr) string {
   316  	switch v := expr.(type) {
   317  	case *ast.Ident:
   318  		return v.Name
   319  	case *ast.InterfaceType:
   320  		return "interface{}"
   321  	}
   322  	return ""
   323  }