github.com/samiam2013/sqlvet@v0.0.0-20221210043606-d72f678fc0aa/pkg/vet/gosource.go (about)

     1  package vet
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"go/ast"
     7  	"go/constant"
     8  	"go/token"
     9  	"go/types"
    10  	"os"
    11  	"path/filepath"
    12  	"reflect"
    13  	"sort"
    14  	"strings"
    15  
    16  	"golang.org/x/tools/go/callgraph"
    17  	"golang.org/x/tools/go/packages"
    18  	"golang.org/x/tools/go/pointer"
    19  	"golang.org/x/tools/go/ssa"
    20  	"golang.org/x/tools/go/ssa/ssautil"
    21  
    22  	log "github.com/sirupsen/logrus"
    23  
    24  	"github.com/samiam2013/sqlvet/pkg/parseutil"
    25  )
    26  
    27  var (
    28  	ErrQueryArgUnsupportedType = errors.New("unexpected query arg type")
    29  	ErrQueryArgUnsafe          = errors.New("potentially unsafe query string")
    30  	ErrQueryArgTODO            = errors.New("TODO: support this type")
    31  )
    32  
    33  type QuerySite struct {
    34  	Called            string
    35  	Position          token.Position
    36  	Query             string
    37  	ParameterArgCount int
    38  	Err               error
    39  }
    40  
    41  type MatchedSqlFunc struct {
    42  	SSA         *ssa.Function
    43  	QueryArgPos int
    44  }
    45  
    46  type SqlFuncMatchRule struct {
    47  	FuncName string `toml:"func_name"`
    48  	// zero indexed
    49  	QueryArgPos  int    `toml:"query_arg_pos"`
    50  	QueryArgName string `toml:"query_arg_name"`
    51  }
    52  
    53  type SqlFuncMatcher struct {
    54  	PkgPath string             `toml:"pkg_path"`
    55  	Rules   []SqlFuncMatchRule `toml:"rules"`
    56  
    57  	pkg *packages.Package
    58  }
    59  
    60  func (s *SqlFuncMatcher) SetGoPackage(p *packages.Package) {
    61  	s.pkg = p
    62  }
    63  
    64  func (s *SqlFuncMatcher) PackageImported() bool {
    65  	return s.pkg != nil
    66  }
    67  
    68  func (s *SqlFuncMatcher) IterPackageExportedFuncs(cb func(*types.Func)) {
    69  	scope := s.pkg.Types.Scope()
    70  	for _, scopeName := range scope.Names() {
    71  		obj := scope.Lookup(scopeName)
    72  		if !obj.Exported() {
    73  			continue
    74  		}
    75  
    76  		fobj, ok := obj.(*types.Func)
    77  		if ok {
    78  			cb(fobj)
    79  		} else {
    80  			// check for exported struct methods
    81  			switch otype := obj.Type().(type) {
    82  			case *types.Signature:
    83  			case *types.Named:
    84  				for i := 0; i < otype.NumMethods(); i++ {
    85  					m := otype.Method(i)
    86  					if !m.Exported() {
    87  						continue
    88  					}
    89  					cb(m)
    90  				}
    91  			case *types.Basic:
    92  			default:
    93  				log.Debugf("Skipped pkg scope: %s (%s)", otype, reflect.TypeOf(otype))
    94  			}
    95  		}
    96  	}
    97  }
    98  
    99  func (s *SqlFuncMatcher) MatchSqlFuncs(prog *ssa.Program) []MatchedSqlFunc {
   100  	sqlfuncs := []MatchedSqlFunc{}
   101  
   102  	s.IterPackageExportedFuncs(func(fobj *types.Func) {
   103  		for _, rule := range s.Rules {
   104  			if rule.FuncName != "" && fobj.Name() == rule.FuncName {
   105  				sqlfuncs = append(sqlfuncs, MatchedSqlFunc{
   106  					SSA:         prog.FuncValue(fobj),
   107  					QueryArgPos: rule.QueryArgPos,
   108  				})
   109  				// callable matched one rule, no need to go through the rest
   110  				break
   111  			}
   112  
   113  			if rule.QueryArgName != "" {
   114  				sigParams := fobj.Type().(*types.Signature).Params()
   115  				if sigParams.Len()-1 < rule.QueryArgPos {
   116  					continue
   117  				}
   118  				param := sigParams.At(rule.QueryArgPos)
   119  				if param.Name() != rule.QueryArgName {
   120  					continue
   121  				}
   122  				sqlfuncs = append(sqlfuncs, MatchedSqlFunc{
   123  					SSA:         prog.FuncValue(fobj),
   124  					QueryArgPos: rule.QueryArgPos,
   125  				})
   126  				// callable matched one rule, no need to go through the rest
   127  				break
   128  			}
   129  		}
   130  	})
   131  
   132  	return sqlfuncs
   133  }
   134  
   135  func handleQuery(ctx VetContext, qs *QuerySite) {
   136  	// TODO: apply named query resolution based on v.X type and v.Sel.Name
   137  	// e.g. for sqlx, only apply to NamedExec and NamedQuery
   138  	qs.Query, _, qs.Err = parseutil.CompileNamedQuery(
   139  		[]byte(qs.Query), parseutil.BindType("postgres"))
   140  	if qs.Err != nil {
   141  		return
   142  	}
   143  
   144  	var queryParams []QueryParam
   145  	queryParams, qs.Err = ValidateSqlQuery(ctx, qs.Query)
   146  
   147  	if qs.Err != nil {
   148  		return
   149  	}
   150  
   151  	// query string is valid, now validate parameter args if exists
   152  	if qs.ParameterArgCount < len(queryParams) {
   153  		// qs.Err = fmt.Errorf(
   154  		// 	"Query expects %d parameters, but received %d from function call",
   155  		// 	len(queryParams), qs.ParameterArgCount,
   156  		// )
   157  	}
   158  }
   159  
   160  func getMatchers(extraMatchers []SqlFuncMatcher) []*SqlFuncMatcher {
   161  	matchers := []*SqlFuncMatcher{
   162  		{
   163  			PkgPath: "github.com/jmoiron/sqlx",
   164  			Rules: []SqlFuncMatchRule{
   165  				{QueryArgName: "query"},
   166  				{QueryArgName: "sql"},
   167  				// for methods with Context suffix
   168  				{QueryArgName: "query", QueryArgPos: 1},
   169  				{QueryArgName: "sql", QueryArgPos: 1},
   170  				{QueryArgName: "query", QueryArgPos: 2},
   171  				{QueryArgName: "sql", QueryArgPos: 2},
   172  			},
   173  		},
   174  		{
   175  			PkgPath: "database/sql",
   176  			Rules: []SqlFuncMatchRule{
   177  				{QueryArgName: "query"},
   178  				{QueryArgName: "sql"},
   179  				// for methods with Context suffix
   180  				{QueryArgName: "query", QueryArgPos: 1},
   181  				{QueryArgName: "sql", QueryArgPos: 1},
   182  			},
   183  		},
   184  		{
   185  			PkgPath: "github.com/jinzhu/gorm",
   186  			Rules: []SqlFuncMatchRule{
   187  				{QueryArgName: "sql"},
   188  			},
   189  		},
   190  		// TODO: xorm uses vararg, which is not supported yet
   191  		// &SqlFuncMatcher{
   192  		// 	PkgPath: "xorm.io/xorm",
   193  		// 	Rules: []SqlFuncMatchRule{
   194  		// 		{FuncName: "SQL"},
   195  		// 		{FuncName: "Sql"},
   196  		// 		{FuncName: "Exec"},
   197  		// 		{FuncName: "Query"},
   198  		// 		{FuncName: "QueryInterface"},
   199  		// 		{FuncName: "QueryString"},
   200  		// 		{FuncName: "QuerySliceString"},
   201  		// 	},
   202  		// },
   203  		{
   204  			PkgPath: "go-gorp/gorp",
   205  			Rules: []SqlFuncMatchRule{
   206  				{QueryArgName: "query"},
   207  			},
   208  		},
   209  		{
   210  			PkgPath: "gopkg.in/gorp.v1",
   211  			Rules: []SqlFuncMatchRule{
   212  				{QueryArgName: "query"},
   213  			},
   214  		},
   215  	}
   216  	if extraMatchers != nil {
   217  		for _, m := range extraMatchers {
   218  			tmpm := m
   219  			matchers = append(matchers, &tmpm)
   220  		}
   221  	}
   222  
   223  	return matchers
   224  }
   225  
   226  func loadGoPackages(dir string, buildFlags string) ([]*packages.Package, error) {
   227  	cfg := &packages.Config{
   228  		Mode: packages.NeedName |
   229  			packages.NeedFiles |
   230  			packages.NeedImports |
   231  			packages.NeedDeps |
   232  			packages.NeedTypes |
   233  			packages.NeedSyntax |
   234  			packages.NeedTypesInfo,
   235  		Dir: dir,
   236  		Env: append(os.Environ(), "GO111MODULE=auto"),
   237  	}
   238  	if buildFlags != "" {
   239  		cfg.BuildFlags = strings.Split(buildFlags, " ")
   240  	}
   241  	dirAbs, err := filepath.Abs(dir)
   242  	if err != nil {
   243  		return nil, fmt.Errorf("Invalid path: %w", err)
   244  	}
   245  	pkgPath := dirAbs + "/..."
   246  	pkgs, err := packages.Load(cfg, pkgPath)
   247  	if err != nil {
   248  		return nil, err
   249  	}
   250  	// return early if any syntax error
   251  	for _, pkg := range pkgs {
   252  		if len(pkg.Errors) > 0 {
   253  			return nil, fmt.Errorf("Failed to load package, %w", pkg.Errors[0])
   254  		}
   255  	}
   256  	return pkgs, nil
   257  }
   258  
   259  func extractQueryStrFromSsaValue(argVal ssa.Value) (string, error) {
   260  	queryStr := ""
   261  
   262  	switch queryArg := argVal.(type) {
   263  	case *ssa.Const:
   264  		queryStr = constant.StringVal(queryArg.Value)
   265  	case *ssa.Phi:
   266  		// TODO: resolve all phi options
   267  		// for _, edge := range queryArg.Edges {
   268  		// }
   269  		log.Debug("TODO(callgraph) support ssa.Phi")
   270  		return "", ErrQueryArgTODO
   271  	case *ssa.BinOp:
   272  		// only support string concat
   273  		switch queryArg.Op {
   274  		case token.ADD:
   275  			lstr, err := extractQueryStrFromSsaValue(queryArg.X)
   276  			if err != nil {
   277  				return "", err
   278  			}
   279  			rstr, err := extractQueryStrFromSsaValue(queryArg.Y)
   280  			if err != nil {
   281  				return "", err
   282  			}
   283  			queryStr = lstr + rstr
   284  		default:
   285  			return "", ErrQueryArgUnsupportedType
   286  		}
   287  	case *ssa.Parameter:
   288  		// query call is wrapped in a helper function, query string is passed
   289  		// in as function parameter
   290  		// TODO: need to trace the caller or add wrapper function to
   291  		// matcher config
   292  		return "", ErrQueryArgTODO
   293  	case *ssa.Extract:
   294  		// query string is from one of the multi return values
   295  		// need to figure out how to trace string from function returns
   296  		return "", ErrQueryArgTODO
   297  	case *ssa.Call:
   298  		// return value from a function call
   299  		// TODO: trace caller function
   300  		return "", ErrQueryArgUnsafe
   301  	case *ssa.MakeInterface:
   302  		// query function takes interface as input
   303  		// check to see if interface is converted from a string
   304  		switch interfaceFrom := queryArg.X.(type) {
   305  		case *ssa.Const:
   306  			queryStr = constant.StringVal(interfaceFrom.Value)
   307  		default:
   308  			return "", ErrQueryArgUnsupportedType
   309  		}
   310  	case *ssa.Slice:
   311  		// function takes var arg as input
   312  
   313  		// Type() returns string if the type of X was string, otherwise a
   314  		// *types.Slice with the same element type as X.
   315  		if _, ok := queryArg.Type().(*types.Slice); ok {
   316  			log.Debug("TODO(callgraph) support slice for vararg")
   317  		}
   318  		return "", ErrQueryArgTODO
   319  	default:
   320  		return "", ErrQueryArgUnsupportedType
   321  	}
   322  
   323  	return queryStr, nil
   324  }
   325  
   326  func shouldIgnoreNode(ignoreNodes []ast.Node, callSitePos token.Pos) bool {
   327  	if len(ignoreNodes) == 0 {
   328  		return false
   329  	}
   330  
   331  	if callSitePos < ignoreNodes[0].Pos() {
   332  		return false
   333  	}
   334  
   335  	if callSitePos > ignoreNodes[len(ignoreNodes)-1].End() {
   336  		return false
   337  	}
   338  
   339  	for _, n := range ignoreNodes {
   340  		if callSitePos < n.End() && callSitePos > n.Pos() {
   341  			return true
   342  		}
   343  	}
   344  
   345  	return false
   346  }
   347  
   348  func iterCallGraphNodeCallees(ctx VetContext, cgNode *callgraph.Node, prog *ssa.Program, sqlfunc MatchedSqlFunc, ignoreNodes []ast.Node) []*QuerySite {
   349  	queries := []*QuerySite{}
   350  
   351  	for _, inEdge := range cgNode.In {
   352  		callerFunc := inEdge.Caller.Func
   353  		if callerFunc.Pkg == nil {
   354  			// skip calls from dependencies
   355  			continue
   356  		}
   357  
   358  		callSite := inEdge.Site
   359  		callSitePos := callSite.Pos()
   360  		if shouldIgnoreNode(ignoreNodes, callSitePos) {
   361  			continue
   362  		}
   363  
   364  		callSitePosition := prog.Fset.Position(callSitePos)
   365  		log.Debugf("Validating %s @ %s", sqlfunc.SSA, callSitePosition)
   366  
   367  		callArgs := callSite.Common().Args
   368  
   369  		absArgPos := sqlfunc.QueryArgPos
   370  		if callSite.Common().IsInvoke() {
   371  			// interface method invocation.
   372  			// In this mode, Value is the interface value and Method is the
   373  			// interface's abstract method. Note: an abstract method may be
   374  			// shared by multiple interfaces due to embedding; Value.Type()
   375  			// provides the specific interface used for this call.
   376  		} else {
   377  			// "call" mode: when Method is nil (!IsInvoke), a CallCommon
   378  			// represents an ordinary function call of the value in Value,
   379  			// which may be a *Builtin, a *Function or any other value of
   380  			// kind 'func'.
   381  			if sqlfunc.SSA.Signature.Recv() != nil {
   382  				// it's a struct method call, plus 1 to take receiver into
   383  				// account
   384  				absArgPos += 1
   385  			}
   386  		}
   387  		queryArg := callArgs[absArgPos]
   388  
   389  		qs := &QuerySite{
   390  			Called:   inEdge.Callee.Func.Name(),
   391  			Position: callSitePosition,
   392  			Err:      nil,
   393  		}
   394  
   395  		if len(callArgs) > absArgPos+1 {
   396  			// query function accepts query parameters
   397  			paramArg := callArgs[absArgPos+1]
   398  			// only support query param as variadic argument for now
   399  			switch params := paramArg.(type) {
   400  			case *ssa.Const:
   401  				// likely nil
   402  			case *ssa.Slice:
   403  				sliceType := params.X.Type()
   404  				switch t := sliceType.(type) {
   405  				case *types.Pointer:
   406  					elem := t.Elem()
   407  					switch e := elem.(type) {
   408  					case *types.Array:
   409  						// query parameters are passed in as vararg: an array
   410  						// of interface
   411  						qs.ParameterArgCount = int(e.Len())
   412  					}
   413  				}
   414  			}
   415  		}
   416  
   417  		qs.Query, qs.Err = extractQueryStrFromSsaValue(queryArg)
   418  		if qs.Err != nil {
   419  			switch qs.Err {
   420  			case ErrQueryArgUnsupportedType:
   421  				log.WithFields(log.Fields{
   422  					"type":      reflect.TypeOf(queryArg),
   423  					"pos":       prog.Fset.Position(callSite.Pos()),
   424  					"caller":    callerFunc,
   425  					"callerPkg": callerFunc.Pkg,
   426  				}).Debug(fmt.Errorf("unsupported type in callgraph: %w", qs.Err))
   427  			case ErrQueryArgTODO:
   428  				log.WithFields(log.Fields{
   429  					"type":      reflect.TypeOf(queryArg),
   430  					"pos":       prog.Fset.Position(callSite.Pos()),
   431  					"caller":    callerFunc,
   432  					"callerPkg": callerFunc.Pkg,
   433  				}).Debug(fmt.Errorf("TODO(callgraph) %w", qs.Err))
   434  				// skip to be supported query type
   435  				continue
   436  			default:
   437  				queries = append(queries, qs)
   438  				continue
   439  			}
   440  		}
   441  
   442  		if qs.Query == "" {
   443  			continue
   444  		}
   445  		handleQuery(ctx, qs)
   446  		queries = append(queries, qs)
   447  	}
   448  
   449  	return queries
   450  }
   451  
   452  func getSortedIgnoreNodes(pkgs []*packages.Package) []ast.Node {
   453  	ignoreNodes := []ast.Node{}
   454  
   455  	for _, p := range pkgs {
   456  		for _, s := range p.Syntax {
   457  			cmap := ast.NewCommentMap(p.Fset, s, s.Comments)
   458  			for node, cglist := range cmap {
   459  				for _, cg := range cglist {
   460  					// Remove `//` and spaces from comment line to get the
   461  					// actual comment text. We can't use cg.Text() directly
   462  					// here due to change introduced in
   463  					// https://github.com/golang/go/issues/37974
   464  					ctext := cg.List[0].Text
   465  					if !strings.HasPrefix(ctext, "//") {
   466  						continue
   467  					}
   468  					ctext = strings.TrimSpace(ctext[2:])
   469  
   470  					anno, err := ParseComment(ctext)
   471  					if err != nil {
   472  						continue
   473  					}
   474  					if anno.Ignore {
   475  						ignoreNodes = append(ignoreNodes, node)
   476  						log.Tracef("Ignore ast node from %d to %d", node.Pos(), node.End())
   477  					}
   478  				}
   479  			}
   480  		}
   481  	}
   482  
   483  	sort.Slice(ignoreNodes, func(i, j int) bool {
   484  		return ignoreNodes[i].Pos() < ignoreNodes[j].Pos()
   485  	})
   486  
   487  	return ignoreNodes
   488  }
   489  
   490  func CheckDir(ctx VetContext, dir, buildFlags string, extraMatchers []SqlFuncMatcher) ([]*QuerySite, error) {
   491  	_, err := os.Stat(filepath.Join(dir, "go.mod"))
   492  	if os.IsNotExist(err) {
   493  		return nil, errors.New("sqlvet only supports projects using go modules for now.")
   494  	}
   495  
   496  	pkgs, err := loadGoPackages(dir, buildFlags)
   497  	if err != nil {
   498  		return nil, err
   499  	}
   500  	log.Debugf("Loaded %d packages: %s", len(pkgs), pkgs)
   501  
   502  	ignoreNodes := getSortedIgnoreNodes(pkgs)
   503  	log.Debugf("Identified %d queries to ignore", len(ignoreNodes))
   504  
   505  	// check to see if loaded packages imported any package that matches our rules
   506  	matchers := getMatchers(extraMatchers)
   507  	log.Debugf("Loaded %d matchers, checking imported SQL packages...", len(matchers))
   508  	for _, matcher := range matchers {
   509  		for _, p := range pkgs {
   510  			v, ok := p.Imports[matcher.PkgPath]
   511  			if !ok {
   512  				continue
   513  			}
   514  			// package is imported by at least of the loaded packages
   515  			matcher.SetGoPackage(v)
   516  			log.Debugf("\t%s imported", matcher.PkgPath)
   517  			break
   518  		}
   519  	}
   520  
   521  	prog, ssaPkgs := ssautil.Packages(pkgs, 0)
   522  	log.Debug("Performaing whole-program analysis...")
   523  	prog.Build()
   524  
   525  	// find ssa.Function for matched sqlfuncs from program
   526  	sqlfuncs := []MatchedSqlFunc{}
   527  	for _, matcher := range matchers {
   528  		if !matcher.PackageImported() {
   529  			// if package is not imported, then no sqlfunc should be matched
   530  			continue
   531  		}
   532  		sqlfuncs = append(sqlfuncs, matcher.MatchSqlFuncs(prog)...)
   533  	}
   534  	log.Debugf("Matched %d sqlfuncs", len(sqlfuncs))
   535  
   536  	log.Debugf("Locating main packages from %d packages.", len(ssaPkgs))
   537  	mains := ssautil.MainPackages(ssaPkgs)
   538  
   539  	log.Debug("Building call graph...")
   540  	anaRes, err := pointer.Analyze(&pointer.Config{
   541  		Mains:          mains,
   542  		BuildCallGraph: true,
   543  	})
   544  
   545  	if err != nil {
   546  		return nil, err
   547  	}
   548  
   549  	queries := []*QuerySite{}
   550  
   551  	cg := anaRes.CallGraph
   552  	for _, sqlfunc := range sqlfuncs {
   553  		cgNode := cg.CreateNode(sqlfunc.SSA)
   554  		queries = append(
   555  			queries,
   556  			iterCallGraphNodeCallees(ctx, cgNode, prog, sqlfunc, ignoreNodes)...)
   557  	}
   558  
   559  	return queries, nil
   560  }