github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/go/ast/astutil/util.go (about)

     1  package astutil
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/token"
     7  	"reflect"
     8  	"strings"
     9  
    10  	"golang.org/x/tools/go/ast/astutil"
    11  )
    12  
    13  func IsIdent(expr ast.Expr, ident string) bool {
    14  	id, ok := expr.(*ast.Ident)
    15  	return ok && id.Name == ident
    16  }
    17  
    18  // isBlank returns whether id is the blank identifier "_".
    19  // If id == nil, the answer is false.
    20  func IsBlank(id ast.Expr) bool {
    21  	ident, _ := id.(*ast.Ident)
    22  	return ident != nil && ident.Name == "_"
    23  }
    24  
    25  // Deprecated: use code.IsIntegerLiteral instead.
    26  func IsIntLiteral(expr ast.Expr, literal string) bool {
    27  	lit, ok := expr.(*ast.BasicLit)
    28  	return ok && lit.Kind == token.INT && lit.Value == literal
    29  }
    30  
    31  // Deprecated: use IsIntLiteral instead
    32  func IsZero(expr ast.Expr) bool {
    33  	return IsIntLiteral(expr, "0")
    34  }
    35  
    36  func Preamble(f *ast.File) string {
    37  	cutoff := f.Package
    38  	if f.Doc != nil {
    39  		cutoff = f.Doc.Pos()
    40  	}
    41  	var out []string
    42  	for _, cmt := range f.Comments {
    43  		if cmt.Pos() >= cutoff {
    44  			break
    45  		}
    46  		out = append(out, cmt.Text())
    47  	}
    48  	return strings.Join(out, "\n")
    49  }
    50  
    51  func GroupSpecs(fset *token.FileSet, specs []ast.Spec) [][]ast.Spec {
    52  	if len(specs) == 0 {
    53  		return nil
    54  	}
    55  	groups := make([][]ast.Spec, 1)
    56  	groups[0] = append(groups[0], specs[0])
    57  
    58  	for _, spec := range specs[1:] {
    59  		g := groups[len(groups)-1]
    60  		if fset.PositionFor(spec.Pos(), false).Line-1 !=
    61  			fset.PositionFor(g[len(g)-1].End(), false).Line {
    62  
    63  			groups = append(groups, nil)
    64  		}
    65  
    66  		groups[len(groups)-1] = append(groups[len(groups)-1], spec)
    67  	}
    68  
    69  	return groups
    70  }
    71  
    72  // Unparen returns e with any enclosing parentheses stripped.
    73  func Unparen(e ast.Expr) ast.Expr {
    74  	for {
    75  		p, ok := e.(*ast.ParenExpr)
    76  		if !ok {
    77  			return e
    78  		}
    79  		e = p.X
    80  	}
    81  }
    82  
    83  // CopyExpr creates a deep copy of an expression.
    84  // It doesn't support copying FuncLits and returns ok == false when encountering one.
    85  func CopyExpr(node ast.Expr) (ast.Expr, bool) {
    86  	switch node := node.(type) {
    87  	case *ast.BasicLit:
    88  		cp := *node
    89  		return &cp, true
    90  	case *ast.BinaryExpr:
    91  		cp := *node
    92  		var ok1, ok2 bool
    93  		cp.X, ok1 = CopyExpr(cp.X)
    94  		cp.Y, ok2 = CopyExpr(cp.Y)
    95  		return &cp, ok1 && ok2
    96  	case *ast.CallExpr:
    97  		var ok bool
    98  		cp := *node
    99  		cp.Fun, ok = CopyExpr(cp.Fun)
   100  		if !ok {
   101  			return nil, false
   102  		}
   103  		cp.Args = make([]ast.Expr, len(node.Args))
   104  		for i, v := range node.Args {
   105  			cp.Args[i], ok = CopyExpr(v)
   106  			if !ok {
   107  				return nil, false
   108  			}
   109  		}
   110  		return &cp, true
   111  	case *ast.CompositeLit:
   112  		var ok bool
   113  		cp := *node
   114  		cp.Type, ok = CopyExpr(cp.Type)
   115  		if !ok {
   116  			return nil, false
   117  		}
   118  		cp.Elts = make([]ast.Expr, len(node.Elts))
   119  		for i, v := range node.Elts {
   120  			cp.Elts[i], ok = CopyExpr(v)
   121  			if !ok {
   122  				return nil, false
   123  			}
   124  		}
   125  		return &cp, true
   126  	case *ast.Ident:
   127  		cp := *node
   128  		return &cp, true
   129  	case *ast.IndexExpr:
   130  		var ok1, ok2 bool
   131  		cp := *node
   132  		cp.X, ok1 = CopyExpr(cp.X)
   133  		cp.Index, ok2 = CopyExpr(cp.Index)
   134  		return &cp, ok1 && ok2
   135  	case *ast.IndexListExpr:
   136  		var ok bool
   137  		cp := *node
   138  		cp.X, ok = CopyExpr(cp.X)
   139  		if !ok {
   140  			return nil, false
   141  		}
   142  		for i, v := range node.Indices {
   143  			cp.Indices[i], ok = CopyExpr(v)
   144  			if !ok {
   145  				return nil, false
   146  			}
   147  		}
   148  		return &cp, true
   149  	case *ast.KeyValueExpr:
   150  		var ok1, ok2 bool
   151  		cp := *node
   152  		cp.Key, ok1 = CopyExpr(cp.Key)
   153  		cp.Value, ok2 = CopyExpr(cp.Value)
   154  		return &cp, ok1 && ok2
   155  	case *ast.ParenExpr:
   156  		var ok bool
   157  		cp := *node
   158  		cp.X, ok = CopyExpr(cp.X)
   159  		return &cp, ok
   160  	case *ast.SelectorExpr:
   161  		var ok bool
   162  		cp := *node
   163  		cp.X, ok = CopyExpr(cp.X)
   164  		if !ok {
   165  			return nil, false
   166  		}
   167  		sel, ok := CopyExpr(cp.Sel)
   168  		if !ok {
   169  			// this is impossible
   170  			return nil, false
   171  		}
   172  		cp.Sel = sel.(*ast.Ident)
   173  		return &cp, true
   174  	case *ast.SliceExpr:
   175  		var ok1, ok2, ok3, ok4 bool
   176  		cp := *node
   177  		cp.X, ok1 = CopyExpr(cp.X)
   178  		cp.Low, ok2 = CopyExpr(cp.Low)
   179  		cp.High, ok3 = CopyExpr(cp.High)
   180  		cp.Max, ok4 = CopyExpr(cp.Max)
   181  		return &cp, ok1 && ok2 && ok3 && ok4
   182  	case *ast.StarExpr:
   183  		var ok bool
   184  		cp := *node
   185  		cp.X, ok = CopyExpr(cp.X)
   186  		return &cp, ok
   187  	case *ast.TypeAssertExpr:
   188  		var ok1, ok2 bool
   189  		cp := *node
   190  		cp.X, ok1 = CopyExpr(cp.X)
   191  		cp.Type, ok2 = CopyExpr(cp.Type)
   192  		return &cp, ok1 && ok2
   193  	case *ast.UnaryExpr:
   194  		var ok bool
   195  		cp := *node
   196  		cp.X, ok = CopyExpr(cp.X)
   197  		return &cp, ok
   198  	case *ast.MapType:
   199  		var ok1, ok2 bool
   200  		cp := *node
   201  		cp.Key, ok1 = CopyExpr(cp.Key)
   202  		cp.Value, ok2 = CopyExpr(cp.Value)
   203  		return &cp, ok1 && ok2
   204  	case *ast.ArrayType:
   205  		var ok1, ok2 bool
   206  		cp := *node
   207  		cp.Len, ok1 = CopyExpr(cp.Len)
   208  		cp.Elt, ok2 = CopyExpr(cp.Elt)
   209  		return &cp, ok1 && ok2
   210  	case *ast.Ellipsis:
   211  		var ok bool
   212  		cp := *node
   213  		cp.Elt, ok = CopyExpr(cp.Elt)
   214  		return &cp, ok
   215  	case *ast.InterfaceType:
   216  		cp := *node
   217  		return &cp, true
   218  	case *ast.StructType:
   219  		cp := *node
   220  		return &cp, true
   221  	case *ast.FuncLit, *ast.FuncType:
   222  		// TODO(dh): implement copying of function literals and types.
   223  		return nil, false
   224  	case *ast.ChanType:
   225  		var ok bool
   226  		cp := *node
   227  		cp.Value, ok = CopyExpr(cp.Value)
   228  		return &cp, ok
   229  	case nil:
   230  		return nil, true
   231  	default:
   232  		panic(fmt.Sprintf("unreachable: %T", node))
   233  	}
   234  }
   235  
   236  func Equal(a, b ast.Node) bool {
   237  	if a == b {
   238  		return true
   239  	}
   240  	if a == nil || b == nil {
   241  		return false
   242  	}
   243  	if reflect.TypeOf(a) != reflect.TypeOf(b) {
   244  		return false
   245  	}
   246  
   247  	switch a := a.(type) {
   248  	case *ast.BasicLit:
   249  		b := b.(*ast.BasicLit)
   250  		return a.Kind == b.Kind && a.Value == b.Value
   251  	case *ast.BinaryExpr:
   252  		b := b.(*ast.BinaryExpr)
   253  		return Equal(a.X, b.X) && a.Op == b.Op && Equal(a.Y, b.Y)
   254  	case *ast.CallExpr:
   255  		b := b.(*ast.CallExpr)
   256  		if len(a.Args) != len(b.Args) {
   257  			return false
   258  		}
   259  		for i, arg := range a.Args {
   260  			if !Equal(arg, b.Args[i]) {
   261  				return false
   262  			}
   263  		}
   264  		return Equal(a.Fun, b.Fun) &&
   265  			(a.Ellipsis == token.NoPos && b.Ellipsis == token.NoPos || a.Ellipsis != token.NoPos && b.Ellipsis != token.NoPos)
   266  	case *ast.CompositeLit:
   267  		b := b.(*ast.CompositeLit)
   268  		if len(a.Elts) != len(b.Elts) {
   269  			return false
   270  		}
   271  		for i, elt := range b.Elts {
   272  			if !Equal(elt, b.Elts[i]) {
   273  				return false
   274  			}
   275  		}
   276  		return Equal(a.Type, b.Type) && a.Incomplete == b.Incomplete
   277  	case *ast.Ident:
   278  		b := b.(*ast.Ident)
   279  		return a.Name == b.Name
   280  	case *ast.IndexExpr:
   281  		b := b.(*ast.IndexExpr)
   282  		return Equal(a.X, b.X) && Equal(a.Index, b.Index)
   283  	case *ast.IndexListExpr:
   284  		b := b.(*ast.IndexListExpr)
   285  		if len(a.Indices) != len(b.Indices) {
   286  			return false
   287  		}
   288  		for i, v := range a.Indices {
   289  			if !Equal(v, b.Indices[i]) {
   290  				return false
   291  			}
   292  		}
   293  		return Equal(a.X, b.X)
   294  	case *ast.KeyValueExpr:
   295  		b := b.(*ast.KeyValueExpr)
   296  		return Equal(a.Key, b.Key) && Equal(a.Value, b.Value)
   297  	case *ast.ParenExpr:
   298  		b := b.(*ast.ParenExpr)
   299  		return Equal(a.X, b.X)
   300  	case *ast.SelectorExpr:
   301  		b := b.(*ast.SelectorExpr)
   302  		return Equal(a.X, b.X) && Equal(a.Sel, b.Sel)
   303  	case *ast.SliceExpr:
   304  		b := b.(*ast.SliceExpr)
   305  		return Equal(a.X, b.X) && Equal(a.Low, b.Low) && Equal(a.High, b.High) && Equal(a.Max, b.Max) && a.Slice3 == b.Slice3
   306  	case *ast.StarExpr:
   307  		b := b.(*ast.StarExpr)
   308  		return Equal(a.X, b.X)
   309  	case *ast.TypeAssertExpr:
   310  		b := b.(*ast.TypeAssertExpr)
   311  		return Equal(a.X, b.X) && Equal(a.Type, b.Type)
   312  	case *ast.UnaryExpr:
   313  		b := b.(*ast.UnaryExpr)
   314  		return a.Op == b.Op && Equal(a.X, b.X)
   315  	case *ast.MapType:
   316  		b := b.(*ast.MapType)
   317  		return Equal(a.Key, b.Key) && Equal(a.Value, b.Value)
   318  	case *ast.ArrayType:
   319  		b := b.(*ast.ArrayType)
   320  		return Equal(a.Len, b.Len) && Equal(a.Elt, b.Elt)
   321  	case *ast.Ellipsis:
   322  		b := b.(*ast.Ellipsis)
   323  		return Equal(a.Elt, b.Elt)
   324  	case *ast.InterfaceType:
   325  		b := b.(*ast.InterfaceType)
   326  		return a.Incomplete == b.Incomplete && Equal(a.Methods, b.Methods)
   327  	case *ast.StructType:
   328  		b := b.(*ast.StructType)
   329  		return a.Incomplete == b.Incomplete && Equal(a.Fields, b.Fields)
   330  	case *ast.FuncLit:
   331  		// TODO(dh): support function literals
   332  		return false
   333  	case *ast.ChanType:
   334  		b := b.(*ast.ChanType)
   335  		return a.Dir == b.Dir && (a.Arrow == token.NoPos && b.Arrow == token.NoPos || a.Arrow != token.NoPos && b.Arrow != token.NoPos)
   336  	case *ast.FieldList:
   337  		b := b.(*ast.FieldList)
   338  		if len(a.List) != len(b.List) {
   339  			return false
   340  		}
   341  		for i, fieldA := range a.List {
   342  			if !Equal(fieldA, b.List[i]) {
   343  				return false
   344  			}
   345  		}
   346  		return true
   347  	case *ast.Field:
   348  		b := b.(*ast.Field)
   349  		if len(a.Names) != len(b.Names) {
   350  			return false
   351  		}
   352  		for j, name := range a.Names {
   353  			if !Equal(name, b.Names[j]) {
   354  				return false
   355  			}
   356  		}
   357  		if !Equal(a.Type, b.Type) || !Equal(a.Tag, b.Tag) {
   358  			return false
   359  		}
   360  		return true
   361  	default:
   362  		panic(fmt.Sprintf("unreachable: %T", a))
   363  	}
   364  }
   365  
   366  func NegateDeMorgan(expr ast.Expr, recursive bool) ast.Expr {
   367  	switch expr := expr.(type) {
   368  	case *ast.BinaryExpr:
   369  		var out ast.BinaryExpr
   370  		switch expr.Op {
   371  		case token.EQL:
   372  			out.X = expr.X
   373  			out.Op = token.NEQ
   374  			out.Y = expr.Y
   375  		case token.LSS:
   376  			out.X = expr.X
   377  			out.Op = token.GEQ
   378  			out.Y = expr.Y
   379  		case token.GTR:
   380  			out.X = expr.X
   381  			out.Op = token.LEQ
   382  			out.Y = expr.Y
   383  		case token.NEQ:
   384  			out.X = expr.X
   385  			out.Op = token.EQL
   386  			out.Y = expr.Y
   387  		case token.LEQ:
   388  			out.X = expr.X
   389  			out.Op = token.GTR
   390  			out.Y = expr.Y
   391  		case token.GEQ:
   392  			out.X = expr.X
   393  			out.Op = token.LSS
   394  			out.Y = expr.Y
   395  
   396  		case token.LAND:
   397  			out.X = NegateDeMorgan(expr.X, recursive)
   398  			out.Op = token.LOR
   399  			out.Y = NegateDeMorgan(expr.Y, recursive)
   400  		case token.LOR:
   401  			out.X = NegateDeMorgan(expr.X, recursive)
   402  			out.Op = token.LAND
   403  			out.Y = NegateDeMorgan(expr.Y, recursive)
   404  		}
   405  		return &out
   406  
   407  	case *ast.ParenExpr:
   408  		if recursive {
   409  			return &ast.ParenExpr{
   410  				X: NegateDeMorgan(expr.X, recursive),
   411  			}
   412  		} else {
   413  			return &ast.UnaryExpr{
   414  				Op: token.NOT,
   415  				X:  expr,
   416  			}
   417  		}
   418  
   419  	case *ast.UnaryExpr:
   420  		if expr.Op == token.NOT {
   421  			return expr.X
   422  		} else {
   423  			return &ast.UnaryExpr{
   424  				Op: token.NOT,
   425  				X:  expr,
   426  			}
   427  		}
   428  
   429  	default:
   430  		return &ast.UnaryExpr{
   431  			Op: token.NOT,
   432  			X:  expr,
   433  		}
   434  	}
   435  }
   436  
   437  func SimplifyParentheses(node ast.Expr) ast.Expr {
   438  	var changed bool
   439  	// XXX accept list of ops to operate on
   440  	// XXX copy AST node, don't modify in place
   441  	post := func(c *astutil.Cursor) bool {
   442  		out := c.Node()
   443  		if paren, ok := c.Node().(*ast.ParenExpr); ok {
   444  			out = paren.X
   445  		}
   446  
   447  		if binop, ok := out.(*ast.BinaryExpr); ok {
   448  			if right, ok := binop.Y.(*ast.BinaryExpr); ok && binop.Op == right.Op {
   449  				// XXX also check that Op is associative
   450  
   451  				root := binop
   452  				pivot := root.Y.(*ast.BinaryExpr)
   453  				root.Y = pivot.X
   454  				pivot.X = root
   455  				root = pivot
   456  				out = root
   457  			}
   458  		}
   459  
   460  		if out != c.Node() {
   461  			changed = true
   462  			c.Replace(out)
   463  		}
   464  		return true
   465  	}
   466  
   467  	for changed = true; changed; {
   468  		changed = false
   469  		node = astutil.Apply(node, nil, post).(ast.Expr)
   470  	}
   471  
   472  	return node
   473  }