github.com/haalcala/mattermost-server-change-repo@v0.0.0-20210713015153-16753fbeee5f/store/layer_generators/main.go (about) 1 // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. 2 // See LICENSE.txt for license information. 3 4 package main 5 6 import ( 7 "bytes" 8 "fmt" 9 "go/ast" 10 "go/format" 11 "go/parser" 12 "go/token" 13 "io/ioutil" 14 "log" 15 "os" 16 "path" 17 "strings" 18 "text/template" 19 ) 20 21 const ( 22 OpenTracingParamsMarker = "@openTracingParams" 23 ErrorType = "error" 24 ) 25 26 func isError(typeName string) bool { 27 return strings.Contains(typeName, ErrorType) 28 } 29 30 func main() { 31 if err := buildTimerLayer(); err != nil { 32 log.Fatal(err) 33 } 34 if err := buildOpenTracingLayer(); err != nil { 35 log.Fatal(err) 36 } 37 if err := buildRetryLayer(); err != nil { 38 log.Fatal(err) 39 } 40 } 41 42 func buildRetryLayer() error { 43 code, err := generateLayer("RetryLayer", "retry_layer.go.tmpl") 44 if err != nil { 45 return err 46 } 47 formatedCode, err := format.Source(code) 48 if err != nil { 49 return err 50 } 51 52 return ioutil.WriteFile(path.Join("retrylayer/retrylayer.go"), formatedCode, 0644) 53 } 54 55 func buildTimerLayer() error { 56 code, err := generateLayer("TimerLayer", "timer_layer.go.tmpl") 57 if err != nil { 58 return err 59 } 60 formatedCode, err := format.Source(code) 61 if err != nil { 62 return err 63 } 64 65 return ioutil.WriteFile(path.Join("timerlayer", "timerlayer.go"), formatedCode, 0644) 66 } 67 68 func buildOpenTracingLayer() error { 69 code, err := generateLayer("OpenTracingLayer", "opentracing_layer.go.tmpl") 70 if err != nil { 71 return err 72 } 73 formatedCode, err := format.Source(code) 74 if err != nil { 75 return err 76 } 77 78 return ioutil.WriteFile(path.Join("opentracinglayer", "opentracinglayer.go"), formatedCode, 0644) 79 } 80 81 type methodParam struct { 82 Name string 83 Type string 84 } 85 86 type methodData struct { 87 Params []methodParam 88 Results []string 89 ParamsToTrace map[string]bool 90 } 91 92 type subStore struct { 93 Methods map[string]methodData 94 } 95 96 type storeMetadata struct { 97 Name string 98 SubStores map[string]subStore 99 Methods map[string]methodData 100 } 101 102 func extractMethodMetadata(method *ast.Field, src []byte) methodData { 103 params := []methodParam{} 104 results := []string{} 105 paramsToTrace := map[string]bool{} 106 ast.Inspect(method.Type, func(expr ast.Node) bool { 107 switch e := expr.(type) { 108 case *ast.FuncType: 109 if method.Doc != nil { 110 for _, comment := range method.Doc.List { 111 s := comment.Text 112 if idx := strings.Index(s, OpenTracingParamsMarker); idx != -1 { 113 for _, p := range strings.Split(s[idx+len(OpenTracingParamsMarker):], ",") { 114 paramsToTrace[strings.TrimSpace(p)] = true 115 } 116 } 117 } 118 } 119 if e.Params != nil { 120 for _, param := range e.Params.List { 121 for _, paramName := range param.Names { 122 params = append(params, methodParam{Name: paramName.Name, Type: string(src[param.Type.Pos()-1 : param.Type.End()-1])}) 123 } 124 } 125 } 126 if e.Results != nil { 127 for _, result := range e.Results.List { 128 results = append(results, string(src[result.Type.Pos()-1:result.Type.End()-1])) 129 } 130 } 131 132 for paramName := range paramsToTrace { 133 found := false 134 for _, param := range params { 135 if param.Name == paramName { 136 found = true 137 break 138 } 139 } 140 if !found { 141 log.Fatalf("Unable to find a parameter called '%s' (method '%s') that is mentioned in the '%s' comment. Maybe it was renamed?", paramName, method.Names[0].Name, OpenTracingParamsMarker) 142 } 143 } 144 } 145 return true 146 }) 147 return methodData{Params: params, Results: results, ParamsToTrace: paramsToTrace} 148 } 149 150 func extractStoreMetadata() (*storeMetadata, error) { 151 // Create the AST by parsing src. 152 fset := token.NewFileSet() // positions are relative to fset 153 154 file, err := os.Open("store.go") 155 if err != nil { 156 return nil, fmt.Errorf("Unable to open store/store.go file: %w", err) 157 } 158 src, err := ioutil.ReadAll(file) 159 if err != nil { 160 return nil, err 161 } 162 file.Close() 163 f, err := parser.ParseFile(fset, "", src, parser.AllErrors|parser.ParseComments) 164 if err != nil { 165 return nil, err 166 } 167 168 topLevelFunctions := map[string]bool{ 169 "MarkSystemRanUnitTests": false, 170 "Close": false, 171 "LockToMaster": false, 172 "UnlockFromMaster": false, 173 "DropAllTables": false, 174 "TotalMasterDbConnections": true, 175 "TotalReadDbConnections": true, 176 "SetContext": true, 177 "TotalSearchDbConnections": true, 178 "GetCurrentSchemaVersion": true, 179 } 180 181 metadata := storeMetadata{Methods: map[string]methodData{}, SubStores: map[string]subStore{}} 182 183 ast.Inspect(f, func(n ast.Node) bool { 184 switch x := n.(type) { 185 case *ast.TypeSpec: 186 if x.Name.Name == "Store" { 187 for _, method := range x.Type.(*ast.InterfaceType).Methods.List { 188 methodName := method.Names[0].Name 189 if _, ok := topLevelFunctions[methodName]; ok { 190 metadata.Methods[methodName] = extractMethodMetadata(method, src) 191 } 192 } 193 } else if strings.HasSuffix(x.Name.Name, "Store") { 194 subStoreName := strings.TrimSuffix(x.Name.Name, "Store") 195 metadata.SubStores[subStoreName] = subStore{Methods: map[string]methodData{}} 196 for _, method := range x.Type.(*ast.InterfaceType).Methods.List { 197 methodName := method.Names[0].Name 198 metadata.SubStores[subStoreName].Methods[methodName] = extractMethodMetadata(method, src) 199 } 200 } 201 } 202 return true 203 }) 204 205 return &metadata, nil 206 } 207 208 func generateLayer(name, templateFile string) ([]byte, error) { 209 out := bytes.NewBufferString("") 210 metadata, err := extractStoreMetadata() 211 if err != nil { 212 return nil, err 213 } 214 metadata.Name = name 215 216 myFuncs := template.FuncMap{ 217 "joinResults": func(results []string) string { 218 return strings.Join(results, ", ") 219 }, 220 "joinResultsForSignature": func(results []string) string { 221 if len(results) == 0 { 222 return "" 223 } 224 if len(results) == 1 { 225 return strings.Join(results, ", ") 226 } 227 return fmt.Sprintf("(%s)", strings.Join(results, ", ")) 228 }, 229 "genResultsVars": func(results []string, withNilError bool) string { 230 vars := []string{} 231 for i, typeName := range results { 232 if isError(typeName) { 233 if withNilError { 234 vars = append(vars, "nil") 235 } else { 236 vars = append(vars, "err") 237 } 238 } else if i == 0 { 239 vars = append(vars, "result") 240 } else { 241 vars = append(vars, fmt.Sprintf("resultVar%d", i)) 242 } 243 } 244 return strings.Join(vars, ", ") 245 }, 246 "errorToBoolean": func(results []string) string { 247 for _, typeName := range results { 248 if isError(typeName) { 249 return "err == nil" 250 } 251 } 252 return "true" 253 }, 254 "errorPresent": func(results []string) bool { 255 for _, typeName := range results { 256 if isError(typeName) { 257 return true 258 } 259 } 260 return false 261 }, 262 "errorVar": func(results []string) string { 263 for _, typeName := range results { 264 if isError(typeName) { 265 return "err" 266 } 267 } 268 return "" 269 }, 270 "joinParams": func(params []methodParam) string { 271 paramsNames := make([]string, 0, len(params)) 272 for _, param := range params { 273 paramsNames = append(paramsNames, param.Name) 274 } 275 return strings.Join(paramsNames, ", ") 276 }, 277 "joinParamsWithType": func(params []methodParam) string { 278 paramsWithType := []string{} 279 for _, param := range params { 280 if param.Type == "ChannelSearchOpts" || param.Type == "UserGetByIdsOpts" { 281 paramsWithType = append(paramsWithType, fmt.Sprintf("%s store.%s", param.Name, param.Type)) 282 } else if param.Type == "*UserGetByIdsOpts" { 283 paramsWithType = append(paramsWithType, fmt.Sprintf("%s *store.UserGetByIdsOpts", param.Name)) 284 } else { 285 paramsWithType = append(paramsWithType, fmt.Sprintf("%s %s", param.Name, param.Type)) 286 } 287 } 288 return strings.Join(paramsWithType, ", ") 289 }, 290 "joinParamsWithTypeOutsideStore": func(params []methodParam) string { 291 paramsWithType := []string{} 292 for _, param := range params { 293 if param.Type == "ChannelSearchOpts" || param.Type == "UserGetByIdsOpts" { 294 paramsWithType = append(paramsWithType, fmt.Sprintf("%s store.%s", param.Name, param.Type)) 295 } else if param.Type == "*UserGetByIdsOpts" { 296 paramsWithType = append(paramsWithType, fmt.Sprintf("%s *store.UserGetByIdsOpts", param.Name)) 297 } else { 298 paramsWithType = append(paramsWithType, fmt.Sprintf("%s %s", param.Name, param.Type)) 299 } 300 } 301 return strings.Join(paramsWithType, ", ") 302 }, 303 } 304 305 t := template.Must(template.New(templateFile).Funcs(myFuncs).ParseFiles("layer_generators/" + templateFile)) 306 if err = t.Execute(out, metadata); err != nil { 307 return nil, err 308 } 309 return out.Bytes(), nil 310 }