github.com/mattermosttest/mattermost-server/v5@v5.0.0-20200917143240-9dfa12e121f9/store/layer_generators/main.go (about)

     1  // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
     2  // See LICENSE.txt for license information.
     3  
     4  package main
     5  
     6  import (
     7  	"bytes"
     8  	"fmt"
     9  	"go/ast"
    10  	"go/format"
    11  	"go/parser"
    12  	"go/token"
    13  	"io/ioutil"
    14  	"log"
    15  	"os"
    16  	"path"
    17  	"strings"
    18  	"text/template"
    19  )
    20  
    21  const (
    22  	OPEN_TRACING_PARAMS_MARKER = "@openTracingParams"
    23  	APP_ERROR_TYPE             = "*model.AppError"
    24  	ERROR_TYPE                 = "error"
    25  )
    26  
    27  func isError(typeName string) bool {
    28  	return strings.Contains(typeName, APP_ERROR_TYPE) || strings.Contains(typeName, ERROR_TYPE)
    29  }
    30  
    31  func main() {
    32  	if err := buildTimerLayer(); err != nil {
    33  		log.Fatal(err)
    34  	}
    35  	if err := buildOpenTracingLayer(); err != nil {
    36  		log.Fatal(err)
    37  	}
    38  }
    39  
    40  func buildTimerLayer() error {
    41  	code, err := generateLayer("TimerLayer", "timer_layer.go.tmpl")
    42  	if err != nil {
    43  		return err
    44  	}
    45  	formatedCode, err := format.Source(code)
    46  	if err != nil {
    47  		return err
    48  	}
    49  
    50  	return ioutil.WriteFile(path.Join("timerlayer", "timerlayer.go"), formatedCode, 0644)
    51  }
    52  
    53  func buildOpenTracingLayer() error {
    54  	code, err := generateLayer("OpenTracingLayer", "opentracing_layer.go.tmpl")
    55  	if err != nil {
    56  		return err
    57  	}
    58  	formatedCode, err := format.Source(code)
    59  	if err != nil {
    60  		return err
    61  	}
    62  
    63  	return ioutil.WriteFile(path.Join("opentracinglayer", "opentracinglayer.go"), formatedCode, 0644)
    64  }
    65  
    66  type methodParam struct {
    67  	Name string
    68  	Type string
    69  }
    70  
    71  type methodData struct {
    72  	Params        []methodParam
    73  	Results       []string
    74  	ParamsToTrace map[string]bool
    75  }
    76  
    77  type subStore struct {
    78  	Methods map[string]methodData
    79  }
    80  
    81  type storeMetadata struct {
    82  	Name      string
    83  	SubStores map[string]subStore
    84  	Methods   map[string]methodData
    85  }
    86  
    87  func extractMethodMetadata(method *ast.Field, src []byte) methodData {
    88  	params := []methodParam{}
    89  	results := []string{}
    90  	paramsToTrace := map[string]bool{}
    91  	ast.Inspect(method.Type, func(expr ast.Node) bool {
    92  		switch e := expr.(type) {
    93  		case *ast.FuncType:
    94  			if method.Doc != nil {
    95  				for _, comment := range method.Doc.List {
    96  					s := comment.Text
    97  					if idx := strings.Index(s, OPEN_TRACING_PARAMS_MARKER); idx != -1 {
    98  						for _, p := range strings.Split(s[idx+len(OPEN_TRACING_PARAMS_MARKER):], ",") {
    99  							paramsToTrace[strings.TrimSpace(p)] = true
   100  						}
   101  					}
   102  				}
   103  			}
   104  			if e.Params != nil {
   105  				for _, param := range e.Params.List {
   106  					for _, paramName := range param.Names {
   107  						params = append(params, methodParam{Name: paramName.Name, Type: string(src[param.Type.Pos()-1 : param.Type.End()-1])})
   108  					}
   109  				}
   110  			}
   111  			if e.Results != nil {
   112  				for _, result := range e.Results.List {
   113  					results = append(results, string(src[result.Type.Pos()-1:result.Type.End()-1]))
   114  				}
   115  			}
   116  
   117  			for paramName := range paramsToTrace {
   118  				found := false
   119  				for _, param := range params {
   120  					if param.Name == paramName {
   121  						found = true
   122  						break
   123  					}
   124  				}
   125  				if !found {
   126  					log.Fatalf("Unable to find a parameter called '%s' (method '%s') that is mentioned in the '%s' comment. Maybe it was renamed?", paramName, method.Names[0].Name, OPEN_TRACING_PARAMS_MARKER)
   127  				}
   128  			}
   129  		}
   130  		return true
   131  	})
   132  	return methodData{Params: params, Results: results, ParamsToTrace: paramsToTrace}
   133  }
   134  
   135  func extractStoreMetadata() (*storeMetadata, error) {
   136  	// Create the AST by parsing src.
   137  	fset := token.NewFileSet() // positions are relative to fset
   138  
   139  	file, err := os.Open("store.go")
   140  	if err != nil {
   141  		return nil, fmt.Errorf("Unable to open store/store.go file: %w", err)
   142  	}
   143  	src, err := ioutil.ReadAll(file)
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  	file.Close()
   148  	f, err := parser.ParseFile(fset, "", src, parser.AllErrors|parser.ParseComments)
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  
   153  	topLevelFunctions := map[string]bool{
   154  		"MarkSystemRanUnitTests":   false,
   155  		"Close":                    false,
   156  		"LockToMaster":             false,
   157  		"UnlockFromMaster":         false,
   158  		"DropAllTables":            false,
   159  		"TotalMasterDbConnections": true,
   160  		"TotalReadDbConnections":   true,
   161  		"SetContext":               true,
   162  		"TotalSearchDbConnections": true,
   163  		"GetCurrentSchemaVersion":  true,
   164  	}
   165  
   166  	metadata := storeMetadata{Methods: map[string]methodData{}, SubStores: map[string]subStore{}}
   167  
   168  	ast.Inspect(f, func(n ast.Node) bool {
   169  		switch x := n.(type) {
   170  		case *ast.TypeSpec:
   171  			if x.Name.Name == "Store" {
   172  				for _, method := range x.Type.(*ast.InterfaceType).Methods.List {
   173  					methodName := method.Names[0].Name
   174  					if _, ok := topLevelFunctions[methodName]; ok {
   175  						metadata.Methods[methodName] = extractMethodMetadata(method, src)
   176  					}
   177  				}
   178  			} else if strings.HasSuffix(x.Name.Name, "Store") {
   179  				subStoreName := strings.TrimSuffix(x.Name.Name, "Store")
   180  				metadata.SubStores[subStoreName] = subStore{Methods: map[string]methodData{}}
   181  				for _, method := range x.Type.(*ast.InterfaceType).Methods.List {
   182  					methodName := method.Names[0].Name
   183  					metadata.SubStores[subStoreName].Methods[methodName] = extractMethodMetadata(method, src)
   184  				}
   185  			}
   186  		}
   187  		return true
   188  	})
   189  
   190  	return &metadata, nil
   191  }
   192  
   193  func generateLayer(name, templateFile string) ([]byte, error) {
   194  	out := bytes.NewBufferString("")
   195  	metadata, err := extractStoreMetadata()
   196  	if err != nil {
   197  		return nil, err
   198  	}
   199  	metadata.Name = name
   200  
   201  	myFuncs := template.FuncMap{
   202  		"joinResults": func(results []string) string {
   203  			return strings.Join(results, ", ")
   204  		},
   205  		"joinResultsForSignature": func(results []string) string {
   206  			if len(results) == 0 {
   207  				return ""
   208  			}
   209  			if len(results) == 1 {
   210  				return strings.Join(results, ", ")
   211  			}
   212  			return fmt.Sprintf("(%s)", strings.Join(results, ", "))
   213  		},
   214  		"genResultsVars": func(results []string) string {
   215  			vars := []string{}
   216  			for i := range results {
   217  				vars = append(vars, fmt.Sprintf("resultVar%d", i))
   218  			}
   219  			return strings.Join(vars, ", ")
   220  		},
   221  		"errorToBoolean": func(results []string) string {
   222  			for i, typeName := range results {
   223  				if isError(typeName) {
   224  					return fmt.Sprintf("resultVar%d == nil", i)
   225  				}
   226  			}
   227  			return "true"
   228  		},
   229  		"errorPresent": func(results []string) bool {
   230  			for _, typeName := range results {
   231  				if isError(typeName) {
   232  					return true
   233  				}
   234  			}
   235  			return false
   236  		},
   237  		"errorVar": func(results []string) string {
   238  			for i, typeName := range results {
   239  				if isError(typeName) {
   240  					return fmt.Sprintf("resultVar%d", i)
   241  				}
   242  			}
   243  			return ""
   244  		},
   245  		"joinParams": func(params []methodParam) string {
   246  			paramsNames := make([]string, 0, len(params))
   247  			for _, param := range params {
   248  				paramsNames = append(paramsNames, param.Name)
   249  			}
   250  			return strings.Join(paramsNames, ", ")
   251  		},
   252  		"joinParamsWithType": func(params []methodParam) string {
   253  			paramsWithType := []string{}
   254  			for _, param := range params {
   255  				if param.Type == "ChannelSearchOpts" || param.Type == "UserGetByIdsOpts" {
   256  					paramsWithType = append(paramsWithType, fmt.Sprintf("%s store.%s", param.Name, param.Type))
   257  				} else if param.Type == "*UserGetByIdsOpts" {
   258  					paramsWithType = append(paramsWithType, fmt.Sprintf("%s *store.UserGetByIdsOpts", param.Name))
   259  				} else {
   260  					paramsWithType = append(paramsWithType, fmt.Sprintf("%s %s", param.Name, param.Type))
   261  				}
   262  			}
   263  			return strings.Join(paramsWithType, ", ")
   264  		},
   265  	}
   266  
   267  	t := template.Must(template.New(templateFile).Funcs(myFuncs).ParseFiles("layer_generators/" + templateFile))
   268  	if err = t.Execute(out, metadata); err != nil {
   269  		return nil, err
   270  	}
   271  	return out.Bytes(), nil
   272  }