github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/staticcheck/sa4031/sa4031.go (about)

     1  package sa4031
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/token"
     7  	"go/types"
     8  	"sort"
     9  
    10  	"github.com/amarpal/go-tools/analysis/code"
    11  	"github.com/amarpal/go-tools/analysis/lint"
    12  	"github.com/amarpal/go-tools/analysis/report"
    13  	"github.com/amarpal/go-tools/go/ir"
    14  	"github.com/amarpal/go-tools/internal/passes/buildir"
    15  	"github.com/amarpal/go-tools/pattern"
    16  	"github.com/amarpal/go-tools/staticcheck/sa4022"
    17  
    18  	"golang.org/x/tools/go/analysis"
    19  	"golang.org/x/tools/go/analysis/passes/inspect"
    20  )
    21  
    22  var SCAnalyzer = lint.InitializeAnalyzer(&lint.Analyzer{
    23  	Analyzer: &analysis.Analyzer{
    24  		Name:     "SA4031",
    25  		Run:      run,
    26  		Requires: []*analysis.Analyzer{buildir.Analyzer, inspect.Analyzer},
    27  	},
    28  	Doc: &lint.Documentation{
    29  		Title:    `Checking never-nil value against nil`,
    30  		Since:    "2022.1",
    31  		Severity: lint.SeverityWarning,
    32  		MergeIf:  lint.MergeIfAny,
    33  	},
    34  })
    35  
    36  var Analyzer = SCAnalyzer.Analyzer
    37  
    38  var allocationNilCheckQ = pattern.MustParse(`(IfStmt _ cond@(BinaryExpr lhs op@(Or "==" "!=") (Builtin "nil")) _ _)`)
    39  
    40  func run(pass *analysis.Pass) (interface{}, error) {
    41  	irpkg := pass.ResultOf[buildir.Analyzer].(*buildir.IR).Pkg
    42  
    43  	var path []ast.Node
    44  	fn := func(node ast.Node, stack []ast.Node) {
    45  		m, ok := code.Match(pass, allocationNilCheckQ, node)
    46  		if !ok {
    47  			return
    48  		}
    49  		cond := m.State["cond"].(ast.Node)
    50  		if _, ok := code.Match(pass, sa4022.CheckAddressIsNilQ, cond); ok {
    51  			// Don't duplicate diagnostics reported by SA4022
    52  			return
    53  		}
    54  		lhs := m.State["lhs"].(ast.Expr)
    55  		path = path[:0]
    56  		for i := len(stack) - 1; i >= 0; i-- {
    57  			path = append(path, stack[i])
    58  		}
    59  		irfn := ir.EnclosingFunction(irpkg, path)
    60  		if irfn == nil {
    61  			// For example for functions named "_", because we don't generate IR for them.
    62  			return
    63  		}
    64  		v, isAddr := irfn.ValueForExpr(lhs)
    65  		if isAddr {
    66  			return
    67  		}
    68  
    69  		seen := map[ir.Value]struct{}{}
    70  		var values []ir.Value
    71  		var neverNil func(v ir.Value, track bool) bool
    72  		neverNil = func(v ir.Value, track bool) bool {
    73  			if _, ok := seen[v]; ok {
    74  				return true
    75  			}
    76  			seen[v] = struct{}{}
    77  			switch v := v.(type) {
    78  			case *ir.MakeClosure, *ir.Function:
    79  				if track {
    80  					values = append(values, v)
    81  				}
    82  				return true
    83  			case *ir.MakeChan, *ir.MakeMap, *ir.MakeSlice, *ir.Alloc:
    84  				if track {
    85  					values = append(values, v)
    86  				}
    87  				return true
    88  			case *ir.Slice:
    89  				if track {
    90  					values = append(values, v)
    91  				}
    92  				return neverNil(v.X, false)
    93  			case *ir.FieldAddr:
    94  				if track {
    95  					values = append(values, v)
    96  				}
    97  				return neverNil(v.X, false)
    98  			case *ir.Sigma:
    99  				return neverNil(v.X, true)
   100  			case *ir.Phi:
   101  				for _, e := range v.Edges {
   102  					if !neverNil(e, true) {
   103  						return false
   104  					}
   105  				}
   106  				return true
   107  			default:
   108  				return false
   109  			}
   110  		}
   111  
   112  		if !neverNil(v, true) {
   113  			return
   114  		}
   115  
   116  		var qualifier string
   117  		if op := m.State["op"].(token.Token); op == token.EQL {
   118  			qualifier = "never"
   119  		} else {
   120  			qualifier = "always"
   121  		}
   122  		fallback := fmt.Sprintf("this nil check is %s true", qualifier)
   123  
   124  		sort.Slice(values, func(i, j int) bool { return values[i].Pos() < values[j].Pos() })
   125  
   126  		if ident, ok := m.State["lhs"].(*ast.Ident); ok {
   127  			if _, ok := pass.TypesInfo.ObjectOf(ident).(*types.Var); ok {
   128  				var opts []report.Option
   129  				if v.Parent() == irfn {
   130  					if len(values) == 1 {
   131  						opts = append(opts, report.Related(values[0], fmt.Sprintf("this is the value of %s", ident.Name)))
   132  					} else {
   133  						for _, vv := range values {
   134  							opts = append(opts, report.Related(vv, fmt.Sprintf("this is one of the value of %s", ident.Name)))
   135  						}
   136  					}
   137  				}
   138  
   139  				switch v.(type) {
   140  				case *ir.MakeClosure, *ir.Function:
   141  					report.Report(pass, cond, "the checked variable contains a function and is never nil; did you mean to call it?", opts...)
   142  				default:
   143  					report.Report(pass, cond, fallback, opts...)
   144  				}
   145  			} else {
   146  				if _, ok := v.(*ir.Function); ok {
   147  					report.Report(pass, cond, "functions are never nil; did you mean to call it?")
   148  				} else {
   149  					report.Report(pass, cond, fallback)
   150  				}
   151  			}
   152  		} else {
   153  			if _, ok := v.(*ir.Function); ok {
   154  				report.Report(pass, cond, "functions are never nil; did you mean to call it?")
   155  			} else {
   156  				report.Report(pass, cond, fallback)
   157  			}
   158  		}
   159  	}
   160  	code.PreorderStack(pass, fn, (*ast.IfStmt)(nil))
   161  	return nil, nil
   162  }