github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/cmd/gtrace/main.go (about) 1 package main 2 3 import ( 4 "bufio" 5 "bytes" 6 "flag" 7 "fmt" 8 "go/ast" 9 "go/build" 10 "go/importer" 11 "go/parser" 12 "go/token" 13 "go/types" 14 "io" 15 "log" 16 "os" 17 "path/filepath" 18 "strings" 19 20 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" 21 ) 22 23 //nolint:gocyclo 24 func main() { 25 var ( 26 // Reports whether we were called from go:generate. 27 isGoGenerate bool 28 29 gofile string 30 workDir string 31 err error 32 ) 33 if gofile = os.Getenv("GOFILE"); gofile != "" { 34 // NOTE: GOFILE is always a filename without path. 35 isGoGenerate = true 36 workDir, err = os.Getwd() 37 if err != nil { 38 log.Fatal(err) 39 } 40 } else { 41 args := flag.Args() 42 if len(args) == 0 { 43 log.Fatal("no $GOFILE env nor file parameter were given") 44 } 45 gofile = filepath.Base(args[0]) 46 workDir = filepath.Dir(args[0]) 47 } 48 { 49 prefix := filepath.Join(filepath.Base(workDir), gofile) 50 log.SetPrefix("[" + prefix + "] ") 51 } 52 buildCtx := build.Default 53 buildPkg, err := buildCtx.ImportDir(workDir, build.IgnoreVendor) 54 if err != nil { 55 log.Fatal(err) 56 } 57 58 srcFilePath := filepath.Join(workDir, gofile) 59 60 var writers []*Writer 61 if isGoGenerate { 62 openFile := func(name string) (*os.File, func()) { 63 var f *os.File 64 //nolint:gofumpt 65 //nolint:nolintlint 66 //nolint:gosec 67 f, err = os.OpenFile( 68 filepath.Join(workDir, filepath.Clean(name)), 69 os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 70 0o600, 71 ) 72 if err != nil { 73 log.Fatal(err) 74 } 75 76 return f, func() { f.Close() } 77 } 78 ext := filepath.Ext(gofile) 79 name := strings.TrimSuffix(gofile, ext) 80 f, clean := openFile(name + "_gtrace" + ext) 81 defer clean() 82 writers = append(writers, &Writer{ 83 Context: buildCtx, 84 Output: f, 85 }) 86 } else { 87 writers = append(writers, &Writer{ 88 Context: buildCtx, 89 Output: os.Stdout, 90 }) 91 } 92 93 var ( 94 pkgFiles = make([]*os.File, 0, len(buildPkg.GoFiles)) 95 astFiles = make([]*ast.File, 0, len(buildPkg.GoFiles)) 96 97 buildConstraints []string 98 ) 99 fset := token.NewFileSet() 100 for _, name := range buildPkg.GoFiles { 101 base := strings.TrimSuffix(name, filepath.Ext(name)) 102 if isGenerated(base, "_gtrace") { 103 continue 104 } 105 var file *os.File 106 file, err = os.Open(filepath.Join(workDir, name)) 107 if err != nil { 108 panic(err) 109 } 110 defer file.Close() //nolint:gocritic 111 112 var ast *ast.File 113 ast, err = parser.ParseFile(fset, file.Name(), file, parser.ParseComments) 114 if err != nil { 115 panic(fmt.Sprintf("parse %q error: %v", file.Name(), err)) 116 } 117 118 pkgFiles = append(pkgFiles, file) 119 astFiles = append(astFiles, ast) 120 121 if name == gofile { 122 if _, err = file.Seek(0, io.SeekStart); err != nil { 123 panic(err) 124 } 125 buildConstraints, err = scanBuildConstraints(file) 126 if err != nil { 127 panic(err) 128 } 129 } 130 } 131 info := &types.Info{ 132 Types: make(map[ast.Expr]types.TypeAndValue), 133 Defs: make(map[*ast.Ident]types.Object), 134 Uses: make(map[*ast.Ident]types.Object), 135 } 136 conf := types.Config{ 137 IgnoreFuncBodies: true, 138 DisableUnusedImportCheck: true, 139 Importer: importer.ForCompiler(fset, "source", nil), 140 } 141 pkg, err := conf.Check(".", fset, astFiles, info) 142 if err != nil { 143 panic(fmt.Sprintf("type error: %v", err)) 144 } 145 var items []*GenItem 146 for i, astFile := range astFiles { 147 if pkgFiles[i].Name() != srcFilePath { 148 continue 149 } 150 var ( 151 depth int 152 item *GenItem 153 ) 154 ast.Inspect(astFile, func(n ast.Node) (next bool) { 155 if n == nil { 156 item = nil 157 depth-- 158 159 return true 160 } 161 defer func() { 162 if next { 163 depth++ 164 } 165 }() 166 167 switch v := n.(type) { 168 case *ast.FuncDecl, *ast.ValueSpec: 169 return false 170 171 case *ast.Ident: 172 if item != nil { 173 item.Ident = v 174 } 175 176 return false 177 178 case *ast.CommentGroup: 179 for _, c := range v.List { 180 if strings.Contains(strings.TrimPrefix(c.Text, "//"), "gtrace:gen") { 181 if item == nil { 182 item = &GenItem{} 183 } 184 } 185 } 186 187 return false 188 189 case *ast.StructType: 190 if item != nil { 191 item.StructType = v 192 items = append(items, item) 193 item = nil 194 } 195 196 return false 197 } 198 199 return true 200 }) 201 } 202 p := Package{ 203 Package: pkg, 204 BuildConstraints: buildConstraints, 205 } 206 traces := make(map[string]*Trace) 207 for _, item := range items { 208 t := &Trace{ 209 Name: item.Ident.Name, 210 } 211 p.Traces = append(p.Traces, t) 212 traces[item.Ident.Name] = t 213 } 214 for i, item := range items { 215 t := p.Traces[i] 216 for _, field := range item.StructType.Fields.List { 217 if _, ok := field.Type.(*ast.FuncType); !ok { 218 continue 219 } 220 name := field.Names[0].Name 221 fn, ok := field.Type.(*ast.FuncType) 222 if !ok { 223 continue 224 } 225 f, err := buildFunc(info, traces, fn) 226 if err != nil { 227 log.Printf( 228 "skipping hook %s due to error: %v", 229 name, err, 230 ) 231 232 continue 233 } 234 t.Hooks = append(t.Hooks, Hook{ 235 Name: name, 236 Func: f, 237 }) 238 } 239 } 240 for _, w := range writers { 241 if err := w.Write(p); err != nil { 242 panic(err) 243 } 244 } 245 246 log.Println("OK") 247 } 248 249 func buildFunc(info *types.Info, traces map[string]*Trace, fn *ast.FuncType) (ret *Func, err error) { 250 ret = new(Func) 251 for _, p := range fn.Params.List { 252 t := info.TypeOf(p.Type) 253 if t == nil { 254 log.Fatalf("unknown type: %s", p.Type) 255 } 256 var names []string 257 for _, n := range p.Names { 258 name := n.Name 259 if name == "_" { 260 name = "" 261 } 262 names = append(names, name) 263 } 264 if len(names) == 0 { 265 // Case where arg is not named. 266 names = []string{""} 267 } 268 for _, name := range names { 269 ret.Params = append(ret.Params, Param{ 270 Name: name, 271 Type: t, 272 }) 273 } 274 } 275 if fn.Results == nil { 276 return ret, nil 277 } 278 if len(fn.Results.List) > 1 { 279 return nil, fmt.Errorf( 280 "unsupported number of function results", 281 ) 282 } 283 284 r := fn.Results.List[0] 285 286 switch x := r.Type.(type) { 287 case *ast.FuncType: 288 result, err := buildFunc(info, traces, x) 289 if err != nil { 290 return nil, xerrors.WithStackTrace(err) 291 } 292 ret.Result = append(ret.Result, result) 293 294 return ret, nil 295 296 case *ast.Ident: 297 if t, ok := traces[x.Name]; ok { 298 t.Nested = true 299 ret.Result = append(ret.Result, t) 300 301 return ret, nil 302 } 303 } 304 305 return nil, fmt.Errorf( 306 "unsupported function result type %s", 307 info.TypeOf(r.Type), 308 ) 309 } 310 311 type Package struct { 312 *types.Package 313 314 BuildConstraints []string 315 Traces []*Trace 316 } 317 318 type Trace struct { 319 Name string 320 Hooks []Hook 321 Nested bool 322 } 323 324 func (*Trace) isFuncResult() bool { return true } 325 326 type Hook struct { 327 Name string 328 Func *Func 329 } 330 331 type Param struct { 332 Name string // Might be empty. 333 Type types.Type 334 } 335 336 func (p Param) String() string { 337 return p.Name + " " + p.Type.String() 338 } 339 340 type FuncResult interface { 341 isFuncResult() bool 342 } 343 344 type Func struct { 345 Params []Param 346 Result []FuncResult // 0 or 1. 347 } 348 349 func (*Func) isFuncResult() bool { return true } 350 351 func (f *Func) HasResult() bool { 352 return len(f.Result) > 0 353 } 354 355 type GenFlag uint8 356 357 func (f GenFlag) Has(x GenFlag) bool { 358 return f&x != 0 359 } 360 361 type GenItem struct { 362 Ident *ast.Ident 363 StructType *ast.StructType 364 } 365 366 func rsplit(s string, c byte) (s1, s2 string) { 367 i := strings.LastIndexByte(s, c) 368 if i == -1 { 369 return s, "" 370 } 371 372 return s[:i], s[i+1:] 373 } 374 375 func scanBuildConstraints(r io.Reader) (cs []string, err error) { 376 br := bufio.NewReader(r) 377 for { 378 line, err := br.ReadBytes('\n') 379 if err != nil { 380 return nil, xerrors.WithStackTrace(err) 381 } 382 line = bytes.TrimSpace(line) 383 if comm := bytes.TrimPrefix(line, []byte("//")); !bytes.Equal(comm, line) { 384 comm = bytes.TrimSpace(comm) 385 if bytes.HasPrefix(comm, []byte("+build")) { 386 cs = append(cs, string(line)) 387 388 continue 389 } 390 } 391 if bytes.HasPrefix(line, []byte("package ")) { 392 break 393 } 394 } 395 396 return cs, nil 397 } 398 399 func isGenerated(base, suffix string) bool { 400 i := strings.Index(base, suffix) 401 if i == -1 { 402 return false 403 } 404 n := len(base) 405 m := i + len(suffix) 406 407 return m == n || base[m] == '_' 408 }