gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/tools/checkconst/checkconst.go (about)

     1  // Copyright 2021 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package checkconst checks constant values.
    16  //
    17  // This analyzer supports multiple annotations: checkconst, checkoffset, checksize and checkalign.
    18  // Each of these essentially checks the value of the declared constant (or the #define'ed value in
    19  // the case of an assembly file) against the value seen during analysis. If this does not match,
    20  // an error is emitted with the appropriate value for that constant/offset/size/alignment.
    21  package checkconst
    22  
    23  import (
    24  	"fmt"
    25  	"go/ast"
    26  	"go/token"
    27  	"go/types"
    28  	"io/ioutil"
    29  	"regexp"
    30  	"strconv"
    31  	"strings"
    32  
    33  	"golang.org/x/tools/go/analysis"
    34  )
    35  
    36  var (
    37  	checkconstMagic  = "\\+check(const|align|offset|size)"
    38  	checkconstRegexp = regexp.MustCompile(checkconstMagic)
    39  	constRegexp      = regexp.MustCompile("//\\s+" + checkconstMagic + "\\s+([A-Za-z0-9_\\./]+)\\s+([A-Za-z0-9_\\.]+)")
    40  	defineRegexp     = regexp.MustCompile("#define\\s+[A-Za-z0-9_]+\\s+([A-Za-z0-9_]+\\s*\\+\\s*)*([x0-9]+)\\s+//\\s+" + checkconstMagic + "\\s+([A-Za-z0-9_\\./]+)\\s+([A-Za-z0-9_\\.]+)")
    41  )
    42  
    43  // Analyzer defines the entrypoint.
    44  var Analyzer = &analysis.Analyzer{
    45  	Name: "checkconst",
    46  	Doc:  "validates basic constants",
    47  	Run:  run,
    48  	FactTypes: []analysis.Fact{
    49  		(*Constants)(nil),
    50  	},
    51  }
    52  
    53  // Constants contains all constant values.
    54  type Constants struct {
    55  	Alignments map[string]int64
    56  	Offsets    map[string]int64
    57  	Sizes      map[string]int64
    58  	Values     map[string]string
    59  }
    60  
    61  // AFact implements analysis.Fact.AFact.
    62  func (*Constants) AFact() {}
    63  
    64  // walkObject walks a local object hierarchy.
    65  func (c *Constants) walkObject(pass *analysis.Pass, parents []string, obj types.Object) {
    66  	switch x := obj.(type) {
    67  	case *types.Const:
    68  		name := strings.Join(parents, ".")
    69  		c.Values[name] = x.Val().ExactString()
    70  	case *types.PkgName:
    71  		// Don't walk to other packages.
    72  	case *types.Var:
    73  		// Add information as a field.
    74  		bestEffort(func() {
    75  			name := strings.Join(parents, ".")
    76  			c.Alignments[name] = pass.TypesSizes.Alignof(x.Type())
    77  			c.Sizes[name] = pass.TypesSizes.Sizeof(x.Type())
    78  		})
    79  	case *types.TypeName:
    80  		// Skip if just an alias, or if not underlying type, or if a
    81  		// type parameter. If it is not an alias, then it must be
    82  		// package-local.
    83  		typ := x.Type()
    84  		if typ == nil || typ.Underlying() == nil {
    85  			break
    86  		}
    87  		if _, ok := typ.(*types.TypeParam); ok {
    88  			break
    89  		}
    90  		// Add basic information.
    91  		bestEffort(func() {
    92  			name := strings.Join(parents, ".")
    93  			c.Alignments[name] = pass.TypesSizes.Alignof(typ)
    94  			c.Sizes[name] = pass.TypesSizes.Sizeof(typ)
    95  		})
    96  		// Recurse to fields if this is a definition.
    97  		if structType, ok := typ.Underlying().(*types.Struct); ok {
    98  			fields := make([]*types.Var, 0, structType.NumFields())
    99  			for i := 0; i < structType.NumFields(); i++ {
   100  				fieldObj := structType.Field(i)
   101  				fields = append(fields, fieldObj)
   102  				c.walkObject(pass, append(parents, fieldObj.Name()), fieldObj)
   103  			}
   104  			bestEffort(func() {
   105  				offsets := pass.TypesSizes.Offsetsof(fields)
   106  				for i, field := range fields {
   107  					fieldName := strings.Join(append(parents, field.Name()), ".")
   108  					c.Offsets[fieldName] = offsets[i]
   109  				}
   110  			})
   111  		}
   112  	}
   113  }
   114  
   115  // bestEffort is a panic/recover wrapper. This is used because the tools
   116  // library occasionally panics due to some type parameter use, and there is
   117  // simple or obvious way to detect these conditions. This should only be used
   118  // when absolutely necessary.
   119  func bestEffort(fn func()) {
   120  	defer func() {
   121  		recover()
   122  	}()
   123  	fn()
   124  }
   125  
   126  // walkScope recursively resolves a scope.
   127  func (c *Constants) walkScope(pass *analysis.Pass, parents []string, scope *types.Scope) {
   128  	for _, name := range scope.Names() {
   129  		c.walkObject(pass, append(parents, name), scope.Lookup(name))
   130  	}
   131  }
   132  
   133  // extractFacts finds all local facts.
   134  func extractFacts(pass *analysis.Pass) {
   135  	c := Constants{
   136  		Alignments: make(map[string]int64),
   137  		Offsets:    make(map[string]int64),
   138  		Sizes:      make(map[string]int64),
   139  		Values:     make(map[string]string),
   140  	}
   141  
   142  	// Accumulate all facts.
   143  	c.walkScope(pass, make([]string, 0, 128), pass.Pkg.Scope())
   144  	pass.ExportPackageFact(&c)
   145  }
   146  
   147  // findPackage finds the package by name.
   148  func findPackage(pkg *types.Package, pkgName string) (*types.Package, error) {
   149  	if pkgName == "." || pkgName == "" {
   150  		return pkg, nil
   151  	}
   152  
   153  	// Attempt to resolve with the full path.
   154  	for _, importedPkg := range pkg.Imports() {
   155  		if importedPkg.Path() == pkgName {
   156  			return importedPkg, nil
   157  		}
   158  	}
   159  
   160  	// Attempt to resolve using the short name.
   161  	for _, importedPkg := range pkg.Imports() {
   162  		if importedPkg.Name() == pkgName {
   163  			return importedPkg, nil
   164  		}
   165  	}
   166  
   167  	// Attempt to resolve with the full path from transitive dependencies.
   168  	//
   169  	// This is needed for referencing internal/ packages which we cannot
   170  	// directly import, but can be reached indirectly (e.g., internal/abi
   171  	// is reachable from runtime).
   172  	//
   173  	// N.B. nogo/check.importer only loads facts on direct import, so
   174  	// ImportPackageFact may fail without an explicit import. See hack in
   175  	// nogo/check.Package.
   176  	visited := map[*types.Package]struct{}{}
   177  	var visit func(pkg *types.Package) *types.Package
   178  	visit = func(pkg *types.Package) *types.Package {
   179  		if _, ok := visited[pkg]; ok {
   180  			return nil
   181  		}
   182  		visited[pkg] = struct{}{}
   183  
   184  		if pkg.Path() == pkgName {
   185  			return pkg
   186  		}
   187  
   188  		for _, importedPkg := range pkg.Imports() {
   189  			if found := visit(importedPkg); found != nil {
   190  				return found
   191  			}
   192  		}
   193  
   194  		return nil
   195  	}
   196  	for _, importedPkg := range pkg.Imports() {
   197  		if found := visit(importedPkg); found != nil {
   198  			return found, nil
   199  		}
   200  	}
   201  
   202  	return nil, fmt.Errorf("unable to locate package %q (saw %v)", pkgName, visited)
   203  }
   204  
   205  // matchRegexp performs a regexp match with a sanity check.
   206  func matchRegexp(pass *analysis.Pass, pos func() token.Pos, re *regexp.Regexp, text string) ([]string, bool) {
   207  	m := re.FindStringSubmatch(text)
   208  	if m == nil && checkconstRegexp.FindString(text) != "" {
   209  		pass.Reportf(pos(), "potentially misformed checkconst directives")
   210  	}
   211  	return m, m != nil
   212  }
   213  
   214  // buildExpected builds the expected value.
   215  func buildExpected(pass *analysis.Pass, pos func() token.Pos, factName, pkgName, objName string) (string, bool) {
   216  	// First, resolve the package.
   217  	pkg, err := findPackage(pass.Pkg, pkgName)
   218  	if err != nil {
   219  		pass.Reportf(pos(), "unable to resolve package %q: %v", pkgName, err)
   220  		return "", false
   221  	}
   222  
   223  	// Next, read the appropriate facts.
   224  	var (
   225  		c  Constants
   226  		s  string
   227  		ok bool
   228  	)
   229  	if !pass.ImportPackageFact(pkg, &c) {
   230  		pass.Reportf(pos(), "constant package facts for %q are unavailable", pkg.Path())
   231  		return "", false
   232  	}
   233  
   234  	// Finally, format appropriately.
   235  	switch factName {
   236  	case "const":
   237  		s, ok = c.Values[objName]
   238  	case "align":
   239  		if v, vOk := c.Alignments[objName]; vOk {
   240  			s, ok = fmt.Sprintf("%d", v), true
   241  		}
   242  	case "offset":
   243  		if v, vOk := c.Offsets[objName]; vOk {
   244  			s, ok = fmt.Sprintf("%d", v), true
   245  		}
   246  	case "size":
   247  		if v, vOk := c.Sizes[objName]; vOk {
   248  			s, ok = fmt.Sprintf("%d", v), true
   249  		}
   250  	}
   251  	if !ok {
   252  		pass.Reportf(pos(), "fact of type %s unavailable for %q", factName, objName)
   253  	}
   254  	return s, ok
   255  }
   256  
   257  // checkAssembly checks assembly annotations.
   258  func checkAssembly(pass *analysis.Pass) error {
   259  	for _, filename := range pass.OtherFiles {
   260  		if !strings.HasSuffix(filename, ".s") {
   261  			continue
   262  		}
   263  		content, err := ioutil.ReadFile(filename)
   264  		if err != nil {
   265  			return fmt.Errorf("unable to read assembly file: %w", err)
   266  		}
   267  		// This uses the technique to report issues for assembly files
   268  		// as described by the Go documentation:
   269  		// https://pkg.go.dev/golang.org/x/tools/go/analysis#hdr-Pass
   270  		tf := pass.Fset.AddFile(filename, -1, len(content))
   271  		tf.SetLinesForContent(content)
   272  		lines := strings.Split(string(content), "\n")
   273  		for lineNumber, lineContent := range lines {
   274  			// N.B. This is not evaluated except lazily, since it
   275  			// will generate errors to attempt to grab the position
   276  			// at the end of input. Just avoid it.
   277  			pos := func() token.Pos {
   278  				return tf.LineStart(lineNumber + 1)
   279  			}
   280  			m, ok := matchRegexp(pass, pos, defineRegexp, lineContent)
   281  			if !ok {
   282  				continue // Already reported, if needed.
   283  			}
   284  			newValue, ok := buildExpected(pass, pos, m[3], m[4], m[5])
   285  			if !ok {
   286  				continue // Already reported.
   287  			}
   288  			// Convert our internal string to the given value. This essentially
   289  			// canonicalises the literal string provided in the assembly.
   290  			v, err := strconv.ParseInt(m[2], 10, 64)
   291  			if err == nil && fmt.Sprintf("%v", v) != newValue {
   292  				pass.Reportf(pos(), "got value %v, wanted %q", v, newValue)
   293  				continue
   294  			} else if err != nil && m[2] != newValue {
   295  				pass.Reportf(pos(), "got value %q, wanted %q", m[2], newValue)
   296  				continue
   297  			}
   298  		}
   299  	}
   300  	return nil
   301  }
   302  
   303  // checkConsts walks all package-level const objects.
   304  func checkConsts(pass *analysis.Pass) error {
   305  	for _, f := range pass.Files {
   306  		for _, decl := range f.Decls {
   307  			d, ok := decl.(*ast.GenDecl)
   308  			if !ok || d.Tok != token.CONST {
   309  				continue
   310  			}
   311  			findComments := func(vs *ast.ValueSpec) []*ast.Comment {
   312  				comments := make([]*ast.Comment, 0)
   313  				if d.Doc != nil {
   314  					// Include any formally associated doc from the block.
   315  					comments = append(comments, d.Doc.List...)
   316  				}
   317  				if vs.Doc != nil {
   318  					// Include any formally associated comments from the value.
   319  					comments = append(comments, vs.Doc.List...)
   320  				}
   321  				for _, cg := range f.Comments {
   322  					for _, c := range cg.List {
   323  						// Include any comments that appear on the same line
   324  						// as the value spec itself, which are not doc comments.
   325  						specPosition := pass.Fset.Position(vs.Pos())
   326  						commentPosition := pass.Fset.Position(c.Pos())
   327  						if specPosition.Line == commentPosition.Line && specPosition.Column < commentPosition.Column {
   328  							comments = append(comments, c)
   329  						}
   330  					}
   331  				}
   332  				return comments
   333  			}
   334  			for _, spec := range d.Specs {
   335  				vs := spec.(*ast.ValueSpec)
   336  				var (
   337  					expectedValue string
   338  					expectedSet   bool
   339  				)
   340  				for _, l := range findComments(vs) {
   341  					m, ok := matchRegexp(pass, l.Pos, constRegexp, l.Text)
   342  					if !ok {
   343  						continue // Already reported, if needed.
   344  					}
   345  					newValue, ok := buildExpected(pass, l.Pos, m[1], m[2], m[3])
   346  					if ok {
   347  						if expectedSet && newValue != expectedValue {
   348  							pass.Reportf(l.Pos(), "multiple conflicting values")
   349  							continue
   350  						}
   351  						expectedValue = newValue
   352  						expectedSet = true
   353  					}
   354  				}
   355  				if !expectedSet {
   356  					continue // Nothing was set.
   357  				}
   358  				// Format the expression.
   359  				for _, valueExpr := range vs.Values {
   360  					val := pass.TypesInfo.Types[valueExpr].Value
   361  					s := fmt.Sprint(val)
   362  					if s != expectedValue {
   363  						pass.Reportf(valueExpr.Pos(), "got value %q, wanted %q", s, expectedValue)
   364  						continue
   365  					}
   366  				}
   367  			}
   368  		}
   369  	}
   370  	return nil
   371  }
   372  
   373  func run(pass *analysis.Pass) (any, error) {
   374  	// Extract all local facts. This is done against the compiled objects,
   375  	// rather than the source-level analysis, which is done below.
   376  	extractFacts(pass)
   377  
   378  	// Check the local package.
   379  	if err := checkConsts(pass); err != nil {
   380  		return nil, err
   381  	}
   382  	if err := checkAssembly(pass); err != nil {
   383  		return nil, err
   384  	}
   385  	return nil, nil
   386  }