
     1  // keyify transforms unkeyed struct literals into a keyed ones.
     2  package main
     4  import (
     5  	"bytes"
     6  	"encoding/json"
     7  	"flag"
     8  	"fmt"
     9  	"go/ast"
    10  	"go/build"
    11  	"go/constant"
    12  	"go/printer"
    13  	"go/token"
    14  	"go/types"
    15  	"log"
    16  	"os"
    17  	"path/filepath"
    19  	""
    21  	""
    22  	""
    23  	""
    24  )
    26  var (
    27  	fRecursive bool
    28  	fOneLine   bool
    29  	fJSON      bool
    30  	fMinify    bool
    31  	fModified  bool
    32  	fVersion   bool
    33  )
    35  func init() {
    36  	flag.BoolVar(&fRecursive, "r", false, "keyify struct initializers recursively")
    37  	flag.BoolVar(&fOneLine, "o", false, "print new struct initializer on a single line")
    38  	flag.BoolVar(&fJSON, "json", false, "print new struct initializer as JSON")
    39  	flag.BoolVar(&fMinify, "m", false, "omit fields that are set to their zero value")
    40  	flag.BoolVar(&fModified, "modified", false, "read an archive of modified files from standard input")
    41  	flag.BoolVar(&fVersion, "version", false, "Print version and exit")
    42  }
    44  func usage() {
    45  	fmt.Printf("Usage: %s [flags] <position>\n\n", os.Args[0])
    46  	flag.PrintDefaults()
    47  }
    49  func main() {
    50  	log.SetFlags(0)
    51  	flag.Usage = usage
    52  	flag.Parse()
    54  	if fVersion {
    55  		version.Print()
    56  		os.Exit(0)
    57  	}
    59  	if flag.NArg() != 1 {
    60  		flag.Usage()
    61  		os.Exit(2)
    62  	}
    63  	pos := flag.Args()[0]
    64  	name, start, _, err := parsePos(pos)
    65  	if err != nil {
    66  		log.Fatal(err)
    67  	}
    68  	eval, err := filepath.EvalSymlinks(name)
    69  	if err != nil {
    70  		log.Fatal(err)
    71  	}
    72  	name, err = filepath.Abs(eval)
    73  	if err != nil {
    74  		log.Fatal(err)
    75  	}
    76  	cwd, err := os.Getwd()
    77  	if err != nil {
    78  		log.Fatal(err)
    79  	}
    80  	ctx := &build.Default
    81  	if fModified {
    82  		overlay, err := buildutil.ParseOverlayArchive(os.Stdin)
    83  		if err != nil {
    84  			log.Fatal(err)
    85  		}
    86  		ctx = buildutil.OverlayContext(ctx, overlay)
    87  	}
    88  	bpkg, err := buildutil.ContainingPackage(ctx, cwd, name)
    89  	if err != nil {
    90  		log.Fatal(err)
    91  	}
    92  	conf := &loader.Config{
    93  		Build: ctx,
    94  	}
    95  	conf.TypeCheckFuncBodies = func(s string) bool {
    96  		return s == bpkg.ImportPath || s == bpkg.ImportPath+"_test"
    97  	}
    98  	conf.ImportWithTests(bpkg.ImportPath)
    99  	lprog, err := conf.Load()
   100  	if err != nil {
   101  		log.Fatal(err)
   102  	}
   103  	var tf *token.File
   104  	var af *ast.File
   105  	var pkg *loader.PackageInfo
   106  outer:
   107  	for _, pkg = range lprog.InitialPackages() {
   108  		for _, ff := range pkg.Files {
   109  			file := lprog.Fset.File(ff.Pos())
   110  			if file.Name() == name {
   111  				af = ff
   112  				tf = file
   113  				break outer
   114  			}
   115  		}
   116  	}
   117  	if tf == nil {
   118  		log.Fatalf("couldn't find file %s", name)
   119  	}
   120  	tstart, tend, err := fileOffsetToPos(tf, start, start)
   121  	if err != nil {
   122  		log.Fatal(err)
   123  	}
   124  	path, _ := astutil.PathEnclosingInterval(af, tstart, tend)
   125  	var complit *ast.CompositeLit
   126  	for _, p := range path {
   127  		if p, ok := p.(*ast.CompositeLit); ok {
   128  			complit = p
   129  			break
   130  		}
   131  	}
   132  	if complit == nil {
   133  		log.Fatal("no composite literal found near point")
   134  	}
   135  	if len(complit.Elts) == 0 {
   136  		printComplit(complit, complit, lprog.Fset, lprog.Fset)
   137  		return
   138  	}
   139  	if _, ok := complit.Elts[0].(*ast.KeyValueExpr); ok {
   140  		lit := complit
   141  		if fOneLine {
   142  			lit = copyExpr(complit, 1).(*ast.CompositeLit)
   143  		}
   144  		printComplit(complit, lit, lprog.Fset, lprog.Fset)
   145  		return
   146  	}
   147  	_, ok := pkg.TypeOf(complit).Underlying().(*types.Struct)
   148  	if !ok {
   149  		log.Fatal("not a struct initialiser")
   150  		return
   151  	}
   153  	newComplit, lines := keyify(pkg, complit)
   154  	newFset := token.NewFileSet()
   155  	newFile := newFset.AddFile("", -1, lines)
   156  	for i := 1; i <= lines; i++ {
   157  		newFile.AddLine(i)
   158  	}
   159  	printComplit(complit, newComplit, lprog.Fset, newFset)
   160  }
   162  func keyify(
   163  	pkg *loader.PackageInfo,
   164  	complit *ast.CompositeLit,
   165  ) (*ast.CompositeLit, int) {
   166  	var calcPos func(int) token.Pos
   167  	if fOneLine {
   168  		calcPos = func(int) token.Pos { return token.Pos(1) }
   169  	} else {
   170  		calcPos = func(i int) token.Pos { return token.Pos(2 + i) }
   171  	}
   173  	st, _ := pkg.TypeOf(complit).Underlying().(*types.Struct)
   174  	newComplit := &ast.CompositeLit{
   175  		Type:   complit.Type,
   176  		Lbrace: 1,
   177  		Rbrace: token.Pos(st.NumFields() + 2),
   178  	}
   179  	if fOneLine {
   180  		newComplit.Rbrace = 1
   181  	}
   182  	numLines := 2 + st.NumFields()
   183  	n := 0
   184  	for i := 0; i < st.NumFields(); i++ {
   185  		field := st.Field(i)
   186  		val := complit.Elts[i]
   187  		if fRecursive {
   188  			if val2, ok := val.(*ast.CompositeLit); ok {
   189  				if _, ok := pkg.TypeOf(val2.Type).Underlying().(*types.Struct); ok {
   190  					var lines int
   191  					numLines += lines
   192  					val, lines = keyify(pkg, val2)
   193  				}
   194  			}
   195  		}
   196  		_, isIface := st.Field(i).Type().Underlying().(*types.Interface)
   197  		if fMinify && (isNil(val, pkg) || (!isIface && isZero(val, pkg))) {
   198  			continue
   199  		}
   200  		elt := &ast.KeyValueExpr{
   201  			Key:   &ast.Ident{NamePos: calcPos(n), Name: field.Name()},
   202  			Value: copyExpr(val, calcPos(n)),
   203  		}
   204  		newComplit.Elts = append(newComplit.Elts, elt)
   205  		n++
   206  	}
   207  	return newComplit, numLines
   208  }
   210  func isNil(val ast.Expr, pkg *loader.PackageInfo) bool {
   211  	ident, ok := val.(*ast.Ident)
   212  	if !ok {
   213  		return false
   214  	}
   215  	if _, ok := pkg.ObjectOf(ident).(*types.Nil); ok {
   216  		return true
   217  	}
   218  	if c, ok := pkg.ObjectOf(ident).(*types.Const); ok {
   219  		if c.Val().Kind() != constant.Bool {
   220  			return false
   221  		}
   222  		return !constant.BoolVal(c.Val())
   223  	}
   224  	return false
   225  }
   227  func isZero(val ast.Expr, pkg *loader.PackageInfo) bool {
   228  	switch val := val.(type) {
   229  	case *ast.BasicLit:
   230  		switch val.Value {
   231  		case `""`, "``", "0", "0.0", "0i", "0.":
   232  			return true
   233  		default:
   234  			return false
   235  		}
   236  	case *ast.Ident:
   237  		return isNil(val, pkg)
   238  	case *ast.CompositeLit:
   239  		typ := pkg.TypeOf(val.Type)
   240  		if typ == nil {
   241  			return false
   242  		}
   243  		isIface := false
   244  		switch typ := typ.Underlying().(type) {
   245  		case *types.Struct:
   246  		case *types.Array:
   247  			_, isIface = typ.Elem().Underlying().(*types.Interface)
   248  		default:
   249  			return false
   250  		}
   251  		for _, elt := range val.Elts {
   252  			if isNil(elt, pkg) || (!isIface && !isZero(elt, pkg)) {
   253  				return false
   254  			}
   255  		}
   256  		return true
   257  	}
   258  	return false
   259  }
   261  func printComplit(oldlit, newlit *ast.CompositeLit, oldfset, newfset *token.FileSet) {
   262  	buf := &bytes.Buffer{}
   263  	cfg := printer.Config{Mode: printer.UseSpaces | printer.TabIndent, Tabwidth: 8}
   264  	_ = cfg.Fprint(buf, newfset, newlit)
   265  	if fJSON {
   266  		output := struct {
   267  			Start       int    `json:"start"`
   268  			End         int    `json:"end"`
   269  			Replacement string `json:"replacement"`
   270  		}{
   271  			oldfset.Position(oldlit.Pos()).Offset,
   272  			oldfset.Position(oldlit.End()).Offset,
   273  			buf.String(),
   274  		}
   275  		_ = json.NewEncoder(os.Stdout).Encode(output)
   276  	} else {
   277  		fmt.Println(buf.String())
   278  	}
   279  }
   281  func copyExpr(expr ast.Expr, line token.Pos) ast.Expr {
   282  	switch expr := expr.(type) {
   283  	case *ast.BasicLit:
   284  		cp := *expr
   285  		cp.ValuePos = 0
   286  		return &cp
   287  	case *ast.BinaryExpr:
   288  		cp := *expr
   289  		cp.X = copyExpr(cp.X, line)
   290  		cp.OpPos = 0
   291  		cp.Y = copyExpr(cp.Y, line)
   292  		return &cp
   293  	case *ast.CallExpr:
   294  		cp := *expr
   295  		cp.Fun = copyExpr(cp.Fun, line)
   296  		cp.Lparen = 0
   297  		for i, v := range cp.Args {
   298  			cp.Args[i] = copyExpr(v, line)
   299  		}
   300  		if cp.Ellipsis != 0 {
   301  			cp.Ellipsis = line
   302  		}
   303  		cp.Rparen = 0
   304  		return &cp
   305  	case *ast.CompositeLit:
   306  		cp := *expr
   307  		cp.Type = copyExpr(cp.Type, line)
   308  		cp.Lbrace = 0
   309  		for i, v := range cp.Elts {
   310  			cp.Elts[i] = copyExpr(v, line)
   311  		}
   312  		cp.Rbrace = 0
   313  		return &cp
   314  	case *ast.Ident:
   315  		cp := *expr
   316  		cp.NamePos = 0
   317  		return &cp
   318  	case *ast.IndexExpr:
   319  		cp := *expr
   320  		cp.X = copyExpr(cp.X, line)
   321  		cp.Lbrack = 0
   322  		cp.Index = copyExpr(cp.Index, line)
   323  		cp.Rbrack = 0
   324  		return &cp
   325  	case *ast.KeyValueExpr:
   326  		cp := *expr
   327  		cp.Key = copyExpr(cp.Key, line)
   328  		cp.Colon = 0
   329  		cp.Value = copyExpr(cp.Value, line)
   330  		return &cp
   331  	case *ast.ParenExpr:
   332  		cp := *expr
   333  		cp.Lparen = 0
   334  		cp.X = copyExpr(cp.X, line)
   335  		cp.Rparen = 0
   336  		return &cp
   337  	case *ast.SelectorExpr:
   338  		cp := *expr
   339  		cp.X = copyExpr(cp.X, line)
   340  		cp.Sel = copyExpr(cp.Sel, line).(*ast.Ident)
   341  		return &cp
   342  	case *ast.SliceExpr:
   343  		cp := *expr
   344  		cp.X = copyExpr(cp.X, line)
   345  		cp.Lbrack = 0
   346  		cp.Low = copyExpr(cp.Low, line)
   347  		cp.High = copyExpr(cp.High, line)
   348  		cp.Max = copyExpr(cp.Max, line)
   349  		cp.Rbrack = 0
   350  		return &cp
   351  	case *ast.StarExpr:
   352  		cp := *expr
   353  		cp.Star = 0
   354  		cp.X = copyExpr(cp.X, line)
   355  		return &cp
   356  	case *ast.TypeAssertExpr:
   357  		cp := *expr
   358  		cp.X = copyExpr(cp.X, line)
   359  		cp.Lparen = 0
   360  		cp.Type = copyExpr(cp.Type, line)
   361  		cp.Rparen = 0
   362  		return &cp
   363  	case *ast.UnaryExpr:
   364  		cp := *expr
   365  		cp.OpPos = 0
   366  		cp.X = copyExpr(cp.X, line)
   367  		return &cp
   368  	case *ast.MapType:
   369  		cp := *expr
   370  		cp.Map = 0
   371  		cp.Key = copyExpr(cp.Key, line)
   372  		cp.Value = copyExpr(cp.Value, line)
   373  		return &cp
   374  	case *ast.ArrayType:
   375  		cp := *expr
   376  		cp.Lbrack = 0
   377  		cp.Len = copyExpr(cp.Len, line)
   378  		cp.Elt = copyExpr(cp.Elt, line)
   379  		return &cp
   380  	case *ast.Ellipsis:
   381  		cp := *expr
   382  		cp.Elt = copyExpr(cp.Elt, line)
   383  		cp.Ellipsis = line
   384  		return &cp
   385  	case *ast.InterfaceType:
   386  		cp := *expr
   387  		cp.Interface = 0
   388  		return &cp
   389  	case *ast.StructType:
   390  		cp := *expr
   391  		cp.Struct = 0
   392  		return &cp
   393  	case *ast.FuncLit:
   394  		return expr
   395  	case *ast.ChanType:
   396  		cp := *expr
   397  		cp.Arrow = 0
   398  		cp.Begin = 0
   399  		cp.Value = copyExpr(cp.Value, line)
   400  		return &cp
   401  	case nil:
   402  		return nil
   403  	default:
   404  		panic(fmt.Sprintf("shouldn't happen: unknown ast.Expr of type %T", expr))
   405  	}
   406  }