github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/cmd/keyify/keyify.go (about)

     1  // keyify transforms unkeyed struct literals into a keyed ones.
     2  package main
     3  
     4  import (
     5  	"bytes"
     6  	"encoding/json"
     7  	"flag"
     8  	"fmt"
     9  	"go/ast"
    10  	"go/build"
    11  	"go/constant"
    12  	"go/printer"
    13  	"go/token"
    14  	"go/types"
    15  	"log"
    16  	"os"
    17  	"path/filepath"
    18  
    19  	"github.com/amarpal/go-tools/go/ast/astutil"
    20  	"github.com/amarpal/go-tools/lintcmd/version"
    21  
    22  	"golang.org/x/tools/go/buildutil"
    23  
    24  	//lint:ignore SA1019 this tool is unmaintained, just keep it working
    25  	"golang.org/x/tools/go/loader"
    26  )
    27  
    28  var (
    29  	fRecursive bool
    30  	fOneLine   bool
    31  	fJSON      bool
    32  	fMinify    bool
    33  	fModified  bool
    34  	fVersion   bool
    35  )
    36  
    37  func init() {
    38  	flag.BoolVar(&fRecursive, "r", false, "keyify struct initializers recursively")
    39  	flag.BoolVar(&fOneLine, "o", false, "print new struct initializer on a single line")
    40  	flag.BoolVar(&fJSON, "json", false, "print new struct initializer as JSON")
    41  	flag.BoolVar(&fMinify, "m", false, "omit fields that are set to their zero value")
    42  	flag.BoolVar(&fModified, "modified", false, "read an archive of modified files from standard input")
    43  	flag.BoolVar(&fVersion, "version", false, "Print version and exit")
    44  }
    45  
    46  func usage() {
    47  	fmt.Printf("Usage: %s [flags] <position>\n\n", os.Args[0])
    48  	flag.PrintDefaults()
    49  }
    50  
    51  func main() {
    52  	log.SetFlags(0)
    53  	flag.Usage = usage
    54  	flag.Parse()
    55  
    56  	if fVersion {
    57  		version.Print(version.Version, version.MachineVersion)
    58  		os.Exit(0)
    59  	}
    60  
    61  	if flag.NArg() != 1 {
    62  		flag.Usage()
    63  		os.Exit(2)
    64  	}
    65  	pos := flag.Args()[0]
    66  	name, start, _, err := parsePos(pos)
    67  	if err != nil {
    68  		log.Fatal(err)
    69  	}
    70  	eval, err := filepath.EvalSymlinks(name)
    71  	if err != nil {
    72  		log.Fatal(err)
    73  	}
    74  	name, err = filepath.Abs(eval)
    75  	if err != nil {
    76  		log.Fatal(err)
    77  	}
    78  	cwd, err := os.Getwd()
    79  	if err != nil {
    80  		log.Fatal(err)
    81  	}
    82  	ctx := &build.Default
    83  	if fModified {
    84  		overlay, err := buildutil.ParseOverlayArchive(os.Stdin)
    85  		if err != nil {
    86  			log.Fatal(err)
    87  		}
    88  		ctx = buildutil.OverlayContext(ctx, overlay)
    89  	}
    90  	bpkg, err := buildutil.ContainingPackage(ctx, cwd, name)
    91  	if err != nil {
    92  		log.Fatal(err)
    93  	}
    94  	conf := &loader.Config{
    95  		Build: ctx,
    96  	}
    97  	conf.TypeCheckFuncBodies = func(s string) bool {
    98  		return s == bpkg.ImportPath || s == bpkg.ImportPath+"_test"
    99  	}
   100  	conf.ImportWithTests(bpkg.ImportPath)
   101  	lprog, err := conf.Load()
   102  	if err != nil {
   103  		log.Fatal(err)
   104  	}
   105  	var tf *token.File
   106  	var af *ast.File
   107  	var pkg *loader.PackageInfo
   108  outer:
   109  	for _, pkg = range lprog.InitialPackages() {
   110  		for _, ff := range pkg.Files {
   111  			file := lprog.Fset.File(ff.Pos())
   112  			if file.Name() == name {
   113  				af = ff
   114  				tf = file
   115  				break outer
   116  			}
   117  		}
   118  	}
   119  	if tf == nil {
   120  		log.Fatalf("couldn't find file %s", name)
   121  	}
   122  	tstart, tend, err := fileOffsetToPos(tf, start, start)
   123  	if err != nil {
   124  		log.Fatal(err)
   125  	}
   126  	path, _ := astutil.PathEnclosingInterval(af, tstart, tend)
   127  	var complit *ast.CompositeLit
   128  	for _, p := range path {
   129  		if p, ok := p.(*ast.CompositeLit); ok {
   130  			complit = p
   131  			break
   132  		}
   133  	}
   134  	if complit == nil {
   135  		log.Fatal("no composite literal found near point")
   136  	}
   137  	if len(complit.Elts) == 0 {
   138  		printComplit(complit, complit, lprog.Fset, lprog.Fset)
   139  		return
   140  	}
   141  	if _, ok := complit.Elts[0].(*ast.KeyValueExpr); ok {
   142  		lit := complit
   143  		if fOneLine {
   144  			lit = copyExpr(complit, 1).(*ast.CompositeLit)
   145  		}
   146  		printComplit(complit, lit, lprog.Fset, lprog.Fset)
   147  		return
   148  	}
   149  	_, ok := pkg.TypeOf(complit).Underlying().(*types.Struct)
   150  	if !ok {
   151  		log.Fatal("not a struct initialiser")
   152  		return
   153  	}
   154  
   155  	newComplit, lines := keyify(pkg, complit)
   156  	newFset := token.NewFileSet()
   157  	newFile := newFset.AddFile("", -1, lines)
   158  	for i := 1; i <= lines; i++ {
   159  		newFile.AddLine(i)
   160  	}
   161  	printComplit(complit, newComplit, lprog.Fset, newFset)
   162  }
   163  
   164  func keyify(
   165  	pkg *loader.PackageInfo,
   166  	complit *ast.CompositeLit,
   167  ) (*ast.CompositeLit, int) {
   168  	var calcPos func(int) token.Pos
   169  	if fOneLine {
   170  		calcPos = func(int) token.Pos { return token.Pos(1) }
   171  	} else {
   172  		calcPos = func(i int) token.Pos { return token.Pos(2 + i) }
   173  	}
   174  
   175  	st, _ := pkg.TypeOf(complit).Underlying().(*types.Struct)
   176  	newComplit := &ast.CompositeLit{
   177  		Type:   complit.Type,
   178  		Lbrace: 1,
   179  		Rbrace: token.Pos(st.NumFields() + 2),
   180  	}
   181  	if fOneLine {
   182  		newComplit.Rbrace = 1
   183  	}
   184  	numLines := 2 + st.NumFields()
   185  	n := 0
   186  	for i := 0; i < st.NumFields(); i++ {
   187  		field := st.Field(i)
   188  		val := complit.Elts[i]
   189  		if fRecursive {
   190  			if val2, ok := val.(*ast.CompositeLit); ok {
   191  				if _, ok := pkg.TypeOf(val2.Type).Underlying().(*types.Struct); ok {
   192  					// FIXME(dh): this code is obviously wrong. But
   193  					// what were we intending to do here?
   194  					var lines int
   195  					numLines += lines
   196  					//lint:ignore SA4006 See FIXME above.
   197  					val, lines = keyify(pkg, val2)
   198  				}
   199  			}
   200  		}
   201  		_, isIface := st.Field(i).Type().Underlying().(*types.Interface)
   202  		if fMinify && (isNil(val, pkg) || (!isIface && isZero(val, pkg))) {
   203  			continue
   204  		}
   205  		elt := &ast.KeyValueExpr{
   206  			Key:   &ast.Ident{NamePos: calcPos(n), Name: field.Name()},
   207  			Value: copyExpr(val, calcPos(n)),
   208  		}
   209  		newComplit.Elts = append(newComplit.Elts, elt)
   210  		n++
   211  	}
   212  	return newComplit, numLines
   213  }
   214  
   215  func isNil(val ast.Expr, pkg *loader.PackageInfo) bool {
   216  	ident, ok := val.(*ast.Ident)
   217  	if !ok {
   218  		return false
   219  	}
   220  	if _, ok := pkg.ObjectOf(ident).(*types.Nil); ok {
   221  		return true
   222  	}
   223  	if c, ok := pkg.ObjectOf(ident).(*types.Const); ok {
   224  		if c.Val().Kind() != constant.Bool {
   225  			return false
   226  		}
   227  		return !constant.BoolVal(c.Val())
   228  	}
   229  	return false
   230  }
   231  
   232  func isZero(val ast.Expr, pkg *loader.PackageInfo) bool {
   233  	switch val := val.(type) {
   234  	case *ast.BasicLit:
   235  		switch val.Value {
   236  		case `""`, "``", "0", "0.0", "0i", "0.":
   237  			return true
   238  		default:
   239  			return false
   240  		}
   241  	case *ast.Ident:
   242  		return isNil(val, pkg)
   243  	case *ast.CompositeLit:
   244  		typ := pkg.TypeOf(val.Type)
   245  		if typ == nil {
   246  			return false
   247  		}
   248  		isIface := false
   249  		switch typ := typ.Underlying().(type) {
   250  		case *types.Struct:
   251  		case *types.Array:
   252  			_, isIface = typ.Elem().Underlying().(*types.Interface)
   253  		default:
   254  			return false
   255  		}
   256  		for _, elt := range val.Elts {
   257  			if isNil(elt, pkg) || (!isIface && !isZero(elt, pkg)) {
   258  				return false
   259  			}
   260  		}
   261  		return true
   262  	}
   263  	return false
   264  }
   265  
   266  func printComplit(oldlit, newlit *ast.CompositeLit, oldfset, newfset *token.FileSet) {
   267  	buf := &bytes.Buffer{}
   268  	cfg := printer.Config{Mode: printer.UseSpaces | printer.TabIndent, Tabwidth: 8}
   269  	_ = cfg.Fprint(buf, newfset, newlit)
   270  	if fJSON {
   271  		output := struct {
   272  			Start       int    `json:"start"`
   273  			End         int    `json:"end"`
   274  			Replacement string `json:"replacement"`
   275  		}{
   276  			oldfset.Position(oldlit.Pos()).Offset,
   277  			oldfset.Position(oldlit.End()).Offset,
   278  			buf.String(),
   279  		}
   280  		_ = json.NewEncoder(os.Stdout).Encode(output)
   281  	} else {
   282  		fmt.Println(buf.String())
   283  	}
   284  }
   285  
   286  func copyExpr(expr ast.Expr, line token.Pos) ast.Expr {
   287  	switch expr := expr.(type) {
   288  	case *ast.BasicLit:
   289  		cp := *expr
   290  		cp.ValuePos = 0
   291  		return &cp
   292  	case *ast.BinaryExpr:
   293  		cp := *expr
   294  		cp.X = copyExpr(cp.X, line)
   295  		cp.OpPos = 0
   296  		cp.Y = copyExpr(cp.Y, line)
   297  		return &cp
   298  	case *ast.CallExpr:
   299  		cp := *expr
   300  		cp.Fun = copyExpr(cp.Fun, line)
   301  		cp.Lparen = 0
   302  		for i, v := range cp.Args {
   303  			cp.Args[i] = copyExpr(v, line)
   304  		}
   305  		if cp.Ellipsis != 0 {
   306  			cp.Ellipsis = line
   307  		}
   308  		cp.Rparen = 0
   309  		return &cp
   310  	case *ast.CompositeLit:
   311  		cp := *expr
   312  		cp.Type = copyExpr(cp.Type, line)
   313  		cp.Lbrace = 0
   314  		for i, v := range cp.Elts {
   315  			cp.Elts[i] = copyExpr(v, line)
   316  		}
   317  		cp.Rbrace = 0
   318  		return &cp
   319  	case *ast.Ident:
   320  		cp := *expr
   321  		cp.NamePos = 0
   322  		return &cp
   323  	case *ast.IndexExpr:
   324  		cp := *expr
   325  		cp.X = copyExpr(cp.X, line)
   326  		cp.Lbrack = 0
   327  		cp.Index = copyExpr(cp.Index, line)
   328  		cp.Rbrack = 0
   329  		return &cp
   330  	case *ast.KeyValueExpr:
   331  		cp := *expr
   332  		cp.Key = copyExpr(cp.Key, line)
   333  		cp.Colon = 0
   334  		cp.Value = copyExpr(cp.Value, line)
   335  		return &cp
   336  	case *ast.ParenExpr:
   337  		cp := *expr
   338  		cp.Lparen = 0
   339  		cp.X = copyExpr(cp.X, line)
   340  		cp.Rparen = 0
   341  		return &cp
   342  	case *ast.SelectorExpr:
   343  		cp := *expr
   344  		cp.X = copyExpr(cp.X, line)
   345  		cp.Sel = copyExpr(cp.Sel, line).(*ast.Ident)
   346  		return &cp
   347  	case *ast.SliceExpr:
   348  		cp := *expr
   349  		cp.X = copyExpr(cp.X, line)
   350  		cp.Lbrack = 0
   351  		cp.Low = copyExpr(cp.Low, line)
   352  		cp.High = copyExpr(cp.High, line)
   353  		cp.Max = copyExpr(cp.Max, line)
   354  		cp.Rbrack = 0
   355  		return &cp
   356  	case *ast.StarExpr:
   357  		cp := *expr
   358  		cp.Star = 0
   359  		cp.X = copyExpr(cp.X, line)
   360  		return &cp
   361  	case *ast.TypeAssertExpr:
   362  		cp := *expr
   363  		cp.X = copyExpr(cp.X, line)
   364  		cp.Lparen = 0
   365  		cp.Type = copyExpr(cp.Type, line)
   366  		cp.Rparen = 0
   367  		return &cp
   368  	case *ast.UnaryExpr:
   369  		cp := *expr
   370  		cp.OpPos = 0
   371  		cp.X = copyExpr(cp.X, line)
   372  		return &cp
   373  	case *ast.MapType:
   374  		cp := *expr
   375  		cp.Map = 0
   376  		cp.Key = copyExpr(cp.Key, line)
   377  		cp.Value = copyExpr(cp.Value, line)
   378  		return &cp
   379  	case *ast.ArrayType:
   380  		cp := *expr
   381  		cp.Lbrack = 0
   382  		cp.Len = copyExpr(cp.Len, line)
   383  		cp.Elt = copyExpr(cp.Elt, line)
   384  		return &cp
   385  	case *ast.Ellipsis:
   386  		cp := *expr
   387  		cp.Elt = copyExpr(cp.Elt, line)
   388  		cp.Ellipsis = line
   389  		return &cp
   390  	case *ast.InterfaceType:
   391  		cp := *expr
   392  		cp.Interface = 0
   393  		return &cp
   394  	case *ast.StructType:
   395  		cp := *expr
   396  		cp.Struct = 0
   397  		return &cp
   398  	case *ast.FuncLit:
   399  		return expr
   400  	case *ast.ChanType:
   401  		cp := *expr
   402  		cp.Arrow = 0
   403  		cp.Begin = 0
   404  		cp.Value = copyExpr(cp.Value, line)
   405  		return &cp
   406  	case nil:
   407  		return nil
   408  	default:
   409  		panic(fmt.Sprintf("shouldn't happen: unknown ast.Expr of type %T", expr))
   410  	}
   411  }