github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/staticcheck/sa5009/sa5009.go (about)

     1  package sa5009
     2  
     3  import (
     4  	"fmt"
     5  	"go/constant"
     6  	"go/types"
     7  
     8  	"github.com/amarpal/go-tools/analysis/callcheck"
     9  	"github.com/amarpal/go-tools/analysis/lint"
    10  	"github.com/amarpal/go-tools/go/ir"
    11  	"github.com/amarpal/go-tools/go/ir/irutil"
    12  	"github.com/amarpal/go-tools/go/types/typeutil"
    13  	"github.com/amarpal/go-tools/internal/passes/buildir"
    14  	"github.com/amarpal/go-tools/knowledge"
    15  	"github.com/amarpal/go-tools/printf"
    16  
    17  	"golang.org/x/tools/go/analysis"
    18  )
    19  
    20  var SCAnalyzer = lint.InitializeAnalyzer(&lint.Analyzer{
    21  	Analyzer: &analysis.Analyzer{
    22  		Name:     "SA5009",
    23  		Requires: []*analysis.Analyzer{buildir.Analyzer},
    24  		Run:      callcheck.Analyzer(rules),
    25  	},
    26  	Doc: &lint.Documentation{
    27  		Title:    `Invalid Printf call`,
    28  		Since:    "2019.2",
    29  		Severity: lint.SeverityError,
    30  		MergeIf:  lint.MergeIfAny,
    31  	},
    32  })
    33  
    34  var Analyzer = SCAnalyzer.Analyzer
    35  
    36  // TODO(dh): detect printf wrappers
    37  var rules = map[string]callcheck.Check{
    38  	"fmt.Errorf":                  func(call *callcheck.Call) { check(call, 0, 1) },
    39  	"fmt.Printf":                  func(call *callcheck.Call) { check(call, 0, 1) },
    40  	"fmt.Sprintf":                 func(call *callcheck.Call) { check(call, 0, 1) },
    41  	"fmt.Fprintf":                 func(call *callcheck.Call) { check(call, 1, 2) },
    42  	"golang.org/x/xerrors.Errorf": func(call *callcheck.Call) { check(call, 0, 1) },
    43  }
    44  
    45  type verbFlag int
    46  
    47  const (
    48  	isInt verbFlag = 1 << iota
    49  	isBool
    50  	isFP
    51  	isString
    52  	isPointer
    53  	// Verbs that accept "pseudo pointers" will sometimes dereference
    54  	// non-nil pointers. For example, %x on a non-nil *struct will print the
    55  	// individual fields, but on a nil pointer it will print the address.
    56  	isPseudoPointer
    57  	isSlice
    58  	isAny
    59  	noRecurse
    60  )
    61  
    62  var verbs = [...]verbFlag{
    63  	'b': isPseudoPointer | isInt | isFP,
    64  	'c': isInt,
    65  	'd': isPseudoPointer | isInt,
    66  	'e': isFP,
    67  	'E': isFP,
    68  	'f': isFP,
    69  	'F': isFP,
    70  	'g': isFP,
    71  	'G': isFP,
    72  	'o': isPseudoPointer | isInt,
    73  	'O': isPseudoPointer | isInt,
    74  	'p': isSlice | isPointer | noRecurse,
    75  	'q': isInt | isString,
    76  	's': isString,
    77  	't': isBool,
    78  	'T': isAny,
    79  	'U': isInt,
    80  	'v': isAny,
    81  	'X': isPseudoPointer | isInt | isFP | isString,
    82  	'x': isPseudoPointer | isInt | isFP | isString,
    83  }
    84  
    85  func check(call *callcheck.Call, fIdx, vIdx int) {
    86  	f := call.Args[fIdx]
    87  	var args []ir.Value
    88  	switch v := call.Args[vIdx].Value.Value.(type) {
    89  	case *ir.Slice:
    90  		var ok bool
    91  		args, ok = irutil.Vararg(v)
    92  		if !ok {
    93  			// We don't know what the actual arguments to the function are
    94  			return
    95  		}
    96  	case *ir.Const:
    97  		// nil, i.e. no arguments
    98  	default:
    99  		// We don't know what the actual arguments to the function are
   100  		return
   101  	}
   102  	checkImpl(f, f.Value.Value, args)
   103  }
   104  
   105  func checkImpl(carg *callcheck.Argument, f ir.Value, args []ir.Value) {
   106  	var msCache *typeutil.MethodSetCache
   107  	if f.Parent() != nil {
   108  		msCache = &f.Parent().Prog.MethodSets
   109  	}
   110  
   111  	elem := func(T types.Type, verb rune) ([]types.Type, bool) {
   112  		if verbs[verb]&noRecurse != 0 {
   113  			return []types.Type{T}, false
   114  		}
   115  		switch T := T.(type) {
   116  		case *types.Slice:
   117  			if verbs[verb]&isSlice != 0 {
   118  				return []types.Type{T}, false
   119  			}
   120  			if verbs[verb]&isString != 0 && typeutil.IsType(T.Elem().Underlying(), "byte") {
   121  				return []types.Type{T}, false
   122  			}
   123  			return []types.Type{T.Elem()}, true
   124  		case *types.Map:
   125  			key := T.Key()
   126  			val := T.Elem()
   127  			return []types.Type{key, val}, true
   128  		case *types.Struct:
   129  			out := make([]types.Type, 0, T.NumFields())
   130  			for i := 0; i < T.NumFields(); i++ {
   131  				out = append(out, T.Field(i).Type())
   132  			}
   133  			return out, true
   134  		case *types.Array:
   135  			return []types.Type{T.Elem()}, true
   136  		default:
   137  			return []types.Type{T}, false
   138  		}
   139  	}
   140  	isInfo := func(T types.Type, info types.BasicInfo) bool {
   141  		basic, ok := T.Underlying().(*types.Basic)
   142  		return ok && basic.Info()&info != 0
   143  	}
   144  
   145  	isFormatter := func(T types.Type, ms *types.MethodSet) bool {
   146  		sel := ms.Lookup(nil, "Format")
   147  		if sel == nil {
   148  			return false
   149  		}
   150  		fn, ok := sel.Obj().(*types.Func)
   151  		if !ok {
   152  			// should be unreachable
   153  			return false
   154  		}
   155  		sig := fn.Type().(*types.Signature)
   156  		if sig.Params().Len() != 2 {
   157  			return false
   158  		}
   159  		// TODO(dh): check the types of the arguments for more
   160  		// precision
   161  		if sig.Results().Len() != 0 {
   162  			return false
   163  		}
   164  		return true
   165  	}
   166  
   167  	var seen typeutil.Map[struct{}]
   168  	var checkType func(verb rune, T types.Type, top bool) bool
   169  	checkType = func(verb rune, T types.Type, top bool) bool {
   170  		if top {
   171  			seen = typeutil.Map[struct{}]{}
   172  		}
   173  		if _, ok := seen.At(T); ok {
   174  			return true
   175  		}
   176  		seen.Set(T, struct{}{})
   177  		if int(verb) >= len(verbs) {
   178  			// Unknown verb
   179  			return true
   180  		}
   181  
   182  		flags := verbs[verb]
   183  		if flags == 0 {
   184  			// Unknown verb
   185  			return true
   186  		}
   187  
   188  		ms := msCache.MethodSet(T)
   189  		if isFormatter(T, ms) {
   190  			// the value is responsible for formatting itself
   191  			return true
   192  		}
   193  
   194  		if flags&isString != 0 && (types.Implements(T, knowledge.Interfaces["fmt.Stringer"]) || types.Implements(T, knowledge.Interfaces["error"])) {
   195  			// Check for stringer early because we're about to dereference
   196  			return true
   197  		}
   198  
   199  		T = T.Underlying()
   200  		if flags&(isPointer|isPseudoPointer) == 0 && top {
   201  			T = typeutil.Dereference(T)
   202  		}
   203  		if flags&isPseudoPointer != 0 && top {
   204  			t := typeutil.Dereference(T)
   205  			if _, ok := t.Underlying().(*types.Struct); ok {
   206  				T = t
   207  			}
   208  		}
   209  
   210  		if _, ok := T.(*types.Interface); ok {
   211  			// We don't know what's in the interface
   212  			return true
   213  		}
   214  
   215  		var info types.BasicInfo
   216  		if flags&isInt != 0 {
   217  			info |= types.IsInteger
   218  		}
   219  		if flags&isBool != 0 {
   220  			info |= types.IsBoolean
   221  		}
   222  		if flags&isFP != 0 {
   223  			info |= types.IsFloat | types.IsComplex
   224  		}
   225  		if flags&isString != 0 {
   226  			info |= types.IsString
   227  		}
   228  
   229  		if info != 0 && isInfo(T, info) {
   230  			return true
   231  		}
   232  
   233  		if flags&isString != 0 {
   234  			isStringyElem := func(typ types.Type) bool {
   235  				if typ, ok := typ.Underlying().(*types.Basic); ok {
   236  					return typ.Kind() == types.Byte
   237  				}
   238  				return false
   239  			}
   240  			switch T := T.(type) {
   241  			case *types.Slice:
   242  				if isStringyElem(T.Elem()) {
   243  					return true
   244  				}
   245  			case *types.Array:
   246  				if isStringyElem(T.Elem()) {
   247  					return true
   248  				}
   249  			}
   250  			if types.Implements(T, knowledge.Interfaces["fmt.Stringer"]) || types.Implements(T, knowledge.Interfaces["error"]) {
   251  				return true
   252  			}
   253  		}
   254  
   255  		if flags&isPointer != 0 && typeutil.IsPointerLike(T) {
   256  			return true
   257  		}
   258  		if flags&isPseudoPointer != 0 {
   259  			switch U := T.Underlying().(type) {
   260  			case *types.Pointer:
   261  				if !top {
   262  					return true
   263  				}
   264  
   265  				if _, ok := U.Elem().Underlying().(*types.Struct); !ok {
   266  					// TODO(dh): can this condition ever be false? For
   267  					// *T, if T is a struct, we'll already have
   268  					// dereferenced it, meaning the *types.Pointer
   269  					// branch couldn't have been taken. For T that
   270  					// aren't structs, this condition will always
   271  					// evaluate to true.
   272  					return true
   273  				}
   274  			case *types.Chan, *types.Signature:
   275  				// Channels and functions are always treated as
   276  				// pointers and never recursed into.
   277  				return true
   278  			case *types.Basic:
   279  				if U.Kind() == types.UnsafePointer {
   280  					return true
   281  				}
   282  			case *types.Interface:
   283  				// we will already have bailed if the type is an
   284  				// interface.
   285  				panic("unreachable")
   286  			default:
   287  				// other pointer-like types, such as maps or slices,
   288  				// will be printed element-wise.
   289  			}
   290  		}
   291  
   292  		if flags&isSlice != 0 {
   293  			if _, ok := T.(*types.Slice); ok {
   294  				return true
   295  			}
   296  		}
   297  
   298  		if flags&isAny != 0 {
   299  			return true
   300  		}
   301  
   302  		elems, ok := elem(T.Underlying(), verb)
   303  		if !ok {
   304  			return false
   305  		}
   306  		for _, elem := range elems {
   307  			if !checkType(verb, elem, false) {
   308  				return false
   309  			}
   310  		}
   311  
   312  		return true
   313  	}
   314  
   315  	k, ok := irutil.Flatten(f).(*ir.Const)
   316  	if !ok {
   317  		return
   318  	}
   319  	actions, err := printf.Parse(constant.StringVal(k.Value))
   320  	if err != nil {
   321  		carg.Invalid("couldn't parse format string")
   322  		return
   323  	}
   324  
   325  	ptr := 1
   326  	hasExplicit := false
   327  
   328  	checkStar := func(verb printf.Verb, star printf.Argument) bool {
   329  		if star, ok := star.(printf.Star); ok {
   330  			idx := 0
   331  			if star.Index == -1 {
   332  				idx = ptr
   333  				ptr++
   334  			} else {
   335  				hasExplicit = true
   336  				idx = star.Index
   337  				ptr = star.Index + 1
   338  			}
   339  			if idx == 0 {
   340  				carg.Invalid(fmt.Sprintf("Printf format %s reads invalid arg 0; indices are 1-based", verb.Raw))
   341  				return false
   342  			}
   343  			if idx > len(args) {
   344  				carg.Invalid(
   345  					fmt.Sprintf("Printf format %s reads arg #%d, but call has only %d args",
   346  						verb.Raw, idx, len(args)))
   347  				return false
   348  			}
   349  			if arg, ok := args[idx-1].(*ir.MakeInterface); ok {
   350  				if !isInfo(arg.X.Type(), types.IsInteger) {
   351  					carg.Invalid(fmt.Sprintf("Printf format %s reads non-int arg #%d as argument of *", verb.Raw, idx))
   352  				}
   353  			}
   354  		}
   355  		return true
   356  	}
   357  
   358  	// We only report one problem per format string. Making a
   359  	// mistake with an index tends to invalidate all future
   360  	// implicit indices.
   361  	for _, action := range actions {
   362  		verb, ok := action.(printf.Verb)
   363  		if !ok {
   364  			continue
   365  		}
   366  
   367  		if !checkStar(verb, verb.Width) || !checkStar(verb, verb.Precision) {
   368  			return
   369  		}
   370  
   371  		off := ptr
   372  		if verb.Value != -1 {
   373  			hasExplicit = true
   374  			off = verb.Value
   375  		}
   376  		if off > len(args) {
   377  			carg.Invalid(
   378  				fmt.Sprintf("Printf format %s reads arg #%d, but call has only %d args",
   379  					verb.Raw, off, len(args)))
   380  			return
   381  		} else if verb.Value == 0 && verb.Letter != '%' {
   382  			carg.Invalid(fmt.Sprintf("Printf format %s reads invalid arg 0; indices are 1-based", verb.Raw))
   383  			return
   384  		} else if off != 0 {
   385  			arg, ok := args[off-1].(*ir.MakeInterface)
   386  			if ok {
   387  				if !checkType(verb.Letter, arg.X.Type(), true) {
   388  					carg.Invalid(fmt.Sprintf("Printf format %s has arg #%d of wrong type %s",
   389  						verb.Raw, ptr, args[ptr-1].(*ir.MakeInterface).X.Type()))
   390  					return
   391  				}
   392  			}
   393  		}
   394  
   395  		switch verb.Value {
   396  		case -1:
   397  			// Consume next argument
   398  			ptr++
   399  		case 0:
   400  			// Don't consume any arguments
   401  		default:
   402  			ptr = verb.Value + 1
   403  		}
   404  	}
   405  
   406  	if !hasExplicit && ptr <= len(args) {
   407  		carg.Invalid(fmt.Sprintf("Printf call needs %d args but has %d args", ptr-1, len(args)))
   408  	}
   409  }