
     1  package main
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"flag"
     7  	"fmt"
     8  	"go/ast"
     9  	"go/build"
    10  	"go/importer"
    11  	"go/parser"
    12  	"go/token"
    13  	"go/types"
    14  	"io"
    15  	"log"
    16  	"os"
    17  	"path/filepath"
    18  	"strings"
    20  	""
    21  )
    23  //nolint:gocyclo
    24  func main() {
    25  	var (
    26  		// Reports whether we were called from go:generate.
    27  		isGoGenerate bool
    29  		gofile  string
    30  		workDir string
    31  		err     error
    32  	)
    33  	if gofile = os.Getenv("GOFILE"); gofile != "" {
    34  		// NOTE: GOFILE is always a filename without path.
    35  		isGoGenerate = true
    36  		workDir, err = os.Getwd()
    37  		if err != nil {
    38  			log.Fatal(err)
    39  		}
    40  	} else {
    41  		args := flag.Args()
    42  		if len(args) == 0 {
    43  			log.Fatal("no $GOFILE env nor file parameter were given")
    44  		}
    45  		gofile = filepath.Base(args[0])
    46  		workDir = filepath.Dir(args[0])
    47  	}
    48  	{
    49  		prefix := filepath.Join(filepath.Base(workDir), gofile)
    50  		log.SetPrefix("[" + prefix + "] ")
    51  	}
    52  	buildCtx := build.Default
    53  	buildPkg, err := buildCtx.ImportDir(workDir, build.IgnoreVendor)
    54  	if err != nil {
    55  		log.Fatal(err)
    56  	}
    58  	srcFilePath := filepath.Join(workDir, gofile)
    60  	var writers []*Writer
    61  	if isGoGenerate {
    62  		openFile := func(name string) (*os.File, func()) {
    63  			var f *os.File
    64  			//nolint:gofumpt
    65  			//nolint:nolintlint
    66  			//nolint:gosec
    67  			f, err = os.OpenFile(
    68  				filepath.Join(workDir, filepath.Clean(name)),
    69  				os.O_WRONLY|os.O_CREATE|os.O_TRUNC,
    70  				0o600,
    71  			)
    72  			if err != nil {
    73  				log.Fatal(err)
    74  			}
    76  			return f, func() { f.Close() }
    77  		}
    78  		ext := filepath.Ext(gofile)
    79  		name := strings.TrimSuffix(gofile, ext)
    80  		f, clean := openFile(name + "_gtrace" + ext)
    81  		defer clean()
    82  		writers = append(writers, &Writer{
    83  			Context: buildCtx,
    84  			Output:  f,
    85  		})
    86  	} else {
    87  		writers = append(writers, &Writer{
    88  			Context: buildCtx,
    89  			Output:  os.Stdout,
    90  		})
    91  	}
    93  	var (
    94  		pkgFiles = make([]*os.File, 0, len(buildPkg.GoFiles))
    95  		astFiles = make([]*ast.File, 0, len(buildPkg.GoFiles))
    97  		buildConstraints []string
    98  	)
    99  	fset := token.NewFileSet()
   100  	for _, name := range buildPkg.GoFiles {
   101  		base := strings.TrimSuffix(name, filepath.Ext(name))
   102  		if isGenerated(base, "_gtrace") {
   103  			continue
   104  		}
   105  		var file *os.File
   106  		file, err = os.Open(filepath.Join(workDir, name))
   107  		if err != nil {
   108  			panic(err)
   109  		}
   110  		defer file.Close() //nolint:gocritic
   112  		var ast *ast.File
   113  		ast, err = parser.ParseFile(fset, file.Name(), file, parser.ParseComments)
   114  		if err != nil {
   115  			panic(fmt.Sprintf("parse %q error: %v", file.Name(), err))
   116  		}
   118  		pkgFiles = append(pkgFiles, file)
   119  		astFiles = append(astFiles, ast)
   121  		if name == gofile {
   122  			if _, err = file.Seek(0, io.SeekStart); err != nil {
   123  				panic(err)
   124  			}
   125  			buildConstraints, err = scanBuildConstraints(file)
   126  			if err != nil {
   127  				panic(err)
   128  			}
   129  		}
   130  	}
   131  	info := &types.Info{
   132  		Types: make(map[ast.Expr]types.TypeAndValue),
   133  		Defs:  make(map[*ast.Ident]types.Object),
   134  		Uses:  make(map[*ast.Ident]types.Object),
   135  	}
   136  	conf := types.Config{
   137  		IgnoreFuncBodies:         true,
   138  		DisableUnusedImportCheck: true,
   139  		Importer:                 importer.ForCompiler(fset, "source", nil),
   140  	}
   141  	pkg, err := conf.Check(".", fset, astFiles, info)
   142  	if err != nil {
   143  		panic(fmt.Sprintf("type error: %v", err))
   144  	}
   145  	var items []*GenItem
   146  	for i, astFile := range astFiles {
   147  		if pkgFiles[i].Name() != srcFilePath {
   148  			continue
   149  		}
   150  		var (
   151  			depth int
   152  			item  *GenItem
   153  		)
   154  		ast.Inspect(astFile, func(n ast.Node) (next bool) {
   155  			if n == nil {
   156  				item = nil
   157  				depth--
   159  				return true
   160  			}
   161  			defer func() {
   162  				if next {
   163  					depth++
   164  				}
   165  			}()
   167  			switch v := n.(type) {
   168  			case *ast.FuncDecl, *ast.ValueSpec:
   169  				return false
   171  			case *ast.Ident:
   172  				if item != nil {
   173  					item.Ident = v
   174  				}
   176  				return false
   178  			case *ast.CommentGroup:
   179  				for _, c := range v.List {
   180  					if strings.Contains(strings.TrimPrefix(c.Text, "//"), "gtrace:gen") {
   181  						if item == nil {
   182  							item = &GenItem{}
   183  						}
   184  					}
   185  				}
   187  				return false
   189  			case *ast.StructType:
   190  				if item != nil {
   191  					item.StructType = v
   192  					items = append(items, item)
   193  					item = nil
   194  				}
   196  				return false
   197  			}
   199  			return true
   200  		})
   201  	}
   202  	p := Package{
   203  		Package:          pkg,
   204  		BuildConstraints: buildConstraints,
   205  	}
   206  	traces := make(map[string]*Trace)
   207  	for _, item := range items {
   208  		t := &Trace{
   209  			Name: item.Ident.Name,
   210  		}
   211  		p.Traces = append(p.Traces, t)
   212  		traces[item.Ident.Name] = t
   213  	}
   214  	for i, item := range items {
   215  		t := p.Traces[i]
   216  		for _, field := range item.StructType.Fields.List {
   217  			if _, ok := field.Type.(*ast.FuncType); !ok {
   218  				continue
   219  			}
   220  			name := field.Names[0].Name
   221  			fn, ok := field.Type.(*ast.FuncType)
   222  			if !ok {
   223  				continue
   224  			}
   225  			f, err := buildFunc(info, traces, fn)
   226  			if err != nil {
   227  				log.Printf(
   228  					"skipping hook %s due to error: %v",
   229  					name, err,
   230  				)
   232  				continue
   233  			}
   234  			t.Hooks = append(t.Hooks, Hook{
   235  				Name: name,
   236  				Func: f,
   237  			})
   238  		}
   239  	}
   240  	for _, w := range writers {
   241  		if err := w.Write(p); err != nil {
   242  			panic(err)
   243  		}
   244  	}
   246  	log.Println("OK")
   247  }
   249  func buildFunc(info *types.Info, traces map[string]*Trace, fn *ast.FuncType) (ret *Func, err error) {
   250  	ret = new(Func)
   251  	for _, p := range fn.Params.List {
   252  		t := info.TypeOf(p.Type)
   253  		if t == nil {
   254  			log.Fatalf("unknown type: %s", p.Type)
   255  		}
   256  		var names []string
   257  		for _, n := range p.Names {
   258  			name := n.Name
   259  			if name == "_" {
   260  				name = ""
   261  			}
   262  			names = append(names, name)
   263  		}
   264  		if len(names) == 0 {
   265  			// Case where arg is not named.
   266  			names = []string{""}
   267  		}
   268  		for _, name := range names {
   269  			ret.Params = append(ret.Params, Param{
   270  				Name: name,
   271  				Type: t,
   272  			})
   273  		}
   274  	}
   275  	if fn.Results == nil {
   276  		return ret, nil
   277  	}
   278  	if len(fn.Results.List) > 1 {
   279  		return nil, fmt.Errorf(
   280  			"unsupported number of function results",
   281  		)
   282  	}
   284  	r := fn.Results.List[0]
   286  	switch x := r.Type.(type) {
   287  	case *ast.FuncType:
   288  		result, err := buildFunc(info, traces, x)
   289  		if err != nil {
   290  			return nil, xerrors.WithStackTrace(err)
   291  		}
   292  		ret.Result = append(ret.Result, result)
   294  		return ret, nil
   296  	case *ast.Ident:
   297  		if t, ok := traces[x.Name]; ok {
   298  			t.Nested = true
   299  			ret.Result = append(ret.Result, t)
   301  			return ret, nil
   302  		}
   303  	}
   305  	return nil, fmt.Errorf(
   306  		"unsupported function result type %s",
   307  		info.TypeOf(r.Type),
   308  	)
   309  }
   311  type Package struct {
   312  	*types.Package
   314  	BuildConstraints []string
   315  	Traces           []*Trace
   316  }
   318  type Trace struct {
   319  	Name   string
   320  	Hooks  []Hook
   321  	Nested bool
   322  }
   324  func (*Trace) isFuncResult() bool { return true }
   326  type Hook struct {
   327  	Name string
   328  	Func *Func
   329  }
   331  type Param struct {
   332  	Name string // Might be empty.
   333  	Type types.Type
   334  }
   336  func (p Param) String() string {
   337  	return p.Name + " " + p.Type.String()
   338  }
   340  type FuncResult interface {
   341  	isFuncResult() bool
   342  }
   344  type Func struct {
   345  	Params []Param
   346  	Result []FuncResult // 0 or 1.
   347  }
   349  func (*Func) isFuncResult() bool { return true }
   351  func (f *Func) HasResult() bool {
   352  	return len(f.Result) > 0
   353  }
   355  type GenFlag uint8
   357  func (f GenFlag) Has(x GenFlag) bool {
   358  	return f&x != 0
   359  }
   361  type GenItem struct {
   362  	Ident      *ast.Ident
   363  	StructType *ast.StructType
   364  }
   366  func rsplit(s string, c byte) (s1, s2 string) {
   367  	i := strings.LastIndexByte(s, c)
   368  	if i == -1 {
   369  		return s, ""
   370  	}
   372  	return s[:i], s[i+1:]
   373  }
   375  func scanBuildConstraints(r io.Reader) (cs []string, err error) {
   376  	br := bufio.NewReader(r)
   377  	for {
   378  		line, err := br.ReadBytes('\n')
   379  		if err != nil {
   380  			return nil, xerrors.WithStackTrace(err)
   381  		}
   382  		line = bytes.TrimSpace(line)
   383  		if comm := bytes.TrimPrefix(line, []byte("//")); !bytes.Equal(comm, line) {
   384  			comm = bytes.TrimSpace(comm)
   385  			if bytes.HasPrefix(comm, []byte("+build")) {
   386  				cs = append(cs, string(line))
   388  				continue
   389  			}
   390  		}
   391  		if bytes.HasPrefix(line, []byte("package ")) {
   392  			break
   393  		}
   394  	}
   396  	return cs, nil
   397  }
   399  func isGenerated(base, suffix string) bool {
   400  	i := strings.Index(base, suffix)
   401  	if i == -1 {
   402  		return false
   403  	}
   404  	n := len(base)
   405  	m := i + len(suffix)
   407  	return m == n || base[m] == '_'
   408  }