github.com/mattermosttest/mattermost-server/v5@v5.0.0-20200917143240-9dfa12e121f9/store/layer_generators/main.go (about) 1 // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. 2 // See LICENSE.txt for license information. 3 4 package main 5 6 import ( 7 "bytes" 8 "fmt" 9 "go/ast" 10 "go/format" 11 "go/parser" 12 "go/token" 13 "io/ioutil" 14 "log" 15 "os" 16 "path" 17 "strings" 18 "text/template" 19 ) 20 21 const ( 22 OPEN_TRACING_PARAMS_MARKER = "@openTracingParams" 23 APP_ERROR_TYPE = "*model.AppError" 24 ERROR_TYPE = "error" 25 ) 26 27 func isError(typeName string) bool { 28 return strings.Contains(typeName, APP_ERROR_TYPE) || strings.Contains(typeName, ERROR_TYPE) 29 } 30 31 func main() { 32 if err := buildTimerLayer(); err != nil { 33 log.Fatal(err) 34 } 35 if err := buildOpenTracingLayer(); err != nil { 36 log.Fatal(err) 37 } 38 } 39 40 func buildTimerLayer() error { 41 code, err := generateLayer("TimerLayer", "timer_layer.go.tmpl") 42 if err != nil { 43 return err 44 } 45 formatedCode, err := format.Source(code) 46 if err != nil { 47 return err 48 } 49 50 return ioutil.WriteFile(path.Join("timerlayer", "timerlayer.go"), formatedCode, 0644) 51 } 52 53 func buildOpenTracingLayer() error { 54 code, err := generateLayer("OpenTracingLayer", "opentracing_layer.go.tmpl") 55 if err != nil { 56 return err 57 } 58 formatedCode, err := format.Source(code) 59 if err != nil { 60 return err 61 } 62 63 return ioutil.WriteFile(path.Join("opentracinglayer", "opentracinglayer.go"), formatedCode, 0644) 64 } 65 66 type methodParam struct { 67 Name string 68 Type string 69 } 70 71 type methodData struct { 72 Params []methodParam 73 Results []string 74 ParamsToTrace map[string]bool 75 } 76 77 type subStore struct { 78 Methods map[string]methodData 79 } 80 81 type storeMetadata struct { 82 Name string 83 SubStores map[string]subStore 84 Methods map[string]methodData 85 } 86 87 func extractMethodMetadata(method *ast.Field, src []byte) methodData { 88 params := []methodParam{} 89 results := []string{} 90 paramsToTrace := map[string]bool{} 91 ast.Inspect(method.Type, func(expr ast.Node) bool { 92 switch e := expr.(type) { 93 case *ast.FuncType: 94 if method.Doc != nil { 95 for _, comment := range method.Doc.List { 96 s := comment.Text 97 if idx := strings.Index(s, OPEN_TRACING_PARAMS_MARKER); idx != -1 { 98 for _, p := range strings.Split(s[idx+len(OPEN_TRACING_PARAMS_MARKER):], ",") { 99 paramsToTrace[strings.TrimSpace(p)] = true 100 } 101 } 102 } 103 } 104 if e.Params != nil { 105 for _, param := range e.Params.List { 106 for _, paramName := range param.Names { 107 params = append(params, methodParam{Name: paramName.Name, Type: string(src[param.Type.Pos()-1 : param.Type.End()-1])}) 108 } 109 } 110 } 111 if e.Results != nil { 112 for _, result := range e.Results.List { 113 results = append(results, string(src[result.Type.Pos()-1:result.Type.End()-1])) 114 } 115 } 116 117 for paramName := range paramsToTrace { 118 found := false 119 for _, param := range params { 120 if param.Name == paramName { 121 found = true 122 break 123 } 124 } 125 if !found { 126 log.Fatalf("Unable to find a parameter called '%s' (method '%s') that is mentioned in the '%s' comment. Maybe it was renamed?", paramName, method.Names[0].Name, OPEN_TRACING_PARAMS_MARKER) 127 } 128 } 129 } 130 return true 131 }) 132 return methodData{Params: params, Results: results, ParamsToTrace: paramsToTrace} 133 } 134 135 func extractStoreMetadata() (*storeMetadata, error) { 136 // Create the AST by parsing src. 137 fset := token.NewFileSet() // positions are relative to fset 138 139 file, err := os.Open("store.go") 140 if err != nil { 141 return nil, fmt.Errorf("Unable to open store/store.go file: %w", err) 142 } 143 src, err := ioutil.ReadAll(file) 144 if err != nil { 145 return nil, err 146 } 147 file.Close() 148 f, err := parser.ParseFile(fset, "", src, parser.AllErrors|parser.ParseComments) 149 if err != nil { 150 return nil, err 151 } 152 153 topLevelFunctions := map[string]bool{ 154 "MarkSystemRanUnitTests": false, 155 "Close": false, 156 "LockToMaster": false, 157 "UnlockFromMaster": false, 158 "DropAllTables": false, 159 "TotalMasterDbConnections": true, 160 "TotalReadDbConnections": true, 161 "SetContext": true, 162 "TotalSearchDbConnections": true, 163 "GetCurrentSchemaVersion": true, 164 } 165 166 metadata := storeMetadata{Methods: map[string]methodData{}, SubStores: map[string]subStore{}} 167 168 ast.Inspect(f, func(n ast.Node) bool { 169 switch x := n.(type) { 170 case *ast.TypeSpec: 171 if x.Name.Name == "Store" { 172 for _, method := range x.Type.(*ast.InterfaceType).Methods.List { 173 methodName := method.Names[0].Name 174 if _, ok := topLevelFunctions[methodName]; ok { 175 metadata.Methods[methodName] = extractMethodMetadata(method, src) 176 } 177 } 178 } else if strings.HasSuffix(x.Name.Name, "Store") { 179 subStoreName := strings.TrimSuffix(x.Name.Name, "Store") 180 metadata.SubStores[subStoreName] = subStore{Methods: map[string]methodData{}} 181 for _, method := range x.Type.(*ast.InterfaceType).Methods.List { 182 methodName := method.Names[0].Name 183 metadata.SubStores[subStoreName].Methods[methodName] = extractMethodMetadata(method, src) 184 } 185 } 186 } 187 return true 188 }) 189 190 return &metadata, nil 191 } 192 193 func generateLayer(name, templateFile string) ([]byte, error) { 194 out := bytes.NewBufferString("") 195 metadata, err := extractStoreMetadata() 196 if err != nil { 197 return nil, err 198 } 199 metadata.Name = name 200 201 myFuncs := template.FuncMap{ 202 "joinResults": func(results []string) string { 203 return strings.Join(results, ", ") 204 }, 205 "joinResultsForSignature": func(results []string) string { 206 if len(results) == 0 { 207 return "" 208 } 209 if len(results) == 1 { 210 return strings.Join(results, ", ") 211 } 212 return fmt.Sprintf("(%s)", strings.Join(results, ", ")) 213 }, 214 "genResultsVars": func(results []string) string { 215 vars := []string{} 216 for i := range results { 217 vars = append(vars, fmt.Sprintf("resultVar%d", i)) 218 } 219 return strings.Join(vars, ", ") 220 }, 221 "errorToBoolean": func(results []string) string { 222 for i, typeName := range results { 223 if isError(typeName) { 224 return fmt.Sprintf("resultVar%d == nil", i) 225 } 226 } 227 return "true" 228 }, 229 "errorPresent": func(results []string) bool { 230 for _, typeName := range results { 231 if isError(typeName) { 232 return true 233 } 234 } 235 return false 236 }, 237 "errorVar": func(results []string) string { 238 for i, typeName := range results { 239 if isError(typeName) { 240 return fmt.Sprintf("resultVar%d", i) 241 } 242 } 243 return "" 244 }, 245 "joinParams": func(params []methodParam) string { 246 paramsNames := make([]string, 0, len(params)) 247 for _, param := range params { 248 paramsNames = append(paramsNames, param.Name) 249 } 250 return strings.Join(paramsNames, ", ") 251 }, 252 "joinParamsWithType": func(params []methodParam) string { 253 paramsWithType := []string{} 254 for _, param := range params { 255 if param.Type == "ChannelSearchOpts" || param.Type == "UserGetByIdsOpts" { 256 paramsWithType = append(paramsWithType, fmt.Sprintf("%s store.%s", param.Name, param.Type)) 257 } else if param.Type == "*UserGetByIdsOpts" { 258 paramsWithType = append(paramsWithType, fmt.Sprintf("%s *store.UserGetByIdsOpts", param.Name)) 259 } else { 260 paramsWithType = append(paramsWithType, fmt.Sprintf("%s %s", param.Name, param.Type)) 261 } 262 } 263 return strings.Join(paramsWithType, ", ") 264 }, 265 } 266 267 t := template.Must(template.New(templateFile).Funcs(myFuncs).ParseFiles("layer_generators/" + templateFile)) 268 if err = t.Execute(out, metadata); err != nil { 269 return nil, err 270 } 271 return out.Bytes(), nil 272 }