gopkg.in/alecthomas/gometalinter.v3@v3.0.0/_linters/src/github.com/stripe/safesql/safesql.go (about)

     1  // Command safesql is a tool for performing static analysis on programs to
     2  // ensure that SQL injection attacks are not possible. It does this by ensuring
     3  // package database/sql is only used with compile-time constant queries.
     4  package main
     5  
     6  import (
     7  	"flag"
     8  	"fmt"
     9  	"go/build"
    10  	"go/types"
    11  	"os"
    12  	"path/filepath"
    13  	"strings"
    14  
    15  	"golang.org/x/tools/go/callgraph"
    16  	"golang.org/x/tools/go/loader"
    17  	"golang.org/x/tools/go/pointer"
    18  	"golang.org/x/tools/go/ssa"
    19  	"golang.org/x/tools/go/ssa/ssautil"
    20  )
    21  
    22  func main() {
    23  	var verbose, quiet bool
    24  	flag.BoolVar(&verbose, "v", false, "Verbose mode")
    25  	flag.BoolVar(&quiet, "q", false, "Only print on failure")
    26  	flag.Usage = func() {
    27  		fmt.Fprintf(os.Stderr, "Usage: %s [-q] [-v] package1 [package2 ...]\n", os.Args[0])
    28  		flag.PrintDefaults()
    29  	}
    30  
    31  	flag.Parse()
    32  	pkgs := flag.Args()
    33  	if len(pkgs) == 0 {
    34  		flag.Usage()
    35  		os.Exit(2)
    36  	}
    37  
    38  	c := loader.Config{
    39  		FindPackage: FindPackage,
    40  	}
    41  	c.Import("database/sql")
    42  	for _, pkg := range pkgs {
    43  		c.Import(pkg)
    44  	}
    45  	p, err := c.Load()
    46  	if err != nil {
    47  		fmt.Printf("error loading packages %v: %v\n", pkgs, err)
    48  		os.Exit(2)
    49  	}
    50  	s := ssautil.CreateProgram(p, 0)
    51  	s.Build()
    52  
    53  	qms := FindQueryMethods(p.Package("database/sql").Pkg, s)
    54  	if verbose {
    55  		fmt.Println("database/sql functions that accept queries:")
    56  		for _, m := range qms {
    57  			fmt.Printf("- %s (param %d)\n", m.Func, m.Param)
    58  		}
    59  		fmt.Println()
    60  	}
    61  
    62  	mains := FindMains(p, s)
    63  	if len(mains) == 0 {
    64  		fmt.Println("Did not find any commands (i.e., main functions).")
    65  		os.Exit(2)
    66  	}
    67  
    68  	res, err := pointer.Analyze(&pointer.Config{
    69  		Mains:          mains,
    70  		BuildCallGraph: true,
    71  	})
    72  	if err != nil {
    73  		fmt.Printf("error performing pointer analysis: %v\n", err)
    74  		os.Exit(2)
    75  	}
    76  
    77  	bad := FindNonConstCalls(res.CallGraph, qms)
    78  	if len(bad) == 0 {
    79  		if !quiet {
    80  			fmt.Println(`You're safe from SQL injection! Yay \o/`)
    81  		}
    82  		return
    83  	}
    84  
    85  	fmt.Printf("Found %d potentially unsafe SQL statements:\n", len(bad))
    86  	for _, ci := range bad {
    87  		pos := p.Fset.Position(ci.Pos())
    88  		fmt.Printf("- %s\n", pos)
    89  	}
    90  	fmt.Println("Please ensure that all SQL queries you use are compile-time constants.")
    91  	fmt.Println("You should always use parameterized queries or prepared statements")
    92  	fmt.Println("instead of building queries from strings.")
    93  	os.Exit(1)
    94  }
    95  
    96  // QueryMethod represents a method on a type which has a string parameter named
    97  // "query".
    98  type QueryMethod struct {
    99  	Func     *types.Func
   100  	SSA      *ssa.Function
   101  	ArgCount int
   102  	Param    int
   103  }
   104  
   105  // FindQueryMethods locates all methods in the given package (assumed to be
   106  // package database/sql) with a string parameter named "query".
   107  func FindQueryMethods(sql *types.Package, ssa *ssa.Program) []*QueryMethod {
   108  	methods := make([]*QueryMethod, 0)
   109  	scope := sql.Scope()
   110  	for _, name := range scope.Names() {
   111  		o := scope.Lookup(name)
   112  		if !o.Exported() {
   113  			continue
   114  		}
   115  		if _, ok := o.(*types.TypeName); !ok {
   116  			continue
   117  		}
   118  		n := o.Type().(*types.Named)
   119  		for i := 0; i < n.NumMethods(); i++ {
   120  			m := n.Method(i)
   121  			if !m.Exported() {
   122  				continue
   123  			}
   124  			s := m.Type().(*types.Signature)
   125  			if num, ok := FuncHasQuery(s); ok {
   126  				methods = append(methods, &QueryMethod{
   127  					Func:     m,
   128  					SSA:      ssa.FuncValue(m),
   129  					ArgCount: s.Params().Len(),
   130  					Param:    num,
   131  				})
   132  			}
   133  		}
   134  	}
   135  	return methods
   136  }
   137  
   138  var stringType types.Type = types.Typ[types.String]
   139  
   140  // FuncHasQuery returns the offset of the string parameter named "query", or
   141  // none if no such parameter exists.
   142  func FuncHasQuery(s *types.Signature) (offset int, ok bool) {
   143  	params := s.Params()
   144  	for i := 0; i < params.Len(); i++ {
   145  		v := params.At(i)
   146  		if v.Name() == "query" && v.Type() == stringType {
   147  			return i, true
   148  		}
   149  	}
   150  	return 0, false
   151  }
   152  
   153  // FindMains returns the set of all packages loaded into the given
   154  // loader.Program which contain main functions
   155  func FindMains(p *loader.Program, s *ssa.Program) []*ssa.Package {
   156  	ips := p.InitialPackages()
   157  	mains := make([]*ssa.Package, 0, len(ips))
   158  	for _, info := range ips {
   159  		ssaPkg := s.Package(info.Pkg)
   160  		if ssaPkg.Func("main") != nil {
   161  			mains = append(mains, ssaPkg)
   162  		}
   163  	}
   164  	return mains
   165  }
   166  
   167  // FindNonConstCalls returns the set of callsites of the given set of methods
   168  // for which the "query" parameter is not a compile-time constant.
   169  func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstruction {
   170  	cg.DeleteSyntheticNodes()
   171  
   172  	// package database/sql has a couple helper functions which are thin
   173  	// wrappers around other sensitive functions. Instead of handling the
   174  	// general case by tracing down callsites of wrapper functions
   175  	// recursively, let's just whitelist the functions we're already
   176  	// tracking, since it happens to be good enough for our use case.
   177  	okFuncs := make(map[*ssa.Function]struct{}, len(qms))
   178  	for _, m := range qms {
   179  		okFuncs[m.SSA] = struct{}{}
   180  	}
   181  
   182  	bad := make([]ssa.CallInstruction, 0)
   183  	for _, m := range qms {
   184  		node := cg.CreateNode(m.SSA)
   185  		for _, edge := range node.In {
   186  			if _, ok := okFuncs[edge.Site.Parent()]; ok {
   187  				continue
   188  			}
   189  			cc := edge.Site.Common()
   190  			args := cc.Args
   191  			// The first parameter is occasionally the receiver.
   192  			if len(args) == m.ArgCount+1 {
   193  				args = args[1:]
   194  			} else if len(args) != m.ArgCount {
   195  				panic("arg count mismatch")
   196  			}
   197  			v := args[m.Param]
   198  			if _, ok := v.(*ssa.Const); !ok {
   199  				bad = append(bad, edge.Site)
   200  			}
   201  		}
   202  	}
   203  
   204  	return bad
   205  }
   206  
   207  // Deal with GO15VENDOREXPERIMENT
   208  func FindPackage(ctxt *build.Context, path, dir string, mode build.ImportMode) (*build.Package, error) {
   209  	if !useVendor {
   210  		return ctxt.Import(path, dir, mode)
   211  	}
   212  
   213  	// First, walk up the filesystem from dir looking for vendor directories
   214  	var vendorDir string
   215  	for tmp := dir; vendorDir == "" && tmp != "/"; tmp = filepath.Dir(tmp) {
   216  		dname := filepath.Join(tmp, "vendor", filepath.FromSlash(path))
   217  		fd, err := os.Open(dname)
   218  		if err != nil {
   219  			continue
   220  		}
   221  		// Directories are only valid if they contain at least one file
   222  		// with suffix ".go" (this also ensures that the file descriptor
   223  		// we have is in fact a directory)
   224  		names, err := fd.Readdirnames(-1)
   225  		if err != nil {
   226  			continue
   227  		}
   228  		for _, name := range names {
   229  			if strings.HasSuffix(name, ".go") {
   230  				vendorDir = filepath.ToSlash(dname)
   231  				break
   232  			}
   233  		}
   234  	}
   235  
   236  	if vendorDir != "" {
   237  		pkg, err := ctxt.ImportDir(vendorDir, mode)
   238  		if err != nil {
   239  			return nil, err
   240  		}
   241  		// Go tries to derive a valid import path for the package, but
   242  		// it's wrong (it includes "/vendor/"). Overwrite it here.
   243  		pkg.ImportPath = path
   244  		return pkg, nil
   245  	}
   246  
   247  	return ctxt.Import(path, dir, mode)
   248  }