github.com/unionj-cloud/go-doudou/v2@v2.3.5/toolkit/astutils/funcs.go (about)

     1  package astutils
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"encoding/json"
     7  	"fmt"
     8  	"github.com/pkg/errors"
     9  	"github.com/samber/lo"
    10  	"github.com/sirupsen/logrus"
    11  	"github.com/unionj-cloud/go-doudou/v2/toolkit/constants"
    12  	"github.com/unionj-cloud/go-doudou/v2/toolkit/stringutils"
    13  	"go/ast"
    14  	"go/format"
    15  	"golang.org/x/tools/imports"
    16  	"io/ioutil"
    17  	"os"
    18  	"path/filepath"
    19  	"reflect"
    20  	"regexp"
    21  	"strconv"
    22  	"strings"
    23  	"text/template"
    24  	"unicode"
    25  )
    26  
    27  func GetImportStatements(input []byte) []byte {
    28  	reg := regexp.MustCompile("(?s)import \\((.*?)\\)")
    29  	if !reg.Match(input) {
    30  		return nil
    31  	}
    32  	matches := reg.FindSubmatch(input)
    33  	return matches[1]
    34  }
    35  
    36  func AppendImportStatements(src []byte, appendImports []byte) []byte {
    37  	reg := regexp.MustCompile("(?s)import \\((.*?)\\)")
    38  	if !reg.Match(src) {
    39  		return src
    40  	}
    41  	matches := reg.FindSubmatch(src)
    42  	old := matches[1]
    43  	re := regexp.MustCompile(`[\r\n]+`)
    44  	splits := re.Split(string(old), -1)
    45  	oldmap := make(map[string]struct{})
    46  	for _, item := range splits {
    47  		oldmap[strings.TrimSpace(item)] = struct{}{}
    48  	}
    49  	splits = re.Split(string(appendImports), -1)
    50  	var newimps []string
    51  	for _, item := range splits {
    52  		key := strings.TrimSpace(item)
    53  		if _, ok := oldmap[key]; !ok {
    54  			newimps = append(newimps, "\t"+key)
    55  		}
    56  	}
    57  	if len(newimps) == 0 {
    58  		return src
    59  	}
    60  	appendImports = []byte(constants.LineBreak + strings.Join(newimps, constants.LineBreak) + constants.LineBreak)
    61  	return reg.ReplaceAllFunc(src, func(i []byte) []byte {
    62  		old = append([]byte("import ("), old...)
    63  		old = append(old, appendImports...)
    64  		old = append(old, []byte(")")...)
    65  		return old
    66  	})
    67  }
    68  
    69  func GrpcRelatedModify(src []byte, metaName string, grpcSvcName string) []byte {
    70  	expr := fmt.Sprintf(`type %sImpl struct {`, metaName)
    71  	reg := regexp.MustCompile(expr)
    72  	unimpl := fmt.Sprintf("pb.Unimplemented%sServer", grpcSvcName)
    73  	if !strings.Contains(string(src), unimpl) {
    74  		appendUnimpl := []byte(constants.LineBreak + unimpl + constants.LineBreak)
    75  		src = reg.ReplaceAllFunc(src, func(i []byte) []byte {
    76  			return append([]byte(expr), appendUnimpl...)
    77  		})
    78  	}
    79  	var_pb := fmt.Sprintf("var _ pb.%sServer = (*%sImpl)(nil)", grpcSvcName, metaName)
    80  	if !strings.Contains(string(src), var_pb) {
    81  		appendVarPb := []byte(constants.LineBreak + var_pb + constants.LineBreak)
    82  		src = reg.ReplaceAllFunc(src, func(i []byte) []byte {
    83  			return append(appendVarPb, []byte(expr)...)
    84  		})
    85  	}
    86  	return src
    87  }
    88  
    89  func RestRelatedModify(src []byte, metaName string) []byte {
    90  	expr := fmt.Sprintf(`type %sImpl struct {`, metaName)
    91  	reg := regexp.MustCompile(expr)
    92  	var_ := fmt.Sprintf("var _ %s = (*%sImpl)(nil)", metaName, metaName)
    93  	if !strings.Contains(string(src), var_) {
    94  		appendVarPb := []byte(constants.LineBreak + var_ + constants.LineBreak)
    95  		src = reg.ReplaceAllFunc(src, func(i []byte) []byte {
    96  			return append(appendVarPb, []byte(expr)...)
    97  		})
    98  	}
    99  	return src
   100  }
   101  
   102  // FixImport format source code and add missing import syntax automatically
   103  func FixImport(src []byte, file string) {
   104  	var (
   105  		res []byte
   106  		err error
   107  	)
   108  	if res, err = imports.Process(file, src, &imports.Options{
   109  		TabWidth:  8,
   110  		TabIndent: true,
   111  		Comments:  true,
   112  		Fragment:  true,
   113  	}); err != nil {
   114  		lines := strings.Split(string(src), "\n")
   115  		errLine, _ := strconv.Atoi(strings.Split(err.Error(), ":")[1])
   116  		startLine, endLine := errLine-5, errLine+5
   117  		fmt.Println("Format fail:", errLine, err)
   118  		if startLine < 0 {
   119  			startLine = 0
   120  		}
   121  		if endLine > len(lines)-1 {
   122  			endLine = len(lines) - 1
   123  		}
   124  		for i := startLine; i <= endLine; i++ {
   125  			fmt.Println(i, lines[i])
   126  		}
   127  		errors.WithStack(fmt.Errorf("cannot format file: %w", err))
   128  	} else {
   129  		_ = ioutil.WriteFile(file, res, os.ModePerm)
   130  		return
   131  	}
   132  	_ = ioutil.WriteFile(file, src, os.ModePerm)
   133  }
   134  
   135  // GetMethodMeta get method name then new MethodMeta struct from *ast.FuncDecl
   136  func GetMethodMeta(spec *ast.FuncDecl) MethodMeta {
   137  	methodName := ExprString(spec.Name)
   138  	mm := NewMethodMeta(spec.Type, ExprString)
   139  	mm.Name = methodName
   140  	return mm
   141  }
   142  
   143  // NewMethodMeta new MethodMeta struct from *ast.FuncDecl
   144  func NewMethodMeta(ft *ast.FuncType, exprString func(ast.Expr) string) MethodMeta {
   145  	var params, results []FieldMeta
   146  	for _, param := range ft.Params.List {
   147  		pt := exprString(param.Type)
   148  		if len(param.Names) > 0 {
   149  			for _, name := range param.Names {
   150  				params = append(params, FieldMeta{
   151  					Name: name.Name,
   152  					Type: pt,
   153  					Tag:  "",
   154  				})
   155  			}
   156  			continue
   157  		}
   158  		params = append(params, FieldMeta{
   159  			Name: "",
   160  			Type: pt,
   161  			Tag:  "",
   162  		})
   163  	}
   164  	if ft.Results != nil {
   165  		for _, result := range ft.Results.List {
   166  			rt := exprString(result.Type)
   167  			if len(result.Names) > 0 {
   168  				for _, name := range result.Names {
   169  					results = append(results, FieldMeta{
   170  						Name: name.Name,
   171  						Type: rt,
   172  						Tag:  "",
   173  					})
   174  				}
   175  				continue
   176  			}
   177  			results = append(results, FieldMeta{
   178  				Name: "",
   179  				Type: rt,
   180  				Tag:  "",
   181  			})
   182  		}
   183  	}
   184  	return MethodMeta{
   185  		Params:  params,
   186  		Results: results,
   187  	}
   188  }
   189  
   190  // NewStructMeta new StructMeta from *ast.StructType
   191  func NewStructMeta(structType *ast.StructType, exprString func(ast.Expr) string) StructMeta {
   192  	var fields []FieldMeta
   193  	re := regexp.MustCompile(`json:"(.*?)"`)
   194  	for _, field := range structType.Fields.List {
   195  		var fieldComments []string
   196  		if field.Doc != nil {
   197  			for _, comment := range field.Doc.List {
   198  				fieldComments = append(fieldComments, strings.TrimSpace(strings.TrimPrefix(comment.Text, "//")))
   199  			}
   200  		}
   201  
   202  		fieldType := exprString(field.Type)
   203  
   204  		var tag string
   205  		var docName string
   206  		if field.Tag != nil {
   207  			tag = strings.Trim(field.Tag.Value, "`")
   208  			if re.MatchString(tag) {
   209  				docName = strings.TrimSuffix(re.FindStringSubmatch(tag)[1], ",omitempty")
   210  			}
   211  		}
   212  
   213  		if len(field.Names) > 0 {
   214  			for _, name := range field.Names {
   215  				_docName := docName
   216  				if stringutils.IsEmpty(_docName) {
   217  					_docName = name.Name
   218  				}
   219  				fields = append(fields, FieldMeta{
   220  					Name:     name.Name,
   221  					Type:     fieldType,
   222  					Tag:      tag,
   223  					Comments: fieldComments,
   224  					IsExport: unicode.IsUpper(rune(name.Name[0])),
   225  					DocName:  _docName,
   226  				})
   227  			}
   228  		} else {
   229  			splits := strings.Split(fieldType, ".")
   230  			name := splits[len(splits)-1]
   231  			fieldType = "embed:" + fieldType
   232  			_docName := docName
   233  			if stringutils.IsEmpty(_docName) {
   234  				_docName = name
   235  			}
   236  			fields = append(fields, FieldMeta{
   237  				Name:     name,
   238  				Type:     fieldType,
   239  				Tag:      tag,
   240  				Comments: fieldComments,
   241  				IsExport: unicode.IsUpper(rune(name[0])),
   242  				DocName:  _docName,
   243  			})
   244  		}
   245  	}
   246  	return StructMeta{
   247  		Fields: fields,
   248  	}
   249  }
   250  
   251  // PackageMeta wraps package info
   252  type PackageMeta struct {
   253  	Name string
   254  }
   255  
   256  // FieldMeta wraps field info
   257  type FieldMeta struct {
   258  	Name     string
   259  	Type     string
   260  	Tag      string
   261  	Comments []string
   262  	IsExport bool
   263  	// used in OpenAPI 3.0 spec as property name
   264  	DocName string
   265  	// Annotations of the field
   266  	Annotations []Annotation
   267  	// ValidateTag based on https://github.com/go-playground/validator
   268  	// please refer to its documentation https://pkg.go.dev/github.com/go-playground/validator/v10
   269  	ValidateTag    string
   270  	IsPathVariable bool
   271  }
   272  
   273  // StructMeta wraps struct info
   274  type StructMeta struct {
   275  	Name     string
   276  	Fields   []FieldMeta
   277  	Comments []string
   278  	Methods  []MethodMeta
   279  	IsExport bool
   280  	// go-doudou version
   281  	Version string
   282  }
   283  
   284  // EnumMeta wraps struct info
   285  type EnumMeta struct {
   286  	Name   string
   287  	Values []string
   288  }
   289  
   290  // ExprString return string representation from ast.Expr
   291  func ExprString(expr ast.Expr) string {
   292  	switch _expr := expr.(type) {
   293  	case *ast.Ident:
   294  		return _expr.Name
   295  	case *ast.StarExpr:
   296  		return "*" + ExprString(_expr.X)
   297  	case *ast.SelectorExpr:
   298  		return ExprString(_expr.X) + "." + _expr.Sel.Name
   299  	case *ast.InterfaceType:
   300  		return "interface{}"
   301  	case *ast.ArrayType:
   302  		if _expr.Len == nil {
   303  			return "[]" + ExprString(_expr.Elt)
   304  		}
   305  		return "[" + ExprString(_expr.Len) + "]" + ExprString(_expr.Elt)
   306  	case *ast.BasicLit:
   307  		return _expr.Value
   308  	case *ast.MapType:
   309  		return "map[" + ExprString(_expr.Key) + "]" + ExprString(_expr.Value)
   310  	case *ast.StructType:
   311  		structmeta := NewStructMeta(_expr, ExprString)
   312  		b, _ := json.Marshal(structmeta)
   313  		return "anonystruct«" + string(b) + "»"
   314  	case *ast.FuncType:
   315  		return NewMethodMeta(_expr, ExprString).String()
   316  	case *ast.ChanType:
   317  		var result string
   318  		if _expr.Dir == ast.SEND {
   319  			result += "chan<- "
   320  		} else if _expr.Dir == ast.RECV {
   321  			result += "<-chan "
   322  		} else {
   323  			result += "chan "
   324  		}
   325  		return result + ExprString(_expr.Value)
   326  	case *ast.Ellipsis:
   327  		if _expr.Ellipsis.IsValid() {
   328  			return "..." + ExprString(_expr.Elt)
   329  		}
   330  		panic(fmt.Sprintf("invalid ellipsis expression: %+v\n", expr))
   331  	case *ast.IndexExpr:
   332  		return ExprString(_expr.X) + "[" + ExprString(_expr.Index) + "]"
   333  	case *ast.IndexListExpr:
   334  		typeParams := lo.Map[ast.Expr, string](_expr.Indices, func(item ast.Expr, index int) string {
   335  			return ExprString(item)
   336  		})
   337  		return ExprString(_expr.X) + "[" + strings.Join(typeParams, ", ") + "]"
   338  	default:
   339  		logrus.Infof("not support expression: %+v\n", expr)
   340  		logrus.Infof("not support expression: %+v\n", reflect.TypeOf(expr))
   341  		logrus.Infof("not support expression: %#v\n", reflect.TypeOf(expr))
   342  		logrus.Infof("not support expression: %v\n", reflect.TypeOf(expr).String())
   343  		return ""
   344  		//panic(fmt.Sprintf("not support expression: %+v\n", expr))
   345  	}
   346  }
   347  
   348  type Annotation struct {
   349  	Name   string
   350  	Params []string
   351  }
   352  
   353  var reAnno = regexp.MustCompile(`@(\S+?)\((.*?)\)`)
   354  
   355  func GetAnnotations(text string) []Annotation {
   356  	if !reAnno.MatchString(text) {
   357  		return nil
   358  	}
   359  	var annotations []Annotation
   360  	matches := reAnno.FindAllStringSubmatch(text, -1)
   361  	for _, item := range matches {
   362  		name := fmt.Sprintf(`@%s`, item[1])
   363  		var params []string
   364  		if stringutils.IsNotEmpty(item[2]) {
   365  			params = strings.Split(strings.TrimSpace(item[2]), ",")
   366  		}
   367  		annotations = append(annotations, Annotation{
   368  			Name:   name,
   369  			Params: params,
   370  		})
   371  	}
   372  	return annotations
   373  }
   374  
   375  // MethodMeta represents an api
   376  type MethodMeta struct {
   377  	// Recv method receiver
   378  	Recv string
   379  	// Name method name
   380  	Name string
   381  	// Params when generate client code from openapi3 spec json file, Params holds all method input parameters.
   382  	// when generate client code from service interface in svc.go file, if there is struct type param, this struct type param will put into request body,
   383  	// then others will be put into url as query string. if there is no struct type param and the api is a get request, all will be put into url as query string.
   384  	// if there is no struct type param and the api is Not a get request, all will be put into request body as application/x-www-form-urlencoded data.
   385  	// specially, if there is one or more v3.FileModel or []v3.FileModel params,
   386  	// all will be put into request body as multipart/form-data data.
   387  	Params []FieldMeta
   388  	// Results response
   389  	Results []FieldMeta
   390  	// PathVars not support when generate client code from service interface in svc.go file
   391  	// when generate client code from openapi3 spec json file, PathVars is parameters in url as path variable.
   392  	PathVars []FieldMeta
   393  	// HeaderVars not support when generate client code from service interface in svc.go file
   394  	// when generate client code from openapi3 spec json file, HeaderVars is parameters in header.
   395  	HeaderVars []FieldMeta
   396  	// BodyParams not support when generate client code from service interface in svc.go file
   397  	// when generate client code from openapi3 spec json file, BodyParams is parameters in request body as query string.
   398  	BodyParams *FieldMeta
   399  	// BodyJSON not support when generate client code from service interface in svc.go file
   400  	// when generate client code from openapi3 spec json file, BodyJSON is parameters in request body as json.
   401  	BodyJSON *FieldMeta
   402  	// Files not support when generate client code from service interface in svc.go file
   403  	// when generate client code from openapi3 spec json file, Files is parameters in request body as multipart file.
   404  	Files []FieldMeta
   405  	// Comments of the method
   406  	Comments []string
   407  	// Path api path
   408  	// not support when generate client code from service interface in svc.go file
   409  	Path string
   410  	// QueryParams not support when generate client code from service interface in svc.go file
   411  	// when generate client code from openapi3 spec json file, QueryParams is parameters in url as query string.
   412  	QueryParams *FieldMeta
   413  	// Annotations of the method
   414  	Annotations     []Annotation
   415  	HasPathVariable bool
   416  	// HttpMethod only accepts GET, PUT, POST, DELETE
   417  	HttpMethod string
   418  }
   419  
   420  const methodTmpl = `func {{ if .Recv }}(receiver {{.Recv}}){{ end }} {{.Name}}({{- range $i, $p := .Params}}
   421      {{- if $i}},{{end}}
   422      {{- $p.Name}} {{$p.Type}}
   423      {{- end }}) ({{- range $i, $r := .Results}}
   424                       {{- if $i}},{{end}}
   425                       {{- $r.Name}} {{$r.Type}}
   426                       {{- end }})`
   427  
   428  func (mm MethodMeta) String() string {
   429  	if stringutils.IsNotEmpty(mm.Recv) && stringutils.IsEmpty(mm.Name) {
   430  		panic("not valid code")
   431  	}
   432  	var isAnony bool
   433  	if stringutils.IsEmpty(mm.Name) {
   434  		isAnony = true
   435  		mm.Name = "placeholder"
   436  	}
   437  	t, _ := template.New("method.tmpl").Parse(methodTmpl)
   438  	var buf bytes.Buffer
   439  	_ = t.Execute(&buf, mm)
   440  	var res []byte
   441  	res, _ = format.Source(buf.Bytes())
   442  	result := string(res)
   443  	if isAnony {
   444  		return strings.Replace(result, "func placeholder(", "func(", 1)
   445  	}
   446  	return result
   447  }
   448  
   449  // InterfaceMeta wraps interface info
   450  type InterfaceMeta struct {
   451  	Name     string
   452  	Methods  []MethodMeta
   453  	Comments []string
   454  }
   455  
   456  // Visit visit each files
   457  func Visit(files *[]string) filepath.WalkFunc {
   458  	return func(path string, info os.FileInfo, err error) error {
   459  		if err != nil {
   460  			logrus.Panicln(err)
   461  		}
   462  		if !info.IsDir() {
   463  			*files = append(*files, path)
   464  		}
   465  		return nil
   466  	}
   467  }
   468  
   469  // GetMod get module name from go.mod file
   470  func GetMod() string {
   471  	var (
   472  		f         *os.File
   473  		err       error
   474  		firstLine string
   475  	)
   476  	dir, _ := os.Getwd()
   477  	mod := filepath.Join(dir, "go.mod")
   478  	if f, err = os.Open(mod); err != nil {
   479  		panic(err)
   480  	}
   481  	reader := bufio.NewReader(f)
   482  	firstLine, _ = reader.ReadString('\n')
   483  	return strings.TrimSpace(strings.TrimPrefix(firstLine, "module"))
   484  }
   485  
   486  // GetImportPath get import path of pkg from dir
   487  func GetImportPath(dir string) string {
   488  	wd, _ := os.Getwd()
   489  	return GetMod() + strings.ReplaceAll(strings.TrimPrefix(dir, wd), `\`, `/`)
   490  }