github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/staticcheck/sa1031/sa1031.go (about)

     1  package sa1031
     2  
     3  import (
     4  	"go/constant"
     5  	"go/token"
     6  
     7  	"github.com/amarpal/go-tools/analysis/callcheck"
     8  	"github.com/amarpal/go-tools/analysis/lint"
     9  	"github.com/amarpal/go-tools/go/ir"
    10  	"github.com/amarpal/go-tools/go/ir/irutil"
    11  	"github.com/amarpal/go-tools/internal/passes/buildir"
    12  	"github.com/amarpal/go-tools/knowledge"
    13  
    14  	"golang.org/x/tools/go/analysis"
    15  )
    16  
    17  var SCAnalyzer = lint.InitializeAnalyzer(&lint.Analyzer{
    18  	Analyzer: &analysis.Analyzer{
    19  		Name:     "SA1031",
    20  		Requires: []*analysis.Analyzer{buildir.Analyzer},
    21  		Run:      callcheck.Analyzer(checkEncodeRules),
    22  	},
    23  	Doc: &lint.Documentation{
    24  		Title: `Overlapping byte slices passed to an encoder`,
    25  		Text: `In an encoding function of the form \'Encode(dst, src)\', \'dst\' and
    26  \'src\' were found to reference the same memory. This can result in
    27  \'src\' bytes being overwritten before they are read, when the encoder
    28  writes more than one byte per \'src\' byte.`,
    29  		Since:    "Unreleased",
    30  		Severity: lint.SeverityWarning,
    31  		MergeIf:  lint.MergeIfAny,
    32  	},
    33  })
    34  
    35  var Analyzer = SCAnalyzer.Analyzer
    36  
    37  var checkEncodeRules = map[string]callcheck.Check{
    38  	"encoding/ascii85.Encode":            checkNonOverlappingDstSrc(knowledge.Arg("encoding/ascii85.Encode.dst"), knowledge.Arg("encoding/ascii85.Encode.src")),
    39  	"(*encoding/base32.Encoding).Encode": checkNonOverlappingDstSrc(knowledge.Arg("(*encoding/base32.Encoding).Encode.dst"), knowledge.Arg("(*encoding/base32.Encoding).Encode.src")),
    40  	"(*encoding/base64.Encoding).Encode": checkNonOverlappingDstSrc(knowledge.Arg("(*encoding/base64.Encoding).Encode.dst"), knowledge.Arg("(*encoding/base64.Encoding).Encode.src")),
    41  	"encoding/hex.Encode":                checkNonOverlappingDstSrc(knowledge.Arg("encoding/hex.Encode.dst"), knowledge.Arg("encoding/hex.Encode.src")),
    42  }
    43  
    44  func checkNonOverlappingDstSrc(dstArg, srcArg int) callcheck.Check {
    45  	return func(call *callcheck.Call) {
    46  		dst := call.Args[dstArg]
    47  		src := call.Args[srcArg]
    48  		_, dstConst := irutil.Flatten(dst.Value.Value).(*ir.Const)
    49  		_, srcConst := irutil.Flatten(src.Value.Value).(*ir.Const)
    50  		if dstConst || srcConst {
    51  			// one of the arguments is nil, therefore overlap is not possible
    52  			return
    53  		}
    54  		if dst.Value == src.Value {
    55  			// simple case of f(b, b)
    56  			dst.Invalid("overlapping dst and src")
    57  			return
    58  		}
    59  		dstSlice, ok := irutil.Flatten(dst.Value.Value).(*ir.Slice)
    60  		if !ok {
    61  			return
    62  		}
    63  		srcSlice, ok := irutil.Flatten(src.Value.Value).(*ir.Slice)
    64  		if !ok {
    65  			return
    66  		}
    67  		if irutil.Flatten(dstSlice.X) != irutil.Flatten(srcSlice.X) {
    68  			// differing underlying arrays, all is well
    69  			return
    70  		}
    71  		l1 := irutil.Flatten(dstSlice.Low)
    72  		l2 := irutil.Flatten(srcSlice.Low)
    73  		c1, ok1 := l1.(*ir.Const)
    74  		c2, ok2 := l2.(*ir.Const)
    75  		if l1 == l2 || (ok1 && ok2 && constant.Compare(c1.Value, token.EQL, c2.Value)) {
    76  			// dst and src are the same slice, and have the same lower bound
    77  			dst.Invalid("overlapping dst and src")
    78  			return
    79  		}
    80  	}
    81  }