github.com/gobwas/gtrace@v0.4.3/cmd/gtrace/writer.go (about)

     1  package main
     2  
     3  import (
     4  	"bufio"
     5  	"container/list"
     6  	"crypto/md5"
     7  	"encoding/hex"
     8  	"fmt"
     9  	"go/build"
    10  	"go/token"
    11  	"go/types"
    12  	"io"
    13  	"io/ioutil"
    14  	"path/filepath"
    15  	"runtime"
    16  	"sort"
    17  	"strconv"
    18  	"strings"
    19  	"sync"
    20  	"unicode"
    21  	"unicode/utf8"
    22  )
    23  
    24  type Writer struct {
    25  	Output   io.Writer
    26  	BuildTag string
    27  	Stub     bool
    28  	Context  build.Context
    29  
    30  	once sync.Once
    31  	bw   *bufio.Writer
    32  
    33  	atEOL bool
    34  	depth int
    35  	scope *list.List
    36  
    37  	pkg *types.Package
    38  	std map[string]bool
    39  }
    40  
    41  func (w *Writer) Write(p Package) error {
    42  	w.pkg = p.Package
    43  
    44  	w.init()
    45  	w.line(`// Code generated by gtrace. DO NOT EDIT.`)
    46  
    47  	var hasConstraint bool
    48  	for i, line := range p.BuildConstraints {
    49  		hasConstraint = true
    50  		if i == 0 {
    51  			w.line()
    52  		}
    53  		w.line(line)
    54  	}
    55  	if tag := w.BuildTag; tag != "" {
    56  		if !hasConstraint {
    57  			w.line()
    58  		}
    59  		w.code(`// +build `)
    60  		if w.Stub {
    61  			w.code(`!`)
    62  		}
    63  		w.line(w.BuildTag)
    64  	}
    65  	w.line()
    66  	w.line(`package `, p.Name())
    67  	w.line()
    68  
    69  	var deps []dep
    70  	for _, trace := range p.Traces {
    71  		deps = w.traceImports(deps, trace)
    72  	}
    73  	w.importDeps(deps)
    74  
    75  	w.newScope(func() {
    76  		for _, trace := range p.Traces {
    77  			w.compose(trace)
    78  			if trace.Nested {
    79  				w.isZero(trace)
    80  			}
    81  			if trace.Flag.Has(GenContext) {
    82  				w.context(trace)
    83  			}
    84  			for _, hook := range trace.Hooks {
    85  				if w.Stub {
    86  					w.stubHook(trace, hook)
    87  				} else {
    88  					w.hook(trace, hook)
    89  				}
    90  			}
    91  		}
    92  		for _, trace := range p.Traces {
    93  			for _, hook := range trace.Hooks {
    94  				if !hook.Flag.Has(GenShortcut) {
    95  					continue
    96  				}
    97  				if w.Stub {
    98  					w.stubHookShortcut(trace, hook)
    99  				} else {
   100  					w.hookShortcut(trace, hook)
   101  				}
   102  			}
   103  		}
   104  	})
   105  
   106  	return w.bw.Flush()
   107  }
   108  
   109  func (w *Writer) init() {
   110  	w.once.Do(func() {
   111  		w.bw = bufio.NewWriter(w.Output)
   112  		w.scope = list.New()
   113  	})
   114  }
   115  
   116  func (w *Writer) mustDeclare(name string) {
   117  	s := w.scope.Back().Value.(*scope)
   118  	if !s.set(name) {
   119  		where := s.where(name)
   120  		panic(fmt.Sprintf(
   121  			"gtrace: can't declare identifier: %q: already defined at %q",
   122  			name, where,
   123  		))
   124  	}
   125  }
   126  
   127  func (w *Writer) declare(name string) string {
   128  	if isPredeclared(name) {
   129  		name = firstChar(name)
   130  	}
   131  	s := w.scope.Back().Value.(*scope)
   132  	for i := 0; ; i++ {
   133  		v := name
   134  		if i > 0 {
   135  			v += strconv.Itoa(i)
   136  		}
   137  		if token.IsKeyword(v) {
   138  			continue
   139  		}
   140  		if w.isGlobalScope() && w.pkg.Scope().Lookup(v) != nil {
   141  			continue
   142  		}
   143  		if s.set(v) {
   144  			return v
   145  		}
   146  	}
   147  }
   148  
   149  func isPredeclared(name string) bool {
   150  	return types.Universe.Lookup(name) != nil
   151  }
   152  
   153  func (w *Writer) isGlobalScope() bool {
   154  	return w.scope.Back().Prev() == nil
   155  }
   156  
   157  func (w *Writer) capture(vars ...string) {
   158  	s := w.scope.Back().Value.(*scope)
   159  	for _, v := range vars {
   160  		if !s.set(v) {
   161  			panic(fmt.Sprintf("can't capture variable %q", v))
   162  		}
   163  	}
   164  }
   165  
   166  type dep struct {
   167  	pkgPath string
   168  	pkgName string
   169  	typName string
   170  }
   171  
   172  func (w *Writer) typeImports(dst []dep, t types.Type) []dep {
   173  	if p, ok := t.(*types.Pointer); ok {
   174  		return w.typeImports(dst, p.Elem())
   175  	}
   176  	n, ok := t.(*types.Named)
   177  	if !ok {
   178  		return dst
   179  	}
   180  	var (
   181  		obj = n.Obj()
   182  		pkg = obj.Pkg()
   183  	)
   184  	if pkg != nil && pkg.Path() != w.pkg.Path() {
   185  		return append(dst, dep{
   186  			pkgPath: pkg.Path(),
   187  			pkgName: pkg.Name(),
   188  			typName: obj.Name(),
   189  		})
   190  	}
   191  	return dst
   192  }
   193  
   194  func forEachField(s *types.Struct, fn func(*types.Var)) {
   195  	for i := 0; i < s.NumFields(); i++ {
   196  		fn(s.Field(i))
   197  	}
   198  }
   199  
   200  func unwrapStruct(t types.Type) (n *types.Named, s *types.Struct) {
   201  	var ok bool
   202  	n, ok = t.(*types.Named)
   203  	if ok {
   204  		s, _ = n.Underlying().(*types.Struct)
   205  	}
   206  	return
   207  }
   208  
   209  func (w *Writer) funcImports(dst []dep, flag GenFlag, fn *Func) []dep {
   210  	for _, p := range fn.Params {
   211  		dst = w.typeImports(dst, p.Type)
   212  		if !flag.Has(GenShortcut) {
   213  			continue
   214  		}
   215  		if _, s := unwrapStruct(p.Type); s != nil {
   216  			forEachField(s, func(v *types.Var) {
   217  				if v.Exported() {
   218  					dst = w.typeImports(dst, v.Type())
   219  				}
   220  			})
   221  		}
   222  	}
   223  	for _, x := range fn.Result {
   224  		if fn, ok := x.(*Func); ok {
   225  			dst = w.funcImports(dst, flag, fn)
   226  		}
   227  	}
   228  	return dst
   229  }
   230  
   231  func (w *Writer) traceImports(dst []dep, t *Trace) []dep {
   232  	if t.Flag.Has(GenContext) {
   233  		dst = append(dst, dep{
   234  			pkgPath: "context",
   235  			pkgName: "context",
   236  			typName: "Context",
   237  		})
   238  	}
   239  	for _, h := range t.Hooks {
   240  		dst = w.funcImports(dst, h.Flag, h.Func)
   241  	}
   242  	return dst
   243  }
   244  
   245  func (w *Writer) importDeps(deps []dep) {
   246  	seen := map[string]bool{}
   247  	for i := 0; i < len(deps); {
   248  		d := deps[i]
   249  		if seen[d.pkgPath] {
   250  			n := len(deps)
   251  			deps[i], deps[n-1] = deps[n-1], deps[i]
   252  			deps = deps[:n-1]
   253  			continue
   254  		}
   255  		seen[d.pkgPath] = true
   256  		i++
   257  	}
   258  	if len(deps) == 0 {
   259  		return
   260  	}
   261  	sort.Slice(deps, func(i, j int) bool {
   262  		var (
   263  			d0   = deps[i]
   264  			d1   = deps[j]
   265  			std0 = w.isStdLib(d0.pkgPath)
   266  			std1 = w.isStdLib(d1.pkgPath)
   267  		)
   268  		if std0 != std1 {
   269  			return std0
   270  		}
   271  		return d0.pkgPath < d1.pkgPath
   272  	})
   273  	w.line(`import (`)
   274  	var (
   275  		lastStd bool
   276  	)
   277  	for _, d := range deps {
   278  		if w.isStdLib(d.pkgPath) {
   279  			lastStd = true
   280  		} else if lastStd {
   281  			lastStd = false
   282  			w.line()
   283  		}
   284  		w.line("\t", `"`, d.pkgPath, `"`)
   285  	}
   286  	w.line(`)`)
   287  	w.line()
   288  }
   289  
   290  func (w *Writer) isStdLib(pkg string) bool {
   291  	w.ensureStdLibMapping()
   292  	s := strings.Split(pkg, "/")[0]
   293  	return w.std[s]
   294  }
   295  
   296  func (w *Writer) ensureStdLibMapping() {
   297  	if w.std != nil {
   298  		return
   299  	}
   300  	w.std = make(map[string]bool)
   301  
   302  	src := filepath.Join(w.Context.GOROOT, "src")
   303  	files, err := ioutil.ReadDir(src)
   304  	if err != nil {
   305  		panic(fmt.Sprintf("can't list GOROOT's src: %v", err))
   306  	}
   307  	for _, file := range files {
   308  		if !file.IsDir() {
   309  			continue
   310  		}
   311  		name := filepath.Base(file.Name())
   312  		switch name {
   313  		case
   314  			"cmd",
   315  			"internal":
   316  			// Ignored.
   317  
   318  		default:
   319  			w.std[name] = true
   320  		}
   321  	}
   322  }
   323  
   324  func (w *Writer) call(args []string) {
   325  	w.code(`(`)
   326  	for i, name := range args {
   327  		if i > 0 {
   328  			w.code(`, `)
   329  		}
   330  		w.code(name)
   331  	}
   332  	w.line(`)`)
   333  }
   334  
   335  func (w *Writer) isZero(trace *Trace) {
   336  	w.newScope(func() {
   337  		t := w.declare("t")
   338  		w.line(`// isZero checks whether `, t, ` is empty`)
   339  		w.line(`func (`, t, ` `, trace.Name, `) isZero() bool {`)
   340  		w.block(func() {
   341  			for _, hook := range trace.Hooks {
   342  				w.line(`if `, t, `.`, hook.Name, ` != nil {`)
   343  				w.block(func() {
   344  					w.line(`return false`)
   345  				})
   346  				w.line(`}`)
   347  			}
   348  			w.line(`return true`)
   349  		})
   350  		w.line(`}`)
   351  	})
   352  }
   353  
   354  func (w *Writer) compose(trace *Trace) {
   355  	w.newScope(func() {
   356  		t := w.declare("t")
   357  		x := w.declare("x")
   358  		ret := w.declare("ret")
   359  		w.line(`// Compose returns a new `, trace.Name, ` which has functional fields composed`)
   360  		w.line(`// both from `, t, ` and `, x, `.`)
   361  		w.code(`func (`, t, ` `, trace.Name, `) Compose(`, x, ` `, trace.Name, `) `)
   362  		w.line(`(`, ret, ` `, trace.Name, `) {`)
   363  		w.block(func() {
   364  			for _, hook := range trace.Hooks {
   365  				w.composeHook(hook, t, x, ret+"."+hook.Name)
   366  			}
   367  			w.line(`return `, ret)
   368  		})
   369  		w.line(`}`)
   370  	})
   371  }
   372  
   373  func (w *Writer) composeHook(hook Hook, t1, t2, dst string) {
   374  	w.line(`switch {`)
   375  	w.line(`case `, t1, `.`, hook.Name, ` == nil:`)
   376  	w.line("\t", dst, ` = `, t2, `.`, hook.Name)
   377  	w.line(`case `, t2, `.`, hook.Name, ` == nil:`)
   378  	w.line("\t", dst, ` = `, t1, `.`, hook.Name)
   379  	w.line(`default:`)
   380  	w.block(func() {
   381  		h1 := w.declare("h1")
   382  		h2 := w.declare("h2")
   383  		w.line(h1, ` := `, t1, `.`, hook.Name)
   384  		w.line(h2, ` := `, t2, `.`, hook.Name)
   385  		w.code(dst, ` = `)
   386  		w.composeHookCall(hook.Func, h1, h2)
   387  	})
   388  	w.line(`}`)
   389  }
   390  
   391  func (w *Writer) composeHookCall(fn *Func, h1, h2 string) {
   392  	w.newScope(func() {
   393  		w.capture(h1, h2)
   394  		w.block(func() {
   395  			w.capture(h1, h2)
   396  			w.code(`func`)
   397  			args := w.funcParams(fn.Params)
   398  			w.funcResults(fn)
   399  			w.line(`{`)
   400  			var (
   401  				r1 string
   402  				r2 string
   403  				rs []string
   404  			)
   405  			if fn.HasResult() {
   406  				r1 = w.declare("r1")
   407  				r2 = w.declare("r2")
   408  				rs = []string{r1, r2}
   409  			}
   410  			for i, h := range []string{h1, h2} {
   411  				if fn.HasResult() {
   412  					w.code(rs[i], ` := `)
   413  				}
   414  				w.code(h)
   415  				w.call(args)
   416  			}
   417  			if fn.HasResult() {
   418  				w.line(`switch {`)
   419  
   420  				w.code(`case `)
   421  				w.isEmptyResult(r1, fn.Result[0])
   422  				w.line(`:`)
   423  				w.line("\t", `return `, r2)
   424  
   425  				w.code(`case `)
   426  				w.isEmptyResult(r2, fn.Result[0])
   427  				w.line(`:`)
   428  				w.line("\t", `return `, r1)
   429  
   430  				w.line(`default:`)
   431  				w.block(func() {
   432  					w.code(`return `)
   433  					switch x := fn.Result[0].(type) {
   434  					case *Func:
   435  						w.composeHookCall(x, r1, r2)
   436  					case *Trace:
   437  						w.line(r1, `.Compose(`, r2, `)`)
   438  					default:
   439  						panic("unknown result type")
   440  					}
   441  				})
   442  				w.line(`}`)
   443  			}
   444  		})
   445  		w.line(`}`)
   446  	})
   447  }
   448  
   449  func (w *Writer) isEmptyResult(name string, r FuncResult) {
   450  	switch r.(type) {
   451  	case *Func:
   452  		w.code(name, ` == nil`)
   453  	case *Trace:
   454  		w.code(name, `.isZero()`)
   455  	default:
   456  		panic("unknown result type")
   457  	}
   458  }
   459  
   460  var contextType = (func() types.Type {
   461  	pkg := types.NewPackage("context", "context")
   462  	typ := types.NewInterfaceType(nil, nil)
   463  	name := types.NewTypeName(0, pkg, "Context", typ)
   464  	return types.NewNamed(name, typ, nil)
   465  })()
   466  
   467  func (w *Writer) stubTrace(id string, t *Trace) (name string) {
   468  	name = tempName("gtrace", "noop", t.Name, id)
   469  	name = unexported(name)
   470  	name = w.declare(name)
   471  	w.line(`var `, name, ` `, t.Name)
   472  	return name
   473  }
   474  
   475  func (w *Writer) stubFunc(id string, f *Func) (name string) {
   476  	name = tempName("gtrace", "noop", id)
   477  	name = unexported(name)
   478  	name = w.declare(name)
   479  
   480  	var res string
   481  	for _, r := range f.Result {
   482  		switch x := r.(type) {
   483  		case *Func:
   484  			res = w.stubFunc(id, x)
   485  		case *Trace:
   486  			res = w.stubTrace(id, x)
   487  		default:
   488  			panic("unknown result type")
   489  		}
   490  	}
   491  	w.newScope(func() {
   492  		w.code(`func `, name)
   493  		w.funcParamsUnused(f.Params)
   494  		w.funcResults(f)
   495  		w.line(`{`)
   496  		if f.HasResult() {
   497  			w.block(func() {
   498  				w.line(`return `, res)
   499  			})
   500  		}
   501  		w.line(`}`)
   502  	})
   503  
   504  	return name
   505  }
   506  
   507  func (w *Writer) stubHook(trace *Trace, hook Hook) {
   508  	var stubName string
   509  	for _, r := range hook.Func.Result {
   510  		switch x := r.(type) {
   511  		case *Func:
   512  			stubName = w.stubFunc(uniqueTraceHookID(trace, hook), x)
   513  		case *Trace:
   514  			stubName = w.stubTrace(uniqueTraceID(x), x)
   515  		default:
   516  			panic("unexpected result type")
   517  		}
   518  	}
   519  	haveNames := haveNames(hook.Func.Params)
   520  	w.newScope(func() {
   521  		w.code(`func (`, trace.Name, `) `, unexported(hook.Name))
   522  		w.code(`(`)
   523  		if trace.Flag.Has(GenContext) {
   524  			if haveNames {
   525  				ctx := w.declare("ctx")
   526  				w.code(ctx, ` `)
   527  			}
   528  			w.code(`context.Context`)
   529  		}
   530  		for i, p := range hook.Func.Params {
   531  			if i > 0 || trace.Flag.Has(GenContext) {
   532  				w.code(`, `)
   533  			}
   534  			if haveNames {
   535  				name := w.declare(nameParam(p))
   536  				w.code(name, ` `)
   537  			}
   538  			w.code(w.typeString(p.Type))
   539  		}
   540  		w.code(`) `)
   541  		w.funcResultsFlags(hook.Func, docs)
   542  		w.line(`{`)
   543  		if hook.Func.HasResult() {
   544  			w.block(func() {
   545  				w.line(`return `, stubName)
   546  			})
   547  		}
   548  		w.line(`}`)
   549  	})
   550  }
   551  
   552  func (w *Writer) stubShortcutFunc(id string, f *Func) (name string) {
   553  	name = tempName("gtrace", "noop", id)
   554  	name = w.declare(name)
   555  
   556  	var res string
   557  	for _, r := range f.Result {
   558  		switch x := r.(type) {
   559  		case *Func:
   560  			res = w.stubShortcutFunc(id, x)
   561  		case *Trace:
   562  			res = w.stubTrace(id, x)
   563  		default:
   564  			panic("unexpected result type")
   565  		}
   566  	}
   567  	w.newScope(func() {
   568  		w.code(`func `, name)
   569  		w.code(`(`)
   570  		params := flattenParams(nil, f.Params)
   571  		for i, p := range params {
   572  			if i > 0 {
   573  				w.code(`, `)
   574  			}
   575  			w.code(w.typeString(p.Type))
   576  		}
   577  		w.code(`) `)
   578  		for _, r := range f.Result {
   579  			switch x := r.(type) {
   580  			case *Func:
   581  				w.shortcutFuncSign(x)
   582  			case *Trace:
   583  				w.line(x.Name, ` `)
   584  			default:
   585  				panic("unexpected result type")
   586  			}
   587  		}
   588  		w.line(`{`)
   589  		if f.HasResult() {
   590  			w.block(func() {
   591  				w.line(`return `, res)
   592  			})
   593  		}
   594  		w.line(`}`)
   595  	})
   596  
   597  	return name
   598  }
   599  
   600  func (w *Writer) stubHookShortcut(trace *Trace, hook Hook) {
   601  	name := tempName(trace.Name, hook.Name)
   602  	name = unexported(name)
   603  	w.mustDeclare(name)
   604  
   605  	id := uniqueTraceHookID(trace, hook)
   606  
   607  	var stubName string
   608  	for _, r := range hook.Func.Result {
   609  		switch x := r.(type) {
   610  		case *Func:
   611  			stubName = w.stubShortcutFunc(id, x)
   612  		case *Trace:
   613  			stubName = w.stubTrace(id, x)
   614  		default:
   615  			panic("unexpected result type")
   616  		}
   617  	}
   618  
   619  	params := flattenParams(nil, hook.Func.Params)
   620  	haveNames := haveNames(params)
   621  
   622  	w.newScope(func() {
   623  		w.code(`func `, name)
   624  		w.code(`(`)
   625  		if trace.Flag.Has(GenContext) {
   626  			if haveNames {
   627  				ctx := w.declare("ctx")
   628  				w.code(ctx, ` `)
   629  			}
   630  			w.code(`context.Context, `)
   631  		}
   632  
   633  		if haveNames {
   634  			t := w.declare("t")
   635  			w.code(t, ` `)
   636  		}
   637  		w.code(trace.Name)
   638  
   639  		for _, p := range params {
   640  			w.code(`, `)
   641  			if haveNames {
   642  				name := w.declare(nameParam(p))
   643  				w.code(name, ` `)
   644  			}
   645  			w.code(w.typeString(p.Type))
   646  		}
   647  		w.code(`) `)
   648  		w.shortcutFuncResultsFlags(hook.Func, docs)
   649  		w.line(`{`)
   650  		if hook.Func.HasResult() {
   651  			w.block(func() {
   652  				w.line(`return `, stubName)
   653  			})
   654  		}
   655  		w.line(`}`)
   656  	})
   657  }
   658  
   659  func (w *Writer) hook(trace *Trace, hook Hook) {
   660  	w.newScope(func() {
   661  		t := w.declare("t")
   662  		x := w.declare("c") // For context's trace.
   663  		fn := w.declare("fn")
   664  
   665  		w.code(`func (`, t, ` `, trace.Name, `) `, unexported(hook.Name))
   666  
   667  		w.code(`(`)
   668  		var ctx string
   669  		if trace.Flag.Has(GenContext) {
   670  			ctx = w.declare("ctx")
   671  			w.code(ctx, ` context.Context`)
   672  		}
   673  		var args []string
   674  		for i, p := range hook.Func.Params {
   675  			if i > 0 || ctx != "" {
   676  				w.code(`, `)
   677  			}
   678  			args = append(args, w.funcParam(p))
   679  		}
   680  		w.code(`) `)
   681  		w.funcResultsFlags(hook.Func, docs)
   682  		w.line(`{`)
   683  		w.block(func() {
   684  			if ctx != "" {
   685  				w.line(x, ` := Context`, trace.Name, `(`, ctx, `)`)
   686  				w.code(`var fn `)
   687  				w.funcSignature(hook.Func)
   688  				w.line()
   689  				w.composeHook(hook, t, x, fn)
   690  			} else {
   691  				w.line(fn, ` := `, t, `.`, hook.Name)
   692  			}
   693  			w.line(`if `, fn, ` == nil {`)
   694  			w.block(func() {
   695  				w.zeroReturn(hook.Func)
   696  			})
   697  			w.line(`}`)
   698  
   699  			w.hookFuncCall(hook.Func, fn, args)
   700  		})
   701  		w.line(`}`)
   702  	})
   703  }
   704  
   705  func (w *Writer) hookFuncCall(fn *Func, name string, args []string) {
   706  	var res string
   707  	if fn.HasResult() {
   708  		res = w.declare("res")
   709  		w.code(res, ` := `)
   710  	}
   711  
   712  	w.code(name)
   713  	w.call(args)
   714  
   715  	if !fn.HasResult() {
   716  		return
   717  	}
   718  
   719  	r, isFunc := fn.Result[0].(*Func)
   720  	if isFunc {
   721  		w.line(`if `, res, ` == nil {`)
   722  		w.block(func() {
   723  			w.zeroReturn(fn)
   724  		})
   725  		w.line(`}`)
   726  
   727  		if r.HasResult() {
   728  			w.newScope(func() {
   729  				w.code(`return func`)
   730  				args := w.funcParams(r.Params)
   731  				w.funcResults(r)
   732  				w.line(`{`)
   733  				w.block(func() {
   734  					w.hookFuncCall(r, res, args)
   735  				})
   736  				w.line(`}`)
   737  			})
   738  			return
   739  		}
   740  	}
   741  
   742  	w.line(`return `, res)
   743  }
   744  
   745  func (w *Writer) context(trace *Trace) {
   746  	w.line()
   747  	w.line(`type `, unexported(trace.Name), `ContextKey struct{}`)
   748  	w.line()
   749  
   750  	w.newScope(func() {
   751  		var (
   752  			ctx = w.declare("ctx")
   753  			t   = w.declare("t")
   754  		)
   755  		w.line(`// With`, trace.Name, ` returns context which has associated `, trace.Name, ` with it.`)
   756  		w.code(`func With`, trace.Name, `(`)
   757  		w.code(ctx, ` context.Context, `)
   758  		w.code(t, ` `, trace.Name, `) `)
   759  		w.line(`context.Context {`)
   760  		w.block(func() {
   761  			w.line(`return context.WithValue(`, ctx, `,`)
   762  			w.line("\t", unexported(trace.Name), `ContextKey{},`)
   763  			w.line("\t", `Context`, trace.Name, `(`, ctx, `).Compose(`, t, `),`)
   764  			w.line(`)`)
   765  		})
   766  		w.line(`}`)
   767  		w.line()
   768  	})
   769  	w.newScope(func() {
   770  		var (
   771  			ctx = w.declare("ctx")
   772  			t   = w.declare("t")
   773  		)
   774  		w.line(`// Context`, trace.Name, ` returns `, trace.Name, ` associated with `, ctx, `.`)
   775  		w.line(`// If there is no `, trace.Name, ` associated with `, ctx, ` then zero value `)
   776  		w.line(`// of `, trace.Name, ` is returned.`)
   777  		w.code(`func Context`, trace.Name, `(`, ctx, ` context.Context) `)
   778  		w.line(trace.Name, ` {`)
   779  		w.block(func() {
   780  			w.code(t, `, _ := ctx.Value(`, unexported(trace.Name), `ContextKey{})`)
   781  			w.line(`.(`, trace.Name, `)`)
   782  			w.line(`return `, t)
   783  		})
   784  		w.line(`}`)
   785  		w.line()
   786  	})
   787  }
   788  
   789  func nameParam(p Param) (s string) {
   790  	s = p.Name
   791  	if s == "" {
   792  		s = firstChar(ident(typeBasename(p.Type)))
   793  	}
   794  	return unexported(s)
   795  }
   796  
   797  func (w *Writer) declareParams(src []Param) (names []string) {
   798  	names = make([]string, len(src))
   799  	for i, p := range src {
   800  		names[i] = w.declare(nameParam(p))
   801  	}
   802  	return names
   803  }
   804  
   805  func flattenParams(dst, src []Param) []Param {
   806  	for _, p := range src {
   807  		_, s := unwrapStruct(p.Type)
   808  		if s != nil {
   809  			dst = flattenStruct(dst, s)
   810  			continue
   811  		}
   812  		dst = append(dst, p)
   813  	}
   814  	return dst
   815  }
   816  
   817  func typeBasename(t types.Type) (name string) {
   818  	lo, name := rsplit(t.String(), '.')
   819  	if name == "" {
   820  		name = lo
   821  	}
   822  	return name
   823  }
   824  
   825  func flattenStruct(dst []Param, s *types.Struct) []Param {
   826  	forEachField(s, func(f *types.Var) {
   827  		if !f.Exported() {
   828  			return
   829  		}
   830  		var (
   831  			name = f.Name()
   832  			typ  = f.Type()
   833  		)
   834  		if name == typeBasename(typ) {
   835  			// NOTE: field name essentially be empty for embeded structs or
   836  			// fields called exactly as type.
   837  			name = ""
   838  		}
   839  		dst = append(dst, Param{
   840  			Name: name,
   841  			Type: typ,
   842  		})
   843  	})
   844  	return dst
   845  }
   846  
   847  func (w *Writer) constructParams(params []Param, names []string) (res []string) {
   848  	for _, p := range params {
   849  		n, s := unwrapStruct(p.Type)
   850  		if s != nil {
   851  			var v string
   852  			v, names = w.constructStruct(n, s, names)
   853  			res = append(res, v)
   854  			continue
   855  		}
   856  		name := names[0]
   857  		names = names[1:]
   858  		res = append(res, name)
   859  	}
   860  	return res
   861  }
   862  
   863  func (w *Writer) constructStruct(n *types.Named, s *types.Struct, vars []string) (string, []string) {
   864  	p := w.declare("p")
   865  	// TODO Ptr
   866  	// maybe skip pointers from flattening to not allocate anyhing during trace.
   867  	w.line(`var `, p, ` `, w.typeString(n))
   868  	for i := 0; i < s.NumFields(); i++ {
   869  		v := s.Field(i)
   870  		if !v.Exported() {
   871  			continue
   872  		}
   873  		name := vars[0]
   874  		vars = vars[1:]
   875  		w.line(p, `.`, v.Name(), ` = `, name)
   876  	}
   877  	return p, vars
   878  }
   879  
   880  func (w *Writer) hookShortcut(trace *Trace, hook Hook) {
   881  	name := tempName(trace.Name, hook.Name)
   882  	name = unexported(name)
   883  	w.mustDeclare(name)
   884  
   885  	w.newScope(func() {
   886  		t := w.declare("t")
   887  		w.code(`func `, name)
   888  		w.code(`(`)
   889  		var ctx string
   890  		if trace.Flag.Has(GenContext) {
   891  			ctx = w.declare("ctx")
   892  			w.code(ctx, ` context.Context`)
   893  			w.code(`, `)
   894  		}
   895  		w.code(t, ` `, trace.Name)
   896  
   897  		var (
   898  			params = flattenParams(nil, hook.Func.Params)
   899  			names  = w.declareParams(params)
   900  		)
   901  		for i, p := range params {
   902  			w.code(`, `)
   903  			w.code(names[i], ` `, w.typeString(p.Type))
   904  		}
   905  		w.code(`) `)
   906  		w.shortcutFuncResultsFlags(hook.Func, docs)
   907  		w.line(`{`)
   908  		w.block(func() {
   909  			for _, name := range names {
   910  				w.capture(name)
   911  			}
   912  			vars := w.constructParams(hook.Func.Params, names)
   913  			var res string
   914  			if hook.Func.HasResult() {
   915  				res = w.declare("res")
   916  				w.code(res, ` := `)
   917  			}
   918  			w.code(t, `.`, unexported(hook.Name))
   919  			if ctx != "" {
   920  				vars = append([]string{ctx}, vars...)
   921  			}
   922  			w.call(vars)
   923  			if hook.Func.HasResult() {
   924  				w.code(`return `)
   925  				r := hook.Func.Result[0]
   926  				switch x := r.(type) {
   927  				case *Func:
   928  					w.hookFuncShortcut(x, res)
   929  				case *Trace:
   930  					w.line(res)
   931  				default:
   932  					panic("unexpected result type")
   933  				}
   934  			}
   935  		})
   936  		w.line(`}`)
   937  	})
   938  }
   939  
   940  func (w *Writer) hookFuncShortcut(fn *Func, name string) {
   941  	w.newScope(func() {
   942  		w.code(`func(`)
   943  		var (
   944  			params = flattenParams(nil, fn.Params)
   945  			names  = w.declareParams(params)
   946  		)
   947  		for i, p := range params {
   948  			if i > 0 {
   949  				w.code(`, `)
   950  			}
   951  			w.code(names[i], ` `, w.typeString(p.Type))
   952  		}
   953  		w.code(`) `)
   954  		w.shortcutFuncResults(fn)
   955  		w.line(`{`)
   956  		w.block(func() {
   957  			for _, name := range names {
   958  				w.capture(name)
   959  			}
   960  			params := w.constructParams(fn.Params, names)
   961  			var res string
   962  			if fn.HasResult() {
   963  				res = w.declare("res")
   964  				w.code(res, ` := `)
   965  			}
   966  			w.code(name)
   967  			w.call(params)
   968  			if fn.HasResult() {
   969  				r := fn.Result[0]
   970  				w.code(`return `)
   971  				switch x := r.(type) {
   972  				case *Func:
   973  					w.hookFuncShortcut(x, res)
   974  				case *Trace:
   975  					w.line(res)
   976  				default:
   977  					panic("unexpected result type")
   978  				}
   979  			}
   980  		})
   981  		w.line(`}`)
   982  	})
   983  }
   984  
   985  func (w *Writer) zeroReturn(fn *Func) {
   986  	if !fn.HasResult() {
   987  		w.line(`return`)
   988  		return
   989  	}
   990  	w.code(`return `)
   991  	switch x := fn.Result[0].(type) {
   992  	case *Func:
   993  		w.funcSignature(x)
   994  		w.line(`{`)
   995  		w.block(func() {
   996  			w.zeroReturn(x)
   997  		})
   998  		w.line(`}`)
   999  	case *Trace:
  1000  		w.line(x.Name, `{}`)
  1001  	default:
  1002  		panic("unexpected result type")
  1003  	}
  1004  }
  1005  
  1006  func (w *Writer) funcParams(params []Param) (vars []string) {
  1007  	w.code(`(`)
  1008  	for i, p := range params {
  1009  		if i > 0 {
  1010  			w.code(`, `)
  1011  		}
  1012  		vars = append(vars, w.funcParam(p))
  1013  	}
  1014  	w.code(`) `)
  1015  	return
  1016  }
  1017  
  1018  func (w *Writer) funcParamsUnused(params []Param) {
  1019  	w.code(`(`)
  1020  	for i, p := range params {
  1021  		if i > 0 {
  1022  			w.code(`, `)
  1023  		}
  1024  		w.code(w.typeString(p.Type))
  1025  	}
  1026  	w.code(`) `)
  1027  }
  1028  
  1029  func (w *Writer) funcParam(p Param) (name string) {
  1030  	name = w.declare(nameParam(p))
  1031  	w.code(name, ` `)
  1032  	w.code(w.typeString(p.Type))
  1033  	return name
  1034  }
  1035  
  1036  func (w *Writer) funcParamSign(p Param) {
  1037  	name := nameParam(p)
  1038  	if len(name) == 1 || isPredeclared(name) {
  1039  		name = "_"
  1040  	}
  1041  	w.code(name, ` `)
  1042  	w.code(w.typeString(p.Type))
  1043  }
  1044  
  1045  type flags uint8
  1046  
  1047  func (f flags) has(x flags) bool {
  1048  	return f&x != 0
  1049  }
  1050  
  1051  const (
  1052  	zeroFlags flags = 1 << iota >> 1
  1053  	docs
  1054  )
  1055  
  1056  func (w *Writer) funcResultsFlags(fn *Func, flags flags) {
  1057  	for _, r := range fn.Result {
  1058  		switch x := r.(type) {
  1059  		case *Func:
  1060  			w.funcSignatureFlags(x, flags)
  1061  		case *Trace:
  1062  			w.code(x.Name, ` `)
  1063  		default:
  1064  			panic("unexpected result type")
  1065  		}
  1066  	}
  1067  }
  1068  
  1069  func (w *Writer) funcResults(fn *Func) {
  1070  	w.funcResultsFlags(fn, 0)
  1071  }
  1072  
  1073  func (w *Writer) funcSignatureFlags(fn *Func, flags flags) {
  1074  	haveNames := haveNames(fn.Params)
  1075  	w.code(`func(`)
  1076  	for i, p := range fn.Params {
  1077  		if i > 0 {
  1078  			w.code(`, `)
  1079  		}
  1080  		if flags.has(docs) && haveNames {
  1081  			w.funcParamSign(p)
  1082  		} else {
  1083  			w.code(w.typeString(p.Type))
  1084  		}
  1085  	}
  1086  	w.code(`) `)
  1087  	w.funcResultsFlags(fn, flags)
  1088  }
  1089  
  1090  func (w *Writer) funcSignature(fn *Func) {
  1091  	w.funcSignatureFlags(fn, 0)
  1092  }
  1093  
  1094  func (w *Writer) shortcutFuncSignFlags(fn *Func, flags flags) {
  1095  	var (
  1096  		params    = flattenParams(nil, fn.Params)
  1097  		haveNames = haveNames(params)
  1098  	)
  1099  	w.code(`func(`)
  1100  	for i, p := range params {
  1101  		if i > 0 {
  1102  			w.code(`, `)
  1103  		}
  1104  		if flags.has(docs) && haveNames {
  1105  			w.funcParamSign(p)
  1106  		} else {
  1107  			w.code(w.typeString(p.Type))
  1108  		}
  1109  	}
  1110  	w.code(`) `)
  1111  	w.shortcutFuncResultsFlags(fn, flags)
  1112  }
  1113  
  1114  func (w *Writer) shortcutFuncSign(fn *Func) {
  1115  	w.shortcutFuncSignFlags(fn, 0)
  1116  }
  1117  
  1118  func (w *Writer) shortcutFuncResultsFlags(fn *Func, flags flags) {
  1119  	for _, r := range fn.Result {
  1120  		switch x := r.(type) {
  1121  		case *Func:
  1122  			w.shortcutFuncSignFlags(x, flags)
  1123  		case *Trace:
  1124  			w.code(x.Name, ` `)
  1125  		default:
  1126  			panic("unexpected result type")
  1127  		}
  1128  	}
  1129  }
  1130  
  1131  func (w *Writer) shortcutFuncResults(fn *Func) {
  1132  	w.shortcutFuncResultsFlags(fn, 0)
  1133  }
  1134  
  1135  func haveNames(params []Param) bool {
  1136  	for _, p := range params {
  1137  		name := nameParam(p)
  1138  		if len(name) > 1 && !isPredeclared(name) {
  1139  			return true
  1140  		}
  1141  	}
  1142  	return false
  1143  }
  1144  
  1145  func (w *Writer) typeString(t types.Type) string {
  1146  	return types.TypeString(t, func(pkg *types.Package) string {
  1147  		if pkg.Path() == w.pkg.Path() {
  1148  			return "" // same package; unqualified
  1149  		}
  1150  		return pkg.Name()
  1151  	})
  1152  }
  1153  
  1154  func (w *Writer) block(fn func()) {
  1155  	w.depth++
  1156  	w.newScope(fn)
  1157  	w.depth--
  1158  }
  1159  
  1160  func (w *Writer) newScope(fn func()) {
  1161  	w.scope.PushBack(new(scope))
  1162  	fn()
  1163  	w.scope.Remove(w.scope.Back())
  1164  }
  1165  
  1166  func (w *Writer) line(args ...string) {
  1167  	w.code(args...)
  1168  	w.bw.WriteByte('\n')
  1169  	w.atEOL = true
  1170  }
  1171  
  1172  func (w *Writer) code(args ...string) {
  1173  	if w.atEOL {
  1174  		for i := 0; i < w.depth; i++ {
  1175  			w.bw.WriteByte('\t')
  1176  		}
  1177  		w.atEOL = false
  1178  	}
  1179  	for _, arg := range args {
  1180  		w.bw.WriteString(arg)
  1181  	}
  1182  }
  1183  
  1184  func exported(s string) string {
  1185  	r, size := utf8.DecodeRuneInString(s)
  1186  	if r == utf8.RuneError {
  1187  		panic("invalid string")
  1188  	}
  1189  	return string(unicode.ToUpper(r)) + s[size:]
  1190  }
  1191  
  1192  func unexported(s string) string {
  1193  	r, size := utf8.DecodeRuneInString(s)
  1194  	if r == utf8.RuneError {
  1195  		panic("invalid string")
  1196  	}
  1197  	return string(unicode.ToLower(r)) + s[size:]
  1198  }
  1199  
  1200  func firstChar(s string) string {
  1201  	r, _ := utf8.DecodeRuneInString(s)
  1202  	if r == utf8.RuneError {
  1203  		panic("invalid string")
  1204  	}
  1205  	return string(r)
  1206  }
  1207  
  1208  func ident(s string) string {
  1209  	// Identifier must not begin with number.
  1210  	for len(s) > 0 {
  1211  		r, size := utf8.DecodeRuneInString(s)
  1212  		if r == utf8.RuneError {
  1213  			panic("invalid string")
  1214  		}
  1215  		if !unicode.IsNumber(r) {
  1216  			break
  1217  		}
  1218  		s = s[size:]
  1219  	}
  1220  
  1221  	// Filter out non letter/number/underscore characters.
  1222  	s = strings.Map(func(r rune) rune {
  1223  		switch {
  1224  		case r == '_' ||
  1225  			unicode.IsLetter(r) ||
  1226  			unicode.IsNumber(r):
  1227  
  1228  			return r
  1229  		default:
  1230  			return -1
  1231  		}
  1232  	}, s)
  1233  
  1234  	if !token.IsIdentifier(s) {
  1235  		s = "_" + s
  1236  	}
  1237  
  1238  	return s
  1239  }
  1240  
  1241  func tempName(names ...string) string {
  1242  	var sb strings.Builder
  1243  	for i, name := range names {
  1244  		if i == 0 {
  1245  			name = unexported(name)
  1246  		} else {
  1247  			name = exported(name)
  1248  		}
  1249  		sb.WriteString(name)
  1250  	}
  1251  	return sb.String()
  1252  }
  1253  
  1254  type decl struct {
  1255  	where string
  1256  }
  1257  
  1258  type scope struct {
  1259  	vars map[string]decl
  1260  }
  1261  
  1262  func (s *scope) set(v string) bool {
  1263  	if s.vars == nil {
  1264  		s.vars = make(map[string]decl)
  1265  	}
  1266  	if _, has := s.vars[v]; has {
  1267  		return false
  1268  	}
  1269  	_, file, line, _ := runtime.Caller(2)
  1270  	s.vars[v] = decl{
  1271  		where: fmt.Sprintf("%s:%d", file, line),
  1272  	}
  1273  	return true
  1274  }
  1275  
  1276  func (s *scope) where(v string) string {
  1277  	d := s.vars[v]
  1278  	return d.where
  1279  }
  1280  
  1281  func uniqueTraceID(t *Trace) string {
  1282  	hash := md5.New()
  1283  	io.WriteString(hash, t.Name)
  1284  	p := hash.Sum(nil)
  1285  	s := hex.EncodeToString(p)
  1286  	return s[:8]
  1287  }
  1288  
  1289  func uniqueTraceHookID(t *Trace, h Hook) string {
  1290  	hash := md5.New()
  1291  	io.WriteString(hash, t.Name)
  1292  	io.WriteString(hash, h.Name)
  1293  	p := hash.Sum(nil)
  1294  	s := hex.EncodeToString(p)
  1295  	return s[:8]
  1296  }