github.com/samiam2013/sqlvet@v0.0.0-20221210043606-d72f678fc0aa/pkg/vet/gosource.go (about) 1 package vet 2 3 import ( 4 "errors" 5 "fmt" 6 "go/ast" 7 "go/constant" 8 "go/token" 9 "go/types" 10 "os" 11 "path/filepath" 12 "reflect" 13 "sort" 14 "strings" 15 16 "golang.org/x/tools/go/callgraph" 17 "golang.org/x/tools/go/packages" 18 "golang.org/x/tools/go/pointer" 19 "golang.org/x/tools/go/ssa" 20 "golang.org/x/tools/go/ssa/ssautil" 21 22 log "github.com/sirupsen/logrus" 23 24 "github.com/samiam2013/sqlvet/pkg/parseutil" 25 ) 26 27 var ( 28 ErrQueryArgUnsupportedType = errors.New("unexpected query arg type") 29 ErrQueryArgUnsafe = errors.New("potentially unsafe query string") 30 ErrQueryArgTODO = errors.New("TODO: support this type") 31 ) 32 33 type QuerySite struct { 34 Called string 35 Position token.Position 36 Query string 37 ParameterArgCount int 38 Err error 39 } 40 41 type MatchedSqlFunc struct { 42 SSA *ssa.Function 43 QueryArgPos int 44 } 45 46 type SqlFuncMatchRule struct { 47 FuncName string `toml:"func_name"` 48 // zero indexed 49 QueryArgPos int `toml:"query_arg_pos"` 50 QueryArgName string `toml:"query_arg_name"` 51 } 52 53 type SqlFuncMatcher struct { 54 PkgPath string `toml:"pkg_path"` 55 Rules []SqlFuncMatchRule `toml:"rules"` 56 57 pkg *packages.Package 58 } 59 60 func (s *SqlFuncMatcher) SetGoPackage(p *packages.Package) { 61 s.pkg = p 62 } 63 64 func (s *SqlFuncMatcher) PackageImported() bool { 65 return s.pkg != nil 66 } 67 68 func (s *SqlFuncMatcher) IterPackageExportedFuncs(cb func(*types.Func)) { 69 scope := s.pkg.Types.Scope() 70 for _, scopeName := range scope.Names() { 71 obj := scope.Lookup(scopeName) 72 if !obj.Exported() { 73 continue 74 } 75 76 fobj, ok := obj.(*types.Func) 77 if ok { 78 cb(fobj) 79 } else { 80 // check for exported struct methods 81 switch otype := obj.Type().(type) { 82 case *types.Signature: 83 case *types.Named: 84 for i := 0; i < otype.NumMethods(); i++ { 85 m := otype.Method(i) 86 if !m.Exported() { 87 continue 88 } 89 cb(m) 90 } 91 case *types.Basic: 92 default: 93 log.Debugf("Skipped pkg scope: %s (%s)", otype, reflect.TypeOf(otype)) 94 } 95 } 96 } 97 } 98 99 func (s *SqlFuncMatcher) MatchSqlFuncs(prog *ssa.Program) []MatchedSqlFunc { 100 sqlfuncs := []MatchedSqlFunc{} 101 102 s.IterPackageExportedFuncs(func(fobj *types.Func) { 103 for _, rule := range s.Rules { 104 if rule.FuncName != "" && fobj.Name() == rule.FuncName { 105 sqlfuncs = append(sqlfuncs, MatchedSqlFunc{ 106 SSA: prog.FuncValue(fobj), 107 QueryArgPos: rule.QueryArgPos, 108 }) 109 // callable matched one rule, no need to go through the rest 110 break 111 } 112 113 if rule.QueryArgName != "" { 114 sigParams := fobj.Type().(*types.Signature).Params() 115 if sigParams.Len()-1 < rule.QueryArgPos { 116 continue 117 } 118 param := sigParams.At(rule.QueryArgPos) 119 if param.Name() != rule.QueryArgName { 120 continue 121 } 122 sqlfuncs = append(sqlfuncs, MatchedSqlFunc{ 123 SSA: prog.FuncValue(fobj), 124 QueryArgPos: rule.QueryArgPos, 125 }) 126 // callable matched one rule, no need to go through the rest 127 break 128 } 129 } 130 }) 131 132 return sqlfuncs 133 } 134 135 func handleQuery(ctx VetContext, qs *QuerySite) { 136 // TODO: apply named query resolution based on v.X type and v.Sel.Name 137 // e.g. for sqlx, only apply to NamedExec and NamedQuery 138 qs.Query, _, qs.Err = parseutil.CompileNamedQuery( 139 []byte(qs.Query), parseutil.BindType("postgres")) 140 if qs.Err != nil { 141 return 142 } 143 144 var queryParams []QueryParam 145 queryParams, qs.Err = ValidateSqlQuery(ctx, qs.Query) 146 147 if qs.Err != nil { 148 return 149 } 150 151 // query string is valid, now validate parameter args if exists 152 if qs.ParameterArgCount < len(queryParams) { 153 // qs.Err = fmt.Errorf( 154 // "Query expects %d parameters, but received %d from function call", 155 // len(queryParams), qs.ParameterArgCount, 156 // ) 157 } 158 } 159 160 func getMatchers(extraMatchers []SqlFuncMatcher) []*SqlFuncMatcher { 161 matchers := []*SqlFuncMatcher{ 162 { 163 PkgPath: "github.com/jmoiron/sqlx", 164 Rules: []SqlFuncMatchRule{ 165 {QueryArgName: "query"}, 166 {QueryArgName: "sql"}, 167 // for methods with Context suffix 168 {QueryArgName: "query", QueryArgPos: 1}, 169 {QueryArgName: "sql", QueryArgPos: 1}, 170 {QueryArgName: "query", QueryArgPos: 2}, 171 {QueryArgName: "sql", QueryArgPos: 2}, 172 }, 173 }, 174 { 175 PkgPath: "database/sql", 176 Rules: []SqlFuncMatchRule{ 177 {QueryArgName: "query"}, 178 {QueryArgName: "sql"}, 179 // for methods with Context suffix 180 {QueryArgName: "query", QueryArgPos: 1}, 181 {QueryArgName: "sql", QueryArgPos: 1}, 182 }, 183 }, 184 { 185 PkgPath: "github.com/jinzhu/gorm", 186 Rules: []SqlFuncMatchRule{ 187 {QueryArgName: "sql"}, 188 }, 189 }, 190 // TODO: xorm uses vararg, which is not supported yet 191 // &SqlFuncMatcher{ 192 // PkgPath: "xorm.io/xorm", 193 // Rules: []SqlFuncMatchRule{ 194 // {FuncName: "SQL"}, 195 // {FuncName: "Sql"}, 196 // {FuncName: "Exec"}, 197 // {FuncName: "Query"}, 198 // {FuncName: "QueryInterface"}, 199 // {FuncName: "QueryString"}, 200 // {FuncName: "QuerySliceString"}, 201 // }, 202 // }, 203 { 204 PkgPath: "go-gorp/gorp", 205 Rules: []SqlFuncMatchRule{ 206 {QueryArgName: "query"}, 207 }, 208 }, 209 { 210 PkgPath: "gopkg.in/gorp.v1", 211 Rules: []SqlFuncMatchRule{ 212 {QueryArgName: "query"}, 213 }, 214 }, 215 } 216 if extraMatchers != nil { 217 for _, m := range extraMatchers { 218 tmpm := m 219 matchers = append(matchers, &tmpm) 220 } 221 } 222 223 return matchers 224 } 225 226 func loadGoPackages(dir string, buildFlags string) ([]*packages.Package, error) { 227 cfg := &packages.Config{ 228 Mode: packages.NeedName | 229 packages.NeedFiles | 230 packages.NeedImports | 231 packages.NeedDeps | 232 packages.NeedTypes | 233 packages.NeedSyntax | 234 packages.NeedTypesInfo, 235 Dir: dir, 236 Env: append(os.Environ(), "GO111MODULE=auto"), 237 } 238 if buildFlags != "" { 239 cfg.BuildFlags = strings.Split(buildFlags, " ") 240 } 241 dirAbs, err := filepath.Abs(dir) 242 if err != nil { 243 return nil, fmt.Errorf("Invalid path: %w", err) 244 } 245 pkgPath := dirAbs + "/..." 246 pkgs, err := packages.Load(cfg, pkgPath) 247 if err != nil { 248 return nil, err 249 } 250 // return early if any syntax error 251 for _, pkg := range pkgs { 252 if len(pkg.Errors) > 0 { 253 return nil, fmt.Errorf("Failed to load package, %w", pkg.Errors[0]) 254 } 255 } 256 return pkgs, nil 257 } 258 259 func extractQueryStrFromSsaValue(argVal ssa.Value) (string, error) { 260 queryStr := "" 261 262 switch queryArg := argVal.(type) { 263 case *ssa.Const: 264 queryStr = constant.StringVal(queryArg.Value) 265 case *ssa.Phi: 266 // TODO: resolve all phi options 267 // for _, edge := range queryArg.Edges { 268 // } 269 log.Debug("TODO(callgraph) support ssa.Phi") 270 return "", ErrQueryArgTODO 271 case *ssa.BinOp: 272 // only support string concat 273 switch queryArg.Op { 274 case token.ADD: 275 lstr, err := extractQueryStrFromSsaValue(queryArg.X) 276 if err != nil { 277 return "", err 278 } 279 rstr, err := extractQueryStrFromSsaValue(queryArg.Y) 280 if err != nil { 281 return "", err 282 } 283 queryStr = lstr + rstr 284 default: 285 return "", ErrQueryArgUnsupportedType 286 } 287 case *ssa.Parameter: 288 // query call is wrapped in a helper function, query string is passed 289 // in as function parameter 290 // TODO: need to trace the caller or add wrapper function to 291 // matcher config 292 return "", ErrQueryArgTODO 293 case *ssa.Extract: 294 // query string is from one of the multi return values 295 // need to figure out how to trace string from function returns 296 return "", ErrQueryArgTODO 297 case *ssa.Call: 298 // return value from a function call 299 // TODO: trace caller function 300 return "", ErrQueryArgUnsafe 301 case *ssa.MakeInterface: 302 // query function takes interface as input 303 // check to see if interface is converted from a string 304 switch interfaceFrom := queryArg.X.(type) { 305 case *ssa.Const: 306 queryStr = constant.StringVal(interfaceFrom.Value) 307 default: 308 return "", ErrQueryArgUnsupportedType 309 } 310 case *ssa.Slice: 311 // function takes var arg as input 312 313 // Type() returns string if the type of X was string, otherwise a 314 // *types.Slice with the same element type as X. 315 if _, ok := queryArg.Type().(*types.Slice); ok { 316 log.Debug("TODO(callgraph) support slice for vararg") 317 } 318 return "", ErrQueryArgTODO 319 default: 320 return "", ErrQueryArgUnsupportedType 321 } 322 323 return queryStr, nil 324 } 325 326 func shouldIgnoreNode(ignoreNodes []ast.Node, callSitePos token.Pos) bool { 327 if len(ignoreNodes) == 0 { 328 return false 329 } 330 331 if callSitePos < ignoreNodes[0].Pos() { 332 return false 333 } 334 335 if callSitePos > ignoreNodes[len(ignoreNodes)-1].End() { 336 return false 337 } 338 339 for _, n := range ignoreNodes { 340 if callSitePos < n.End() && callSitePos > n.Pos() { 341 return true 342 } 343 } 344 345 return false 346 } 347 348 func iterCallGraphNodeCallees(ctx VetContext, cgNode *callgraph.Node, prog *ssa.Program, sqlfunc MatchedSqlFunc, ignoreNodes []ast.Node) []*QuerySite { 349 queries := []*QuerySite{} 350 351 for _, inEdge := range cgNode.In { 352 callerFunc := inEdge.Caller.Func 353 if callerFunc.Pkg == nil { 354 // skip calls from dependencies 355 continue 356 } 357 358 callSite := inEdge.Site 359 callSitePos := callSite.Pos() 360 if shouldIgnoreNode(ignoreNodes, callSitePos) { 361 continue 362 } 363 364 callSitePosition := prog.Fset.Position(callSitePos) 365 log.Debugf("Validating %s @ %s", sqlfunc.SSA, callSitePosition) 366 367 callArgs := callSite.Common().Args 368 369 absArgPos := sqlfunc.QueryArgPos 370 if callSite.Common().IsInvoke() { 371 // interface method invocation. 372 // In this mode, Value is the interface value and Method is the 373 // interface's abstract method. Note: an abstract method may be 374 // shared by multiple interfaces due to embedding; Value.Type() 375 // provides the specific interface used for this call. 376 } else { 377 // "call" mode: when Method is nil (!IsInvoke), a CallCommon 378 // represents an ordinary function call of the value in Value, 379 // which may be a *Builtin, a *Function or any other value of 380 // kind 'func'. 381 if sqlfunc.SSA.Signature.Recv() != nil { 382 // it's a struct method call, plus 1 to take receiver into 383 // account 384 absArgPos += 1 385 } 386 } 387 queryArg := callArgs[absArgPos] 388 389 qs := &QuerySite{ 390 Called: inEdge.Callee.Func.Name(), 391 Position: callSitePosition, 392 Err: nil, 393 } 394 395 if len(callArgs) > absArgPos+1 { 396 // query function accepts query parameters 397 paramArg := callArgs[absArgPos+1] 398 // only support query param as variadic argument for now 399 switch params := paramArg.(type) { 400 case *ssa.Const: 401 // likely nil 402 case *ssa.Slice: 403 sliceType := params.X.Type() 404 switch t := sliceType.(type) { 405 case *types.Pointer: 406 elem := t.Elem() 407 switch e := elem.(type) { 408 case *types.Array: 409 // query parameters are passed in as vararg: an array 410 // of interface 411 qs.ParameterArgCount = int(e.Len()) 412 } 413 } 414 } 415 } 416 417 qs.Query, qs.Err = extractQueryStrFromSsaValue(queryArg) 418 if qs.Err != nil { 419 switch qs.Err { 420 case ErrQueryArgUnsupportedType: 421 log.WithFields(log.Fields{ 422 "type": reflect.TypeOf(queryArg), 423 "pos": prog.Fset.Position(callSite.Pos()), 424 "caller": callerFunc, 425 "callerPkg": callerFunc.Pkg, 426 }).Debug(fmt.Errorf("unsupported type in callgraph: %w", qs.Err)) 427 case ErrQueryArgTODO: 428 log.WithFields(log.Fields{ 429 "type": reflect.TypeOf(queryArg), 430 "pos": prog.Fset.Position(callSite.Pos()), 431 "caller": callerFunc, 432 "callerPkg": callerFunc.Pkg, 433 }).Debug(fmt.Errorf("TODO(callgraph) %w", qs.Err)) 434 // skip to be supported query type 435 continue 436 default: 437 queries = append(queries, qs) 438 continue 439 } 440 } 441 442 if qs.Query == "" { 443 continue 444 } 445 handleQuery(ctx, qs) 446 queries = append(queries, qs) 447 } 448 449 return queries 450 } 451 452 func getSortedIgnoreNodes(pkgs []*packages.Package) []ast.Node { 453 ignoreNodes := []ast.Node{} 454 455 for _, p := range pkgs { 456 for _, s := range p.Syntax { 457 cmap := ast.NewCommentMap(p.Fset, s, s.Comments) 458 for node, cglist := range cmap { 459 for _, cg := range cglist { 460 // Remove `//` and spaces from comment line to get the 461 // actual comment text. We can't use cg.Text() directly 462 // here due to change introduced in 463 // https://github.com/golang/go/issues/37974 464 ctext := cg.List[0].Text 465 if !strings.HasPrefix(ctext, "//") { 466 continue 467 } 468 ctext = strings.TrimSpace(ctext[2:]) 469 470 anno, err := ParseComment(ctext) 471 if err != nil { 472 continue 473 } 474 if anno.Ignore { 475 ignoreNodes = append(ignoreNodes, node) 476 log.Tracef("Ignore ast node from %d to %d", node.Pos(), node.End()) 477 } 478 } 479 } 480 } 481 } 482 483 sort.Slice(ignoreNodes, func(i, j int) bool { 484 return ignoreNodes[i].Pos() < ignoreNodes[j].Pos() 485 }) 486 487 return ignoreNodes 488 } 489 490 func CheckDir(ctx VetContext, dir, buildFlags string, extraMatchers []SqlFuncMatcher) ([]*QuerySite, error) { 491 _, err := os.Stat(filepath.Join(dir, "go.mod")) 492 if os.IsNotExist(err) { 493 return nil, errors.New("sqlvet only supports projects using go modules for now.") 494 } 495 496 pkgs, err := loadGoPackages(dir, buildFlags) 497 if err != nil { 498 return nil, err 499 } 500 log.Debugf("Loaded %d packages: %s", len(pkgs), pkgs) 501 502 ignoreNodes := getSortedIgnoreNodes(pkgs) 503 log.Debugf("Identified %d queries to ignore", len(ignoreNodes)) 504 505 // check to see if loaded packages imported any package that matches our rules 506 matchers := getMatchers(extraMatchers) 507 log.Debugf("Loaded %d matchers, checking imported SQL packages...", len(matchers)) 508 for _, matcher := range matchers { 509 for _, p := range pkgs { 510 v, ok := p.Imports[matcher.PkgPath] 511 if !ok { 512 continue 513 } 514 // package is imported by at least of the loaded packages 515 matcher.SetGoPackage(v) 516 log.Debugf("\t%s imported", matcher.PkgPath) 517 break 518 } 519 } 520 521 prog, ssaPkgs := ssautil.Packages(pkgs, 0) 522 log.Debug("Performaing whole-program analysis...") 523 prog.Build() 524 525 // find ssa.Function for matched sqlfuncs from program 526 sqlfuncs := []MatchedSqlFunc{} 527 for _, matcher := range matchers { 528 if !matcher.PackageImported() { 529 // if package is not imported, then no sqlfunc should be matched 530 continue 531 } 532 sqlfuncs = append(sqlfuncs, matcher.MatchSqlFuncs(prog)...) 533 } 534 log.Debugf("Matched %d sqlfuncs", len(sqlfuncs)) 535 536 log.Debugf("Locating main packages from %d packages.", len(ssaPkgs)) 537 mains := ssautil.MainPackages(ssaPkgs) 538 539 log.Debug("Building call graph...") 540 anaRes, err := pointer.Analyze(&pointer.Config{ 541 Mains: mains, 542 BuildCallGraph: true, 543 }) 544 545 if err != nil { 546 return nil, err 547 } 548 549 queries := []*QuerySite{} 550 551 cg := anaRes.CallGraph 552 for _, sqlfunc := range sqlfuncs { 553 cgNode := cg.CreateNode(sqlfunc.SSA) 554 queries = append( 555 queries, 556 iterCallGraphNodeCallees(ctx, cgNode, prog, sqlfunc, ignoreNodes)...) 557 } 558 559 return queries, nil 560 }