github.com/masterhung0112/hk_server/v5@v5.0.0-20220302090640-ec71aef15e1c/app/layer_generators/main.go (about)

     1  // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
     2  // See LICENSE.txt for license information.
     3  
     4  package main
     5  
     6  import (
     7  	"bytes"
     8  	"flag"
     9  	"fmt"
    10  	"go/ast"
    11  	"go/parser"
    12  	"go/token"
    13  	"io/ioutil"
    14  	"log"
    15  	"os"
    16  	"path"
    17  	"regexp"
    18  	"strings"
    19  	"text/template"
    20  
    21  	"golang.org/x/tools/imports"
    22  )
    23  
    24  var (
    25  	reserved           = []string{"AcceptLanguage", "AccountMigration", "Cluster", "Compliance", "Context", "DataRetention", "Elasticsearch", "HTTPService", "ImageProxy", "IpAddress", "Ldap", "Log", "MessageExport", "Metrics", "Notification", "NotificationsLog", "Path", "RequestId", "Saml", "Session", "SetIpAddress", "SetRequestId", "SetSession", "SetStore", "SetT", "Srv", "Store", "T", "Timezones", "UserAgent", "SetUserAgent", "SetAcceptLanguage", "SetPath", "SetContext", "SetServer", "GetT"}
    26  	outputFile         string
    27  	inputFile          string
    28  	outputFileTemplate string
    29  	basicTypes         = map[string]bool{"int": true, "uint": true, "string": true, "float": true, "bool": true, "byte": true, "int64": true, "uint64": true, "error": true}
    30  	textRegexp         = regexp.MustCompile(`\w+$`)
    31  )
    32  
    33  const (
    34  	OpenTracingParamsMarker = "@openTracingParams"
    35  	AppErrorType            = "*model.AppError"
    36  	ErrorType               = "error"
    37  )
    38  
    39  func isError(typeName string) bool {
    40  	return strings.Contains(typeName, AppErrorType) || strings.Contains(typeName, ErrorType)
    41  }
    42  
    43  func init() {
    44  	flag.StringVar(&inputFile, "in", path.Join("..", "app_iface.go"), "App interface file")
    45  	flag.StringVar(&outputFile, "out", path.Join("..", "opentracing_layer.go"), "Output file")
    46  	flag.StringVar(&outputFileTemplate, "template", "opentracing_layer.go.tmpl", "Output template file")
    47  }
    48  
    49  func main() {
    50  	flag.Parse()
    51  
    52  	code, err := generateLayer("OpenTracingAppLayer", outputFileTemplate)
    53  	if err != nil {
    54  		log.Fatal(err)
    55  	}
    56  	formattedCode, err := imports.Process(outputFile, code, &imports.Options{Comments: true})
    57  	if err != nil {
    58  		log.Fatal(err)
    59  	}
    60  
    61  	err = ioutil.WriteFile(outputFile, formattedCode, 0644)
    62  	if err != nil {
    63  		log.Fatal(err)
    64  	}
    65  }
    66  
    67  type methodParam struct {
    68  	Name string
    69  	Type string
    70  }
    71  
    72  type methodData struct {
    73  	ParamsToTrace map[string]bool
    74  	Params        []methodParam
    75  	Results       []string
    76  }
    77  
    78  type storeMetadata struct {
    79  	Name    string
    80  	Methods map[string]methodData
    81  }
    82  
    83  func fixTypeName(t string) string {
    84  	// don't want to dive into AST to parse this, add exception
    85  	if t == "...func(*UploadFileTask)" {
    86  		t = "...func(*app.UploadFileTask)"
    87  	}
    88  	if strings.Contains(t, ".") || strings.Contains(t, "{}") {
    89  		return t
    90  	}
    91  	typeOnly := textRegexp.FindString(t)
    92  
    93  	if _, basicType := basicTypes[typeOnly]; !basicType {
    94  		t = t[:len(t)-len(typeOnly)] + "app." + typeOnly
    95  	}
    96  	return t
    97  }
    98  
    99  func formatNode(src []byte, node ast.Expr) string {
   100  	return string(src[node.Pos()-1 : node.End()-1])
   101  }
   102  
   103  func extractMethodMetadata(method *ast.Field, src []byte) methodData {
   104  	params := []methodParam{}
   105  	paramsToTrace := map[string]bool{}
   106  	results := []string{}
   107  	e := method.Type.(*ast.FuncType)
   108  	if method.Doc != nil {
   109  		for _, comment := range method.Doc.List {
   110  			s := comment.Text
   111  			if idx := strings.Index(s, OpenTracingParamsMarker); idx != -1 {
   112  				for _, p := range strings.Split(s[idx+len(OpenTracingParamsMarker):], ",") {
   113  					paramsToTrace[strings.TrimSpace(p)] = true
   114  				}
   115  			}
   116  
   117  		}
   118  	}
   119  	if e.Params != nil {
   120  		for _, param := range e.Params.List {
   121  			for _, paramName := range param.Names {
   122  				paramType := fixTypeName(formatNode(src, param.Type))
   123  				params = append(params, methodParam{Name: paramName.Name, Type: paramType})
   124  			}
   125  		}
   126  	}
   127  
   128  	if e.Results != nil {
   129  		for _, r := range e.Results.List {
   130  			typeStr := fixTypeName(formatNode(src, r.Type))
   131  
   132  			if len(r.Names) > 0 {
   133  				for _, k := range r.Names {
   134  					results = append(results, fmt.Sprintf("%s %s", k.Name, typeStr))
   135  				}
   136  			} else {
   137  				results = append(results, typeStr)
   138  			}
   139  		}
   140  	}
   141  
   142  	for paramName := range paramsToTrace {
   143  		found := false
   144  		for _, param := range params {
   145  			if param.Name == paramName || strings.HasPrefix(paramName, param.Name+".") {
   146  				found = true
   147  				break
   148  			}
   149  		}
   150  		if !found {
   151  			log.Fatalf("Unable to find a parameter called '%s' (method '%s') that is mentioned in the '%s' comment. Maybe it was renamed?", paramName, method.Names[0].Name, OpenTracingParamsMarker)
   152  		}
   153  	}
   154  	return methodData{Params: params, Results: results, ParamsToTrace: paramsToTrace}
   155  }
   156  
   157  func extractStoreMetadata() (*storeMetadata, error) {
   158  	// Create the AST by parsing src.
   159  	fset := token.NewFileSet() // positions are relative to fset
   160  
   161  	file, err := os.Open(inputFile)
   162  	if err != nil {
   163  		return nil, fmt.Errorf("unable to open %s file: %w", inputFile, err)
   164  	}
   165  	src, err := ioutil.ReadAll(file)
   166  	if err != nil {
   167  		return nil, err
   168  	}
   169  	defer file.Close()
   170  	f, err := parser.ParseFile(fset, "../app_iface.go", src, parser.AllErrors|parser.ParseComments)
   171  	if err != nil {
   172  		return nil, err
   173  	}
   174  
   175  	metadata := storeMetadata{Methods: map[string]methodData{}}
   176  
   177  	ast.Inspect(f, func(n ast.Node) bool {
   178  		switch x := n.(type) {
   179  		case *ast.TypeSpec:
   180  			if x.Name.Name == "AppIface" {
   181  				for _, method := range x.Type.(*ast.InterfaceType).Methods.List {
   182  					methodName := method.Names[0].Name
   183  					found := false
   184  					for _, reservedMethod := range reserved {
   185  						if methodName == reservedMethod {
   186  							found = true
   187  							break
   188  						}
   189  					}
   190  					if found {
   191  						continue
   192  					}
   193  					metadata.Methods[methodName] = extractMethodMetadata(method, src)
   194  				}
   195  			}
   196  		}
   197  
   198  		return true
   199  	})
   200  	return &metadata, err
   201  }
   202  
   203  func generateLayer(name, templateFile string) ([]byte, error) {
   204  	out := bytes.NewBufferString("")
   205  	metadata, err := extractStoreMetadata()
   206  	if err != nil {
   207  		return nil, err
   208  	}
   209  	metadata.Name = name
   210  
   211  	myFuncs := template.FuncMap{
   212  		"joinResults": func(results []string) string {
   213  			return strings.Join(results, ", ")
   214  		},
   215  		"joinResultsForSignature": func(results []string) string {
   216  			return fmt.Sprintf("(%s)", strings.Join(results, ", "))
   217  		},
   218  		"genResultsVars": func(results []string) string {
   219  			vars := make([]string, 0, len(results))
   220  			for i := range results {
   221  				vars = append(vars, fmt.Sprintf("resultVar%d", i))
   222  			}
   223  			return strings.Join(vars, ", ")
   224  		},
   225  		"errorToBoolean": func(results []string) string {
   226  			for i, typeName := range results {
   227  				if isError(typeName) {
   228  					return fmt.Sprintf("resultVar%d == nil", i)
   229  				}
   230  			}
   231  			return "true"
   232  		},
   233  		"errorPresent": func(results []string) bool {
   234  			for _, typeName := range results {
   235  				if isError(typeName) {
   236  					return true
   237  				}
   238  			}
   239  			return false
   240  		},
   241  		"errorVar": func(results []string) string {
   242  			for i, typeName := range results {
   243  				if isError(typeName) {
   244  					return fmt.Sprintf("resultVar%d", i)
   245  				}
   246  			}
   247  			return ""
   248  		},
   249  		"shouldTrace": func(params map[string]bool, param string) string {
   250  			if _, ok := params[param]; ok {
   251  				return fmt.Sprintf(`span.SetTag("%s", %s)`, param, param)
   252  			}
   253  			for pName := range params {
   254  				if strings.HasPrefix(pName, param+".") {
   255  					return fmt.Sprintf(`span.SetTag("%s", %s)`, pName, pName)
   256  				}
   257  			}
   258  			return ""
   259  		},
   260  		"joinParams": func(params []methodParam) string {
   261  			paramsNames := []string{}
   262  			for _, param := range params {
   263  				s := param.Name
   264  				if strings.HasPrefix(param.Type, "...") {
   265  					s += "..."
   266  				}
   267  				paramsNames = append(paramsNames, s)
   268  			}
   269  			return strings.Join(paramsNames, ", ")
   270  		},
   271  		"joinParamsWithType": func(params []methodParam) string {
   272  			paramsWithType := []string{}
   273  			for _, param := range params {
   274  				paramsWithType = append(paramsWithType, fmt.Sprintf("%s %s", param.Name, param.Type))
   275  			}
   276  			return strings.Join(paramsWithType, ", ")
   277  		},
   278  	}
   279  
   280  	t := template.Must(template.New("opentracing_layer.go.tmpl").Funcs(myFuncs).ParseFiles(templateFile))
   281  	err = t.Execute(out, metadata)
   282  	if err != nil {
   283  		return nil, err
   284  	}
   285  	return out.Bytes(), nil
   286  }