gopkg.in/alecthomas/gometalinter.v3@v3.0.0/_linters/src/github.com/stripe/safesql/safesql.go (about) 1 // Command safesql is a tool for performing static analysis on programs to 2 // ensure that SQL injection attacks are not possible. It does this by ensuring 3 // package database/sql is only used with compile-time constant queries. 4 package main 5 6 import ( 7 "flag" 8 "fmt" 9 "go/build" 10 "go/types" 11 "os" 12 "path/filepath" 13 "strings" 14 15 "golang.org/x/tools/go/callgraph" 16 "golang.org/x/tools/go/loader" 17 "golang.org/x/tools/go/pointer" 18 "golang.org/x/tools/go/ssa" 19 "golang.org/x/tools/go/ssa/ssautil" 20 ) 21 22 func main() { 23 var verbose, quiet bool 24 flag.BoolVar(&verbose, "v", false, "Verbose mode") 25 flag.BoolVar(&quiet, "q", false, "Only print on failure") 26 flag.Usage = func() { 27 fmt.Fprintf(os.Stderr, "Usage: %s [-q] [-v] package1 [package2 ...]\n", os.Args[0]) 28 flag.PrintDefaults() 29 } 30 31 flag.Parse() 32 pkgs := flag.Args() 33 if len(pkgs) == 0 { 34 flag.Usage() 35 os.Exit(2) 36 } 37 38 c := loader.Config{ 39 FindPackage: FindPackage, 40 } 41 c.Import("database/sql") 42 for _, pkg := range pkgs { 43 c.Import(pkg) 44 } 45 p, err := c.Load() 46 if err != nil { 47 fmt.Printf("error loading packages %v: %v\n", pkgs, err) 48 os.Exit(2) 49 } 50 s := ssautil.CreateProgram(p, 0) 51 s.Build() 52 53 qms := FindQueryMethods(p.Package("database/sql").Pkg, s) 54 if verbose { 55 fmt.Println("database/sql functions that accept queries:") 56 for _, m := range qms { 57 fmt.Printf("- %s (param %d)\n", m.Func, m.Param) 58 } 59 fmt.Println() 60 } 61 62 mains := FindMains(p, s) 63 if len(mains) == 0 { 64 fmt.Println("Did not find any commands (i.e., main functions).") 65 os.Exit(2) 66 } 67 68 res, err := pointer.Analyze(&pointer.Config{ 69 Mains: mains, 70 BuildCallGraph: true, 71 }) 72 if err != nil { 73 fmt.Printf("error performing pointer analysis: %v\n", err) 74 os.Exit(2) 75 } 76 77 bad := FindNonConstCalls(res.CallGraph, qms) 78 if len(bad) == 0 { 79 if !quiet { 80 fmt.Println(`You're safe from SQL injection! Yay \o/`) 81 } 82 return 83 } 84 85 fmt.Printf("Found %d potentially unsafe SQL statements:\n", len(bad)) 86 for _, ci := range bad { 87 pos := p.Fset.Position(ci.Pos()) 88 fmt.Printf("- %s\n", pos) 89 } 90 fmt.Println("Please ensure that all SQL queries you use are compile-time constants.") 91 fmt.Println("You should always use parameterized queries or prepared statements") 92 fmt.Println("instead of building queries from strings.") 93 os.Exit(1) 94 } 95 96 // QueryMethod represents a method on a type which has a string parameter named 97 // "query". 98 type QueryMethod struct { 99 Func *types.Func 100 SSA *ssa.Function 101 ArgCount int 102 Param int 103 } 104 105 // FindQueryMethods locates all methods in the given package (assumed to be 106 // package database/sql) with a string parameter named "query". 107 func FindQueryMethods(sql *types.Package, ssa *ssa.Program) []*QueryMethod { 108 methods := make([]*QueryMethod, 0) 109 scope := sql.Scope() 110 for _, name := range scope.Names() { 111 o := scope.Lookup(name) 112 if !o.Exported() { 113 continue 114 } 115 if _, ok := o.(*types.TypeName); !ok { 116 continue 117 } 118 n := o.Type().(*types.Named) 119 for i := 0; i < n.NumMethods(); i++ { 120 m := n.Method(i) 121 if !m.Exported() { 122 continue 123 } 124 s := m.Type().(*types.Signature) 125 if num, ok := FuncHasQuery(s); ok { 126 methods = append(methods, &QueryMethod{ 127 Func: m, 128 SSA: ssa.FuncValue(m), 129 ArgCount: s.Params().Len(), 130 Param: num, 131 }) 132 } 133 } 134 } 135 return methods 136 } 137 138 var stringType types.Type = types.Typ[types.String] 139 140 // FuncHasQuery returns the offset of the string parameter named "query", or 141 // none if no such parameter exists. 142 func FuncHasQuery(s *types.Signature) (offset int, ok bool) { 143 params := s.Params() 144 for i := 0; i < params.Len(); i++ { 145 v := params.At(i) 146 if v.Name() == "query" && v.Type() == stringType { 147 return i, true 148 } 149 } 150 return 0, false 151 } 152 153 // FindMains returns the set of all packages loaded into the given 154 // loader.Program which contain main functions 155 func FindMains(p *loader.Program, s *ssa.Program) []*ssa.Package { 156 ips := p.InitialPackages() 157 mains := make([]*ssa.Package, 0, len(ips)) 158 for _, info := range ips { 159 ssaPkg := s.Package(info.Pkg) 160 if ssaPkg.Func("main") != nil { 161 mains = append(mains, ssaPkg) 162 } 163 } 164 return mains 165 } 166 167 // FindNonConstCalls returns the set of callsites of the given set of methods 168 // for which the "query" parameter is not a compile-time constant. 169 func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstruction { 170 cg.DeleteSyntheticNodes() 171 172 // package database/sql has a couple helper functions which are thin 173 // wrappers around other sensitive functions. Instead of handling the 174 // general case by tracing down callsites of wrapper functions 175 // recursively, let's just whitelist the functions we're already 176 // tracking, since it happens to be good enough for our use case. 177 okFuncs := make(map[*ssa.Function]struct{}, len(qms)) 178 for _, m := range qms { 179 okFuncs[m.SSA] = struct{}{} 180 } 181 182 bad := make([]ssa.CallInstruction, 0) 183 for _, m := range qms { 184 node := cg.CreateNode(m.SSA) 185 for _, edge := range node.In { 186 if _, ok := okFuncs[edge.Site.Parent()]; ok { 187 continue 188 } 189 cc := edge.Site.Common() 190 args := cc.Args 191 // The first parameter is occasionally the receiver. 192 if len(args) == m.ArgCount+1 { 193 args = args[1:] 194 } else if len(args) != m.ArgCount { 195 panic("arg count mismatch") 196 } 197 v := args[m.Param] 198 if _, ok := v.(*ssa.Const); !ok { 199 bad = append(bad, edge.Site) 200 } 201 } 202 } 203 204 return bad 205 } 206 207 // Deal with GO15VENDOREXPERIMENT 208 func FindPackage(ctxt *build.Context, path, dir string, mode build.ImportMode) (*build.Package, error) { 209 if !useVendor { 210 return ctxt.Import(path, dir, mode) 211 } 212 213 // First, walk up the filesystem from dir looking for vendor directories 214 var vendorDir string 215 for tmp := dir; vendorDir == "" && tmp != "/"; tmp = filepath.Dir(tmp) { 216 dname := filepath.Join(tmp, "vendor", filepath.FromSlash(path)) 217 fd, err := os.Open(dname) 218 if err != nil { 219 continue 220 } 221 // Directories are only valid if they contain at least one file 222 // with suffix ".go" (this also ensures that the file descriptor 223 // we have is in fact a directory) 224 names, err := fd.Readdirnames(-1) 225 if err != nil { 226 continue 227 } 228 for _, name := range names { 229 if strings.HasSuffix(name, ".go") { 230 vendorDir = filepath.ToSlash(dname) 231 break 232 } 233 } 234 } 235 236 if vendorDir != "" { 237 pkg, err := ctxt.ImportDir(vendorDir, mode) 238 if err != nil { 239 return nil, err 240 } 241 // Go tries to derive a valid import path for the package, but 242 // it's wrong (it includes "/vendor/"). Overwrite it here. 243 pkg.ImportPath = path 244 return pkg, nil 245 } 246 247 return ctxt.Import(path, dir, mode) 248 }