gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/tools/go_generics/globals/globals_visitor.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package globals provides an AST visitor that calls the visit function for all
    16  // global identifiers.
    17  package globals
    18  
    19  import (
    20  	"fmt"
    21  
    22  	"go/ast"
    23  	"go/token"
    24  	"path/filepath"
    25  	"strconv"
    26  )
    27  
    28  // globalsVisitor holds the state used while traversing the nodes of a file in
    29  // search of globals.
    30  //
    31  // The visitor does two passes on the global declarations: the first one adds
    32  // all globals to the global scope (since Go allows references to globals that
    33  // haven't been declared yet), and the second one calls f() for the definition
    34  // and uses of globals found in the first pass.
    35  //
    36  // The implementation correctly handles cases when globals are aliased by
    37  // locals; in such cases, f() is not called.
    38  type globalsVisitor struct {
    39  	// file is the file whose nodes are being visited.
    40  	file *ast.File
    41  
    42  	// fset is the file set the file being visited belongs to.
    43  	fset *token.FileSet
    44  
    45  	// f is the visit function to be called when a global symbol is reached.
    46  	f func(*ast.Ident, SymKind)
    47  
    48  	// scope is the current scope as nodes are visited.
    49  	scope *scope
    50  
    51  	// processAnon indicates whether we should process anonymous struct fields.
    52  	// It does not perform strict checking on parameter types that share the same name
    53  	// as the global type and therefore will rename them as well.
    54  	processAnon bool
    55  }
    56  
    57  // unexpected is called when an unexpected node appears in the AST. It dumps
    58  // the location of the associated token and panics because this should only
    59  // happen when there is a bug in the traversal code.
    60  func (v *globalsVisitor) unexpected(p token.Pos) {
    61  	panic(fmt.Sprintf("Unable to parse at %v", v.fset.Position(p)))
    62  }
    63  
    64  // pushScope creates a new scope and pushes it to the top of the scope stack.
    65  func (v *globalsVisitor) pushScope() {
    66  	v.scope = newScope(v.scope)
    67  }
    68  
    69  // popScope removes the scope created by the last call to pushScope.
    70  func (v *globalsVisitor) popScope() {
    71  	v.scope = v.scope.outer
    72  }
    73  
    74  // visitType is called when an expression is known to be a type, for example,
    75  // on the first argument of make(). It visits all children nodes and reports
    76  // any globals.
    77  func (v *globalsVisitor) visitType(ge ast.Expr) {
    78  	switch e := ge.(type) {
    79  	case *ast.Ident:
    80  		if s := v.scope.deepLookup(e.Name); s != nil && s.scope.isGlobal() {
    81  			v.f(e, s.kind)
    82  		}
    83  
    84  	case *ast.SelectorExpr:
    85  		id := GetIdent(e.X)
    86  		if id == nil {
    87  			v.unexpected(e.X.Pos())
    88  		}
    89  
    90  	case *ast.StarExpr:
    91  		v.visitType(e.X)
    92  	case *ast.ParenExpr:
    93  		v.visitType(e.X)
    94  	case *ast.ChanType:
    95  		v.visitType(e.Value)
    96  	case *ast.Ellipsis:
    97  		v.visitType(e.Elt)
    98  	case *ast.ArrayType:
    99  		v.visitExpr(e.Len)
   100  		v.visitType(e.Elt)
   101  	case *ast.MapType:
   102  		v.visitType(e.Key)
   103  		v.visitType(e.Value)
   104  	case *ast.StructType:
   105  		v.visitFields(e.Fields, KindUnknown)
   106  	case *ast.FuncType:
   107  		v.visitFields(e.Params, KindUnknown)
   108  		v.visitFields(e.Results, KindUnknown)
   109  	case *ast.InterfaceType:
   110  		v.visitFields(e.Methods, KindUnknown)
   111  	default:
   112  		v.unexpected(ge.Pos())
   113  	}
   114  }
   115  
   116  // visitFields visits all fields, and add symbols if kind isn't KindUnknown.
   117  func (v *globalsVisitor) visitFields(l *ast.FieldList, kind SymKind) {
   118  	if l == nil {
   119  		return
   120  	}
   121  
   122  	for _, f := range l.List {
   123  		if kind != KindUnknown {
   124  			for _, n := range f.Names {
   125  				v.scope.add(n.Name, kind, n.Pos())
   126  			}
   127  		}
   128  		v.visitType(f.Type)
   129  		if f.Tag != nil {
   130  			tag := ast.NewIdent(f.Tag.Value)
   131  			v.f(tag, KindTag)
   132  			// Replace the tag if updated.
   133  			if tag.Name != f.Tag.Value {
   134  				f.Tag.Value = tag.Name
   135  			}
   136  		}
   137  	}
   138  }
   139  
   140  // visitGenDecl is called when a generic declaration is encountered, for example,
   141  // on variable, constant and type declarations. It adds all newly defined
   142  // symbols to the current scope and reports them if the current scope is the
   143  // global one.
   144  func (v *globalsVisitor) visitGenDecl(d *ast.GenDecl) {
   145  	switch d.Tok {
   146  	case token.IMPORT:
   147  	case token.TYPE:
   148  		for _, gs := range d.Specs {
   149  			s := gs.(*ast.TypeSpec)
   150  			v.scope.add(s.Name.Name, KindType, s.Name.Pos())
   151  			if v.scope.isGlobal() {
   152  				v.f(s.Name, KindType)
   153  			}
   154  			v.visitType(s.Type)
   155  		}
   156  	case token.CONST, token.VAR:
   157  		kind := KindConst
   158  		if d.Tok == token.VAR {
   159  			kind = KindVar
   160  		}
   161  
   162  		for _, gs := range d.Specs {
   163  			s := gs.(*ast.ValueSpec)
   164  			if s.Type != nil {
   165  				v.visitType(s.Type)
   166  			}
   167  
   168  			for _, e := range s.Values {
   169  				v.visitExpr(e)
   170  			}
   171  
   172  			for _, n := range s.Names {
   173  				if v.scope.isGlobal() {
   174  					v.f(n, kind)
   175  				}
   176  				v.scope.add(n.Name, kind, n.Pos())
   177  			}
   178  		}
   179  	default:
   180  		v.unexpected(d.Pos())
   181  	}
   182  }
   183  
   184  // isViableType determines if the given expression is a viable type expression,
   185  // that is, if it could be interpreted as a type, for example, sync.Mutex,
   186  // myType, func(int)int, as opposed to -1, 2 * 2, a + b, etc.
   187  func (v *globalsVisitor) isViableType(expr ast.Expr) bool {
   188  	switch e := expr.(type) {
   189  	case *ast.Ident:
   190  		// This covers the plain identifier case. When we see it, we
   191  		// have to check if it resolves to a type; if the symbol is not
   192  		// known, we'll claim it's viable as a type.
   193  		s := v.scope.deepLookup(e.Name)
   194  		return s == nil || s.kind == KindType
   195  
   196  	case *ast.ChanType, *ast.ArrayType, *ast.MapType, *ast.StructType, *ast.FuncType, *ast.InterfaceType, *ast.Ellipsis:
   197  		// This covers the following cases:
   198  		// 1. ChanType:
   199  		//   chan T
   200  		//   <-chan T
   201  		//   chan<- T
   202  		// 2. ArrayType:
   203  		//   [Expr]T
   204  		// 3. MapType:
   205  		//   map[T]U
   206  		// 4. StructType:
   207  		//   struct { Fields }
   208  		// 5. FuncType:
   209  		//   func(Fields)Returns
   210  		// 6. Interface:
   211  		//   interface { Fields }
   212  		// 7. Ellipsis:
   213  		//   ...T
   214  		return true
   215  
   216  	case *ast.SelectorExpr:
   217  		// The only case in which an expression involving a selector can
   218  		// be a type is if it has the following form X.T, where X is an
   219  		// import, and T is a type exported by X.
   220  		//
   221  		// There's no way to know whether T is a type because we don't
   222  		// parse imports. So we just claim that this is a viable type;
   223  		// it doesn't affect the general result because we don't visit
   224  		// imported symbols.
   225  		id := GetIdent(e.X)
   226  		if id == nil {
   227  			return false
   228  		}
   229  
   230  		s := v.scope.deepLookup(id.Name)
   231  		return s != nil && s.kind == KindImport
   232  
   233  	case *ast.StarExpr:
   234  		// This covers the *T case. The expression is a viable type if
   235  		// T is.
   236  		return v.isViableType(e.X)
   237  
   238  	case *ast.ParenExpr:
   239  		// This covers the (T) case. The expression is a viable type if
   240  		// T is.
   241  		return v.isViableType(e.X)
   242  
   243  	default:
   244  		return false
   245  	}
   246  }
   247  
   248  // visitCallExpr visits a "call expression" which can be either a
   249  // function/method call (e.g., f(), pkg.f(), obj.f(), etc.) call or a type
   250  // conversion (e.g., int32(1), (*sync.Mutex)(ptr), etc.).
   251  func (v *globalsVisitor) visitCallExpr(e *ast.CallExpr) {
   252  	if v.isViableType(e.Fun) {
   253  		v.visitType(e.Fun)
   254  	} else {
   255  		v.visitExpr(e.Fun)
   256  	}
   257  
   258  	// If the function being called is new or make, the first argument is
   259  	// a type, so it needs to be visited as such.
   260  	first := 0
   261  	if id := GetIdent(e.Fun); id != nil && (id.Name == "make" || id.Name == "new") {
   262  		if len(e.Args) > 0 {
   263  			v.visitType(e.Args[0])
   264  		}
   265  		first = 1
   266  	}
   267  
   268  	for i := first; i < len(e.Args); i++ {
   269  		v.visitExpr(e.Args[i])
   270  	}
   271  }
   272  
   273  // visitExpr visits all nodes of an expression, and reports any globals that it
   274  // finds.
   275  func (v *globalsVisitor) visitExpr(ge ast.Expr) {
   276  	switch e := ge.(type) {
   277  	case nil:
   278  	case *ast.Ident:
   279  		if s := v.scope.deepLookup(e.Name); s != nil && s.scope.isGlobal() {
   280  			v.f(e, s.kind)
   281  		}
   282  
   283  	case *ast.BasicLit:
   284  	case *ast.CompositeLit:
   285  		v.visitType(e.Type)
   286  		for _, ne := range e.Elts {
   287  			v.visitExpr(ne)
   288  		}
   289  	case *ast.FuncLit:
   290  		v.pushScope()
   291  		v.visitFields(e.Type.Params, KindParameter)
   292  		v.visitFields(e.Type.Results, KindResult)
   293  		v.visitBlockStmt(e.Body)
   294  		v.popScope()
   295  
   296  	case *ast.BinaryExpr:
   297  		v.visitExpr(e.X)
   298  		v.visitExpr(e.Y)
   299  
   300  	case *ast.CallExpr:
   301  		v.visitCallExpr(e)
   302  
   303  	case *ast.IndexExpr:
   304  		v.visitExpr(e.X)
   305  		v.visitExpr(e.Index)
   306  
   307  	case *ast.KeyValueExpr:
   308  		v.visitExpr(e.Value)
   309  
   310  	case *ast.ParenExpr:
   311  		v.visitExpr(e.X)
   312  
   313  	case *ast.SelectorExpr:
   314  		v.visitExpr(e.X)
   315  		if v.processAnon {
   316  			v.visitExpr(e.Sel)
   317  		}
   318  
   319  	case *ast.SliceExpr:
   320  		v.visitExpr(e.X)
   321  		v.visitExpr(e.Low)
   322  		v.visitExpr(e.High)
   323  		v.visitExpr(e.Max)
   324  
   325  	case *ast.StarExpr:
   326  		v.visitExpr(e.X)
   327  
   328  	case *ast.TypeAssertExpr:
   329  		v.visitExpr(e.X)
   330  		if e.Type != nil {
   331  			v.visitType(e.Type)
   332  		}
   333  
   334  	case *ast.UnaryExpr:
   335  		v.visitExpr(e.X)
   336  
   337  	default:
   338  		v.unexpected(ge.Pos())
   339  	}
   340  }
   341  
   342  // GetIdent returns the identifier associated with the given expression by
   343  // removing parentheses if needed.
   344  func GetIdent(expr ast.Expr) *ast.Ident {
   345  	switch e := expr.(type) {
   346  	case *ast.Ident:
   347  		return e
   348  	case *ast.ParenExpr:
   349  		return GetIdent(e.X)
   350  	default:
   351  		return nil
   352  	}
   353  }
   354  
   355  // visitStmt visits all nodes of a statement, and reports any globals that it
   356  // finds. It also adds to the current scope new symbols defined/declared.
   357  func (v *globalsVisitor) visitStmt(gs ast.Stmt) {
   358  	switch s := gs.(type) {
   359  	case nil, *ast.BranchStmt, *ast.EmptyStmt:
   360  	case *ast.AssignStmt:
   361  		for _, e := range s.Rhs {
   362  			v.visitExpr(e)
   363  		}
   364  
   365  		// We visit the LHS after the RHS because the symbols we'll
   366  		// potentially add to the table aren't meant to be visible to
   367  		// the RHS.
   368  		for _, e := range s.Lhs {
   369  			if s.Tok == token.DEFINE {
   370  				if n := GetIdent(e); n != nil {
   371  					v.scope.add(n.Name, KindVar, n.Pos())
   372  				}
   373  			}
   374  			v.visitExpr(e)
   375  		}
   376  
   377  	case *ast.BlockStmt:
   378  		v.visitBlockStmt(s)
   379  
   380  	case *ast.DeclStmt:
   381  		v.visitGenDecl(s.Decl.(*ast.GenDecl))
   382  
   383  	case *ast.DeferStmt:
   384  		v.visitCallExpr(s.Call)
   385  
   386  	case *ast.ExprStmt:
   387  		v.visitExpr(s.X)
   388  
   389  	case *ast.ForStmt:
   390  		v.pushScope()
   391  		v.visitStmt(s.Init)
   392  		v.visitExpr(s.Cond)
   393  		v.visitStmt(s.Post)
   394  		v.visitBlockStmt(s.Body)
   395  		v.popScope()
   396  
   397  	case *ast.GoStmt:
   398  		v.visitCallExpr(s.Call)
   399  
   400  	case *ast.IfStmt:
   401  		v.pushScope()
   402  		v.visitStmt(s.Init)
   403  		v.visitExpr(s.Cond)
   404  		v.visitBlockStmt(s.Body)
   405  		v.visitStmt(s.Else)
   406  		v.popScope()
   407  
   408  	case *ast.IncDecStmt:
   409  		v.visitExpr(s.X)
   410  
   411  	case *ast.LabeledStmt:
   412  		v.visitStmt(s.Stmt)
   413  
   414  	case *ast.RangeStmt:
   415  		v.pushScope()
   416  		v.visitExpr(s.X)
   417  		if s.Tok == token.DEFINE {
   418  			if n := GetIdent(s.Key); n != nil {
   419  				v.scope.add(n.Name, KindVar, n.Pos())
   420  			}
   421  
   422  			if n := GetIdent(s.Value); n != nil {
   423  				v.scope.add(n.Name, KindVar, n.Pos())
   424  			}
   425  		}
   426  		v.visitExpr(s.Key)
   427  		v.visitExpr(s.Value)
   428  		v.visitBlockStmt(s.Body)
   429  		v.popScope()
   430  
   431  	case *ast.ReturnStmt:
   432  		for _, r := range s.Results {
   433  			v.visitExpr(r)
   434  		}
   435  
   436  	case *ast.SelectStmt:
   437  		for _, ns := range s.Body.List {
   438  			c := ns.(*ast.CommClause)
   439  
   440  			v.pushScope()
   441  			v.visitStmt(c.Comm)
   442  			for _, bs := range c.Body {
   443  				v.visitStmt(bs)
   444  			}
   445  			v.popScope()
   446  		}
   447  
   448  	case *ast.SendStmt:
   449  		v.visitExpr(s.Chan)
   450  		v.visitExpr(s.Value)
   451  
   452  	case *ast.SwitchStmt:
   453  		v.pushScope()
   454  		v.visitStmt(s.Init)
   455  		v.visitExpr(s.Tag)
   456  		for _, ns := range s.Body.List {
   457  			c := ns.(*ast.CaseClause)
   458  			v.pushScope()
   459  			for _, ce := range c.List {
   460  				v.visitExpr(ce)
   461  			}
   462  			for _, bs := range c.Body {
   463  				v.visitStmt(bs)
   464  			}
   465  			v.popScope()
   466  		}
   467  		v.popScope()
   468  
   469  	case *ast.TypeSwitchStmt:
   470  		v.pushScope()
   471  		v.visitStmt(s.Init)
   472  		v.visitStmt(s.Assign)
   473  		for _, ns := range s.Body.List {
   474  			c := ns.(*ast.CaseClause)
   475  			v.pushScope()
   476  			for _, ce := range c.List {
   477  				v.visitType(ce)
   478  			}
   479  			for _, bs := range c.Body {
   480  				v.visitStmt(bs)
   481  			}
   482  			v.popScope()
   483  		}
   484  		v.popScope()
   485  
   486  	default:
   487  		v.unexpected(gs.Pos())
   488  	}
   489  }
   490  
   491  // visitBlockStmt visits all statements in the block, adding symbols to a newly
   492  // created scope.
   493  func (v *globalsVisitor) visitBlockStmt(s *ast.BlockStmt) {
   494  	v.pushScope()
   495  	for _, c := range s.List {
   496  		v.visitStmt(c)
   497  	}
   498  	v.popScope()
   499  }
   500  
   501  // visitFuncDecl is called when a function or method declaration is encountered.
   502  // it creates a new scope for the function [optional] receiver, parameters and
   503  // results, and visits all children nodes.
   504  func (v *globalsVisitor) visitFuncDecl(d *ast.FuncDecl) {
   505  	// We don't report methods.
   506  	if d.Recv == nil {
   507  		v.f(d.Name, KindFunction)
   508  	}
   509  
   510  	v.pushScope()
   511  	v.visitFields(d.Recv, KindReceiver)
   512  	v.visitFields(d.Type.Params, KindParameter)
   513  	v.visitFields(d.Type.Results, KindResult)
   514  	if d.Body != nil {
   515  		v.visitBlockStmt(d.Body)
   516  	}
   517  	v.popScope()
   518  }
   519  
   520  // globalsFromDecl is called in the first, and adds symbols to global scope.
   521  func (v *globalsVisitor) globalsFromGenDecl(d *ast.GenDecl) {
   522  	switch d.Tok {
   523  	case token.IMPORT:
   524  		for _, gs := range d.Specs {
   525  			s := gs.(*ast.ImportSpec)
   526  			if s.Name == nil {
   527  				str, _ := strconv.Unquote(s.Path.Value)
   528  				v.scope.add(filepath.Base(str), KindImport, s.Path.Pos())
   529  			} else if s.Name.Name != "_" {
   530  				v.scope.add(s.Name.Name, KindImport, s.Name.Pos())
   531  			}
   532  		}
   533  	case token.TYPE:
   534  		for _, gs := range d.Specs {
   535  			s := gs.(*ast.TypeSpec)
   536  			v.scope.add(s.Name.Name, KindType, s.Name.Pos())
   537  		}
   538  	case token.CONST, token.VAR:
   539  		kind := KindConst
   540  		if d.Tok == token.VAR {
   541  			kind = KindVar
   542  		}
   543  
   544  		for _, s := range d.Specs {
   545  			for _, n := range s.(*ast.ValueSpec).Names {
   546  				v.scope.add(n.Name, kind, n.Pos())
   547  			}
   548  		}
   549  	default:
   550  		v.unexpected(d.Pos())
   551  	}
   552  }
   553  
   554  // visit implements the visiting of globals. It does performs the two passes
   555  // described in the description of the globalsVisitor struct.
   556  func (v *globalsVisitor) visit() {
   557  	// Gather all symbols in the global scope. This excludes methods.
   558  	v.pushScope()
   559  	for _, gd := range v.file.Decls {
   560  		switch d := gd.(type) {
   561  		case *ast.GenDecl:
   562  			v.globalsFromGenDecl(d)
   563  		case *ast.FuncDecl:
   564  			if d.Recv == nil {
   565  				v.scope.add(d.Name.Name, KindFunction, d.Name.Pos())
   566  			}
   567  		default:
   568  			v.unexpected(gd.Pos())
   569  		}
   570  	}
   571  
   572  	// Go through the contents of the declarations.
   573  	for _, gd := range v.file.Decls {
   574  		switch d := gd.(type) {
   575  		case *ast.GenDecl:
   576  			v.visitGenDecl(d)
   577  		case *ast.FuncDecl:
   578  			v.visitFuncDecl(d)
   579  		}
   580  	}
   581  }
   582  
   583  // Visit traverses the provided AST and calls f() for each identifier that
   584  // refers to global names. The global name must be defined in the file itself.
   585  //
   586  // The function f() is allowed to modify the identifier, for example, to rename
   587  // uses of global references.
   588  func Visit(fset *token.FileSet, file *ast.File, f func(*ast.Ident, SymKind), processAnon bool) {
   589  	v := globalsVisitor{
   590  		fset:        fset,
   591  		file:        file,
   592  		f:           f,
   593  		processAnon: processAnon,
   594  	}
   595  
   596  	v.visit()
   597  }