github.com/haalcala/mattermost-server-change-repo@v0.0.0-20210713015153-16753fbeee5f/store/layer_generators/main.go (about)

     1  // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
     2  // See LICENSE.txt for license information.
     3  
     4  package main
     5  
     6  import (
     7  	"bytes"
     8  	"fmt"
     9  	"go/ast"
    10  	"go/format"
    11  	"go/parser"
    12  	"go/token"
    13  	"io/ioutil"
    14  	"log"
    15  	"os"
    16  	"path"
    17  	"strings"
    18  	"text/template"
    19  )
    20  
    21  const (
    22  	OpenTracingParamsMarker = "@openTracingParams"
    23  	ErrorType               = "error"
    24  )
    25  
    26  func isError(typeName string) bool {
    27  	return strings.Contains(typeName, ErrorType)
    28  }
    29  
    30  func main() {
    31  	if err := buildTimerLayer(); err != nil {
    32  		log.Fatal(err)
    33  	}
    34  	if err := buildOpenTracingLayer(); err != nil {
    35  		log.Fatal(err)
    36  	}
    37  	if err := buildRetryLayer(); err != nil {
    38  		log.Fatal(err)
    39  	}
    40  }
    41  
    42  func buildRetryLayer() error {
    43  	code, err := generateLayer("RetryLayer", "retry_layer.go.tmpl")
    44  	if err != nil {
    45  		return err
    46  	}
    47  	formatedCode, err := format.Source(code)
    48  	if err != nil {
    49  		return err
    50  	}
    51  
    52  	return ioutil.WriteFile(path.Join("retrylayer/retrylayer.go"), formatedCode, 0644)
    53  }
    54  
    55  func buildTimerLayer() error {
    56  	code, err := generateLayer("TimerLayer", "timer_layer.go.tmpl")
    57  	if err != nil {
    58  		return err
    59  	}
    60  	formatedCode, err := format.Source(code)
    61  	if err != nil {
    62  		return err
    63  	}
    64  
    65  	return ioutil.WriteFile(path.Join("timerlayer", "timerlayer.go"), formatedCode, 0644)
    66  }
    67  
    68  func buildOpenTracingLayer() error {
    69  	code, err := generateLayer("OpenTracingLayer", "opentracing_layer.go.tmpl")
    70  	if err != nil {
    71  		return err
    72  	}
    73  	formatedCode, err := format.Source(code)
    74  	if err != nil {
    75  		return err
    76  	}
    77  
    78  	return ioutil.WriteFile(path.Join("opentracinglayer", "opentracinglayer.go"), formatedCode, 0644)
    79  }
    80  
    81  type methodParam struct {
    82  	Name string
    83  	Type string
    84  }
    85  
    86  type methodData struct {
    87  	Params        []methodParam
    88  	Results       []string
    89  	ParamsToTrace map[string]bool
    90  }
    91  
    92  type subStore struct {
    93  	Methods map[string]methodData
    94  }
    95  
    96  type storeMetadata struct {
    97  	Name      string
    98  	SubStores map[string]subStore
    99  	Methods   map[string]methodData
   100  }
   101  
   102  func extractMethodMetadata(method *ast.Field, src []byte) methodData {
   103  	params := []methodParam{}
   104  	results := []string{}
   105  	paramsToTrace := map[string]bool{}
   106  	ast.Inspect(method.Type, func(expr ast.Node) bool {
   107  		switch e := expr.(type) {
   108  		case *ast.FuncType:
   109  			if method.Doc != nil {
   110  				for _, comment := range method.Doc.List {
   111  					s := comment.Text
   112  					if idx := strings.Index(s, OpenTracingParamsMarker); idx != -1 {
   113  						for _, p := range strings.Split(s[idx+len(OpenTracingParamsMarker):], ",") {
   114  							paramsToTrace[strings.TrimSpace(p)] = true
   115  						}
   116  					}
   117  				}
   118  			}
   119  			if e.Params != nil {
   120  				for _, param := range e.Params.List {
   121  					for _, paramName := range param.Names {
   122  						params = append(params, methodParam{Name: paramName.Name, Type: string(src[param.Type.Pos()-1 : param.Type.End()-1])})
   123  					}
   124  				}
   125  			}
   126  			if e.Results != nil {
   127  				for _, result := range e.Results.List {
   128  					results = append(results, string(src[result.Type.Pos()-1:result.Type.End()-1]))
   129  				}
   130  			}
   131  
   132  			for paramName := range paramsToTrace {
   133  				found := false
   134  				for _, param := range params {
   135  					if param.Name == paramName {
   136  						found = true
   137  						break
   138  					}
   139  				}
   140  				if !found {
   141  					log.Fatalf("Unable to find a parameter called '%s' (method '%s') that is mentioned in the '%s' comment. Maybe it was renamed?", paramName, method.Names[0].Name, OpenTracingParamsMarker)
   142  				}
   143  			}
   144  		}
   145  		return true
   146  	})
   147  	return methodData{Params: params, Results: results, ParamsToTrace: paramsToTrace}
   148  }
   149  
   150  func extractStoreMetadata() (*storeMetadata, error) {
   151  	// Create the AST by parsing src.
   152  	fset := token.NewFileSet() // positions are relative to fset
   153  
   154  	file, err := os.Open("store.go")
   155  	if err != nil {
   156  		return nil, fmt.Errorf("Unable to open store/store.go file: %w", err)
   157  	}
   158  	src, err := ioutil.ReadAll(file)
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  	file.Close()
   163  	f, err := parser.ParseFile(fset, "", src, parser.AllErrors|parser.ParseComments)
   164  	if err != nil {
   165  		return nil, err
   166  	}
   167  
   168  	topLevelFunctions := map[string]bool{
   169  		"MarkSystemRanUnitTests":   false,
   170  		"Close":                    false,
   171  		"LockToMaster":             false,
   172  		"UnlockFromMaster":         false,
   173  		"DropAllTables":            false,
   174  		"TotalMasterDbConnections": true,
   175  		"TotalReadDbConnections":   true,
   176  		"SetContext":               true,
   177  		"TotalSearchDbConnections": true,
   178  		"GetCurrentSchemaVersion":  true,
   179  	}
   180  
   181  	metadata := storeMetadata{Methods: map[string]methodData{}, SubStores: map[string]subStore{}}
   182  
   183  	ast.Inspect(f, func(n ast.Node) bool {
   184  		switch x := n.(type) {
   185  		case *ast.TypeSpec:
   186  			if x.Name.Name == "Store" {
   187  				for _, method := range x.Type.(*ast.InterfaceType).Methods.List {
   188  					methodName := method.Names[0].Name
   189  					if _, ok := topLevelFunctions[methodName]; ok {
   190  						metadata.Methods[methodName] = extractMethodMetadata(method, src)
   191  					}
   192  				}
   193  			} else if strings.HasSuffix(x.Name.Name, "Store") {
   194  				subStoreName := strings.TrimSuffix(x.Name.Name, "Store")
   195  				metadata.SubStores[subStoreName] = subStore{Methods: map[string]methodData{}}
   196  				for _, method := range x.Type.(*ast.InterfaceType).Methods.List {
   197  					methodName := method.Names[0].Name
   198  					metadata.SubStores[subStoreName].Methods[methodName] = extractMethodMetadata(method, src)
   199  				}
   200  			}
   201  		}
   202  		return true
   203  	})
   204  
   205  	return &metadata, nil
   206  }
   207  
   208  func generateLayer(name, templateFile string) ([]byte, error) {
   209  	out := bytes.NewBufferString("")
   210  	metadata, err := extractStoreMetadata()
   211  	if err != nil {
   212  		return nil, err
   213  	}
   214  	metadata.Name = name
   215  
   216  	myFuncs := template.FuncMap{
   217  		"joinResults": func(results []string) string {
   218  			return strings.Join(results, ", ")
   219  		},
   220  		"joinResultsForSignature": func(results []string) string {
   221  			if len(results) == 0 {
   222  				return ""
   223  			}
   224  			if len(results) == 1 {
   225  				return strings.Join(results, ", ")
   226  			}
   227  			return fmt.Sprintf("(%s)", strings.Join(results, ", "))
   228  		},
   229  		"genResultsVars": func(results []string, withNilError bool) string {
   230  			vars := []string{}
   231  			for i, typeName := range results {
   232  				if isError(typeName) {
   233  					if withNilError {
   234  						vars = append(vars, "nil")
   235  					} else {
   236  						vars = append(vars, "err")
   237  					}
   238  				} else if i == 0 {
   239  					vars = append(vars, "result")
   240  				} else {
   241  					vars = append(vars, fmt.Sprintf("resultVar%d", i))
   242  				}
   243  			}
   244  			return strings.Join(vars, ", ")
   245  		},
   246  		"errorToBoolean": func(results []string) string {
   247  			for _, typeName := range results {
   248  				if isError(typeName) {
   249  					return "err == nil"
   250  				}
   251  			}
   252  			return "true"
   253  		},
   254  		"errorPresent": func(results []string) bool {
   255  			for _, typeName := range results {
   256  				if isError(typeName) {
   257  					return true
   258  				}
   259  			}
   260  			return false
   261  		},
   262  		"errorVar": func(results []string) string {
   263  			for _, typeName := range results {
   264  				if isError(typeName) {
   265  					return "err"
   266  				}
   267  			}
   268  			return ""
   269  		},
   270  		"joinParams": func(params []methodParam) string {
   271  			paramsNames := make([]string, 0, len(params))
   272  			for _, param := range params {
   273  				paramsNames = append(paramsNames, param.Name)
   274  			}
   275  			return strings.Join(paramsNames, ", ")
   276  		},
   277  		"joinParamsWithType": func(params []methodParam) string {
   278  			paramsWithType := []string{}
   279  			for _, param := range params {
   280  				if param.Type == "ChannelSearchOpts" || param.Type == "UserGetByIdsOpts" {
   281  					paramsWithType = append(paramsWithType, fmt.Sprintf("%s store.%s", param.Name, param.Type))
   282  				} else if param.Type == "*UserGetByIdsOpts" {
   283  					paramsWithType = append(paramsWithType, fmt.Sprintf("%s *store.UserGetByIdsOpts", param.Name))
   284  				} else {
   285  					paramsWithType = append(paramsWithType, fmt.Sprintf("%s %s", param.Name, param.Type))
   286  				}
   287  			}
   288  			return strings.Join(paramsWithType, ", ")
   289  		},
   290  		"joinParamsWithTypeOutsideStore": func(params []methodParam) string {
   291  			paramsWithType := []string{}
   292  			for _, param := range params {
   293  				if param.Type == "ChannelSearchOpts" || param.Type == "UserGetByIdsOpts" {
   294  					paramsWithType = append(paramsWithType, fmt.Sprintf("%s store.%s", param.Name, param.Type))
   295  				} else if param.Type == "*UserGetByIdsOpts" {
   296  					paramsWithType = append(paramsWithType, fmt.Sprintf("%s *store.UserGetByIdsOpts", param.Name))
   297  				} else {
   298  					paramsWithType = append(paramsWithType, fmt.Sprintf("%s %s", param.Name, param.Type))
   299  				}
   300  			}
   301  			return strings.Join(paramsWithType, ", ")
   302  		},
   303  	}
   304  
   305  	t := template.Must(template.New(templateFile).Funcs(myFuncs).ParseFiles("layer_generators/" + templateFile))
   306  	if err = t.Execute(out, metadata); err != nil {
   307  		return nil, err
   308  	}
   309  	return out.Bytes(), nil
   310  }