github.com/samiam2013/sqlvet@v0.0.0-20221210043606-d72f678fc0aa/pkg/vet/gosource_test.go (about)

     1  package vet_test
     2  
     3  import (
     4  	"io/ioutil"
     5  	"os"
     6  	"path/filepath"
     7  	"sort"
     8  	"testing"
     9  
    10  	"github.com/houqp/gtest"
    11  	log "github.com/sirupsen/logrus"
    12  	"github.com/stretchr/testify/assert"
    13  
    14  	"github.com/samiam2013/sqlvet/pkg/vet"
    15  )
    16  
    17  type GoSourceTmpDir struct{}
    18  
    19  func (s GoSourceTmpDir) Construct(t *testing.T, fixtures struct{}) (string, string) {
    20  	dir, err := ioutil.TempDir("", "gosource-tmpdir")
    21  	assert.NoError(t, err)
    22  
    23  	modpath := filepath.Join(dir, "go.mod")
    24  	err = ioutil.WriteFile(modpath, []byte(`
    25  module github.com/houqp/sqlvettest
    26  	`), 0644)
    27  	assert.NoError(t, err)
    28  
    29  	return dir, dir
    30  }
    31  
    32  func (s GoSourceTmpDir) Destruct(t *testing.T, dir string) {
    33  	os.RemoveAll(dir)
    34  }
    35  
    36  func init() {
    37  	gtest.MustRegisterFixture("GoSourceTmpDir", &GoSourceTmpDir{}, gtest.ScopeSubTest)
    38  }
    39  
    40  type GoSourceTests struct{}
    41  
    42  func (s *GoSourceTests) Setup(t *testing.T) {
    43  	log.SetLevel(log.TraceLevel)
    44  }
    45  func (s *GoSourceTests) Teardown(t *testing.T)   {}
    46  func (s *GoSourceTests) BeforeEach(t *testing.T) {}
    47  func (s *GoSourceTests) AfterEach(t *testing.T)  {}
    48  
    49  func (s *GoSourceTests) SubTestInvalidSyntax(t *testing.T, fixtures struct {
    50  	TmpDir string `fixture:"GoSourceTmpDir"`
    51  }) {
    52  	dir := fixtures.TmpDir
    53  
    54  	fpath := filepath.Join(dir, "main.go")
    55  	err := ioutil.WriteFile(fpath, []byte(`
    56  package main
    57  
    58  func main() {
    59  	return 1
    60  }
    61  `), 0644)
    62  	assert.NoError(t, err)
    63  
    64  	_, err = vet.CheckDir(vet.VetContext{}, dir, "", nil)
    65  	assert.Error(t, err)
    66  }
    67  
    68  func (s *GoSourceTests) SubTestSkipNoneDbQueryCall(t *testing.T, fixtures struct {
    69  	TmpDir string `fixture:"GoSourceTmpDir"`
    70  }) {
    71  	dir := fixtures.TmpDir
    72  
    73  	source := []byte(`
    74  package main
    75  
    76  type Parameter struct {}
    77  
    78  func (Parameter) Query(s string) error {
    79  	return nil
    80  }
    81  
    82  func NewParam() *Parameter {
    83  	return &Parameter{}
    84  }
    85  
    86  func main() {
    87  	f := Parameter{}
    88  	f.Query("user_id")
    89  
    90  	func() {
    91  		func() {
    92  			// access from outside of closure body scope
    93  			f.Query("user_id")
    94  
    95  			// from function call
    96  			// scoped within closure
    97  			flocal := NewParam()
    98  			flocal.Query("book_id")
    99  		}()
   100  	}()
   101  }
   102  	`)
   103  
   104  	fpath := filepath.Join(dir, "main.go")
   105  	err := ioutil.WriteFile(fpath, source, 0644)
   106  	assert.NoError(t, err)
   107  
   108  	queries, err := vet.CheckDir(vet.VetContext{}, dir, "", nil)
   109  	assert.NoError(t, err)
   110  	assert.Equal(t, 0, len(queries))
   111  }
   112  
   113  func (s *GoSourceTests) SubTestPkgDatabaseSql(t *testing.T, fixtures struct {
   114  	TmpDir string `fixture:"GoSourceTmpDir"`
   115  }) {
   116  	dir := fixtures.TmpDir
   117  
   118  	source := []byte(`
   119  package main
   120  
   121  import (
   122  	"context"
   123  	"database/sql"
   124  	"fmt"
   125  )
   126  
   127  func main() {
   128  	db, _ := sql.Open("mysql", "user:password@tcp(127.0.0.1:3306)/hello")
   129  	db.Query("SELECT 1")
   130  
   131  	// sqlvet: ignore
   132  	db.Query("SELECT 2")
   133  
   134  	db.Query("SELECT 3") //sqlvet: ignore
   135  
   136  	db.Query(
   137  		"SELECT 4",
   138  	) //sqlvet:ignore
   139  
   140  	db.Query("SELECT 5")
   141  	//   sqlvet:ignore
   142  
   143  	// context aware methods
   144  	ctx := context.Background()
   145  	db.QueryRowContext(ctx, "SELECT 5")
   146  
   147  	tx, _ := db.Begin()
   148  	tx.ExecContext(ctx, "SELECT 6")
   149  
   150  	// unsafe string
   151  	var userInput string
   152  	tx.Query(fmt.Sprintf("SELECT %s", userInput))
   153  
   154  	// string concat
   155  	tx.Exec("SELECT " + "7")
   156  	staticUserId := "id"
   157  	tx.Exec("SELECT " + staticUserId + " FROM foo")
   158  }
   159  	`)
   160  
   161  	fpath := filepath.Join(dir, "main.go")
   162  	err := ioutil.WriteFile(fpath, source, 0644)
   163  	assert.NoError(t, err)
   164  
   165  	queries, err := vet.CheckDir(vet.VetContext{}, dir, "", nil)
   166  	if err != nil {
   167  		t.Fatalf("Failed to load package: %s", err.Error())
   168  		return
   169  	}
   170  	assert.Equal(t, 6, len(queries))
   171  	sort.Slice(queries, func(i, j int) bool {
   172  		return queries[i].Position.Offset < queries[j].Position.Offset
   173  	})
   174  
   175  	assert.NoError(t, queries[0].Err)
   176  	assert.Equal(t, "SELECT 1", queries[0].Query)
   177  
   178  	assert.NoError(t, queries[1].Err)
   179  	assert.Equal(t, "SELECT 5", queries[1].Query)
   180  
   181  	assert.NoError(t, queries[2].Err)
   182  	assert.Equal(t, "SELECT 6", queries[2].Query)
   183  
   184  	// unsafe string
   185  	assert.Error(t, queries[3].Err)
   186  
   187  	// string concat
   188  	assert.NoError(t, queries[4].Err)
   189  	assert.Equal(t, "SELECT 7", queries[4].Query)
   190  	assert.NoError(t, queries[5].Err)
   191  	assert.Equal(t, "SELECT id FROM foo", queries[5].Query)
   192  }
   193  
   194  // run sqlvet from parent dir
   195  func (s *GoSourceTests) SubTestCheckRelativeDir(t *testing.T, fixtures struct {
   196  	TmpDir string `fixture:"GoSourceTmpDir"`
   197  }) {
   198  	dir := fixtures.TmpDir
   199  
   200  	source := []byte(`
   201  package main
   202  
   203  func main() {
   204  }
   205  	`)
   206  
   207  	fpath := filepath.Join(dir, "main.go")
   208  	err := ioutil.WriteFile(fpath, source, 0644)
   209  	assert.NoError(t, err)
   210  
   211  	cwd, err := os.Getwd()
   212  	assert.NoError(t, err)
   213  	parentDir := filepath.Dir(dir)
   214  	os.Chdir(parentDir)
   215  	defer os.Chdir(cwd)
   216  
   217  	queries, err := vet.CheckDir(vet.VetContext{}, filepath.Base(dir), "", nil)
   218  	if err != nil {
   219  		t.Fatalf("Failed to load package: %s", err.Error())
   220  		return
   221  	}
   222  	assert.Equal(t, 0, len(queries))
   223  }
   224  
   225  func TestGoSource(t *testing.T) {
   226  	gtest.RunSubTests(t, &GoSourceTests{})
   227  }
   228  
   229  func (s *GoSourceTests) SubTestQueryParam(t *testing.T, fixtures struct {
   230  	TmpDir string `fixture:"GoSourceTmpDir"`
   231  }) {
   232  	dir := fixtures.TmpDir
   233  
   234  	source := []byte(`
   235  package main
   236  
   237  import (
   238  	"context"
   239  	"database/sql"
   240  )
   241  
   242  func main() {
   243  	db, _ := sql.Open("mysql", "user:password@tcp(127.0.0.1:3306)/hello")
   244  
   245  	db.Query("SELECT 2 FROM foo WHERE id=$1", 1)
   246  
   247  	db.Exec("UPDATE foo SET id = $1", 10)
   248  
   249  	ctx := context.Background()
   250  	tx, _ := db.Begin()
   251  	tx.ExecContext(ctx, "INSERT INTO foo (id, value) VALUES ($1, $2)", 1, "hello")
   252  
   253  	db.Query("SELECT 2 FROM foo WHERE id=$1 OR value=$1", 1)
   254  }
   255  	`)
   256  
   257  	fpath := filepath.Join(dir, "main.go")
   258  	err := ioutil.WriteFile(fpath, source, 0644)
   259  	assert.NoError(t, err)
   260  
   261  	queries, err := vet.CheckDir(vet.VetContext{}, dir, "", nil)
   262  	if err != nil {
   263  		t.Fatalf("Failed to load package: %s", err.Error())
   264  		return
   265  	}
   266  	assert.Equal(t, 4, len(queries))
   267  	sort.Slice(queries, func(i, j int) bool {
   268  		return queries[i].Position.Offset < queries[j].Position.Offset
   269  	})
   270  
   271  	assert.NoError(t, queries[0].Err)
   272  	assert.Equal(t, "SELECT 2 FROM foo WHERE id=$1", queries[0].Query)
   273  	assert.Equal(t, 1, queries[0].ParameterArgCount)
   274  
   275  	assert.NoError(t, queries[1].Err)
   276  	assert.Equal(t, "UPDATE foo SET id = $1", queries[1].Query)
   277  	assert.Equal(t, 1, queries[1].ParameterArgCount)
   278  
   279  	assert.NoError(t, queries[2].Err)
   280  	assert.Equal(t, "INSERT INTO foo (id, value) VALUES ($1, $2)", queries[2].Query)
   281  	assert.Equal(t, 2, queries[2].ParameterArgCount)
   282  
   283  	assert.NoError(t, queries[3].Err)
   284  	assert.Equal(t, "SELECT 2 FROM foo WHERE id=$1 OR value=$1", queries[3].Query)
   285  	assert.Equal(t, 1, queries[3].ParameterArgCount)
   286  }
   287  
   288  func (s *GoSourceTests) SubTestBuildFlags(t *testing.T, fixtures struct {
   289  	TmpDir string `fixture:"GoSourceTmpDir"`
   290  }) {
   291  	dir := fixtures.TmpDir
   292  
   293  	source := []byte(`
   294  //+build myBuildTag
   295  
   296  package main
   297  
   298  import (
   299  	"fmt"
   300  )
   301  
   302  func main() {
   303  	fmt.Printf("Hello World\n")
   304  }
   305  `)
   306  
   307  	fpath := filepath.Join(dir, "main.go")
   308  	err := ioutil.WriteFile(fpath, source, 0644)
   309  	assert.NoError(t, err)
   310  
   311  	_, err = vet.CheckDir(vet.VetContext{}, dir, "", nil)
   312  	assert.Error(t, err)
   313  
   314  	_, err = vet.CheckDir(vet.VetContext{}, dir, "-tags myBuildTag", nil)
   315  	assert.NoError(t, err)
   316  }