github.com/samiam2013/sqlvet@v0.0.0-20221210043606-d72f678fc0aa/pkg/vet/gosource_test.go (about) 1 package vet_test 2 3 import ( 4 "io/ioutil" 5 "os" 6 "path/filepath" 7 "sort" 8 "testing" 9 10 "github.com/houqp/gtest" 11 log "github.com/sirupsen/logrus" 12 "github.com/stretchr/testify/assert" 13 14 "github.com/samiam2013/sqlvet/pkg/vet" 15 ) 16 17 type GoSourceTmpDir struct{} 18 19 func (s GoSourceTmpDir) Construct(t *testing.T, fixtures struct{}) (string, string) { 20 dir, err := ioutil.TempDir("", "gosource-tmpdir") 21 assert.NoError(t, err) 22 23 modpath := filepath.Join(dir, "go.mod") 24 err = ioutil.WriteFile(modpath, []byte(` 25 module github.com/houqp/sqlvettest 26 `), 0644) 27 assert.NoError(t, err) 28 29 return dir, dir 30 } 31 32 func (s GoSourceTmpDir) Destruct(t *testing.T, dir string) { 33 os.RemoveAll(dir) 34 } 35 36 func init() { 37 gtest.MustRegisterFixture("GoSourceTmpDir", &GoSourceTmpDir{}, gtest.ScopeSubTest) 38 } 39 40 type GoSourceTests struct{} 41 42 func (s *GoSourceTests) Setup(t *testing.T) { 43 log.SetLevel(log.TraceLevel) 44 } 45 func (s *GoSourceTests) Teardown(t *testing.T) {} 46 func (s *GoSourceTests) BeforeEach(t *testing.T) {} 47 func (s *GoSourceTests) AfterEach(t *testing.T) {} 48 49 func (s *GoSourceTests) SubTestInvalidSyntax(t *testing.T, fixtures struct { 50 TmpDir string `fixture:"GoSourceTmpDir"` 51 }) { 52 dir := fixtures.TmpDir 53 54 fpath := filepath.Join(dir, "main.go") 55 err := ioutil.WriteFile(fpath, []byte(` 56 package main 57 58 func main() { 59 return 1 60 } 61 `), 0644) 62 assert.NoError(t, err) 63 64 _, err = vet.CheckDir(vet.VetContext{}, dir, "", nil) 65 assert.Error(t, err) 66 } 67 68 func (s *GoSourceTests) SubTestSkipNoneDbQueryCall(t *testing.T, fixtures struct { 69 TmpDir string `fixture:"GoSourceTmpDir"` 70 }) { 71 dir := fixtures.TmpDir 72 73 source := []byte(` 74 package main 75 76 type Parameter struct {} 77 78 func (Parameter) Query(s string) error { 79 return nil 80 } 81 82 func NewParam() *Parameter { 83 return &Parameter{} 84 } 85 86 func main() { 87 f := Parameter{} 88 f.Query("user_id") 89 90 func() { 91 func() { 92 // access from outside of closure body scope 93 f.Query("user_id") 94 95 // from function call 96 // scoped within closure 97 flocal := NewParam() 98 flocal.Query("book_id") 99 }() 100 }() 101 } 102 `) 103 104 fpath := filepath.Join(dir, "main.go") 105 err := ioutil.WriteFile(fpath, source, 0644) 106 assert.NoError(t, err) 107 108 queries, err := vet.CheckDir(vet.VetContext{}, dir, "", nil) 109 assert.NoError(t, err) 110 assert.Equal(t, 0, len(queries)) 111 } 112 113 func (s *GoSourceTests) SubTestPkgDatabaseSql(t *testing.T, fixtures struct { 114 TmpDir string `fixture:"GoSourceTmpDir"` 115 }) { 116 dir := fixtures.TmpDir 117 118 source := []byte(` 119 package main 120 121 import ( 122 "context" 123 "database/sql" 124 "fmt" 125 ) 126 127 func main() { 128 db, _ := sql.Open("mysql", "user:password@tcp(127.0.0.1:3306)/hello") 129 db.Query("SELECT 1") 130 131 // sqlvet: ignore 132 db.Query("SELECT 2") 133 134 db.Query("SELECT 3") //sqlvet: ignore 135 136 db.Query( 137 "SELECT 4", 138 ) //sqlvet:ignore 139 140 db.Query("SELECT 5") 141 // sqlvet:ignore 142 143 // context aware methods 144 ctx := context.Background() 145 db.QueryRowContext(ctx, "SELECT 5") 146 147 tx, _ := db.Begin() 148 tx.ExecContext(ctx, "SELECT 6") 149 150 // unsafe string 151 var userInput string 152 tx.Query(fmt.Sprintf("SELECT %s", userInput)) 153 154 // string concat 155 tx.Exec("SELECT " + "7") 156 staticUserId := "id" 157 tx.Exec("SELECT " + staticUserId + " FROM foo") 158 } 159 `) 160 161 fpath := filepath.Join(dir, "main.go") 162 err := ioutil.WriteFile(fpath, source, 0644) 163 assert.NoError(t, err) 164 165 queries, err := vet.CheckDir(vet.VetContext{}, dir, "", nil) 166 if err != nil { 167 t.Fatalf("Failed to load package: %s", err.Error()) 168 return 169 } 170 assert.Equal(t, 6, len(queries)) 171 sort.Slice(queries, func(i, j int) bool { 172 return queries[i].Position.Offset < queries[j].Position.Offset 173 }) 174 175 assert.NoError(t, queries[0].Err) 176 assert.Equal(t, "SELECT 1", queries[0].Query) 177 178 assert.NoError(t, queries[1].Err) 179 assert.Equal(t, "SELECT 5", queries[1].Query) 180 181 assert.NoError(t, queries[2].Err) 182 assert.Equal(t, "SELECT 6", queries[2].Query) 183 184 // unsafe string 185 assert.Error(t, queries[3].Err) 186 187 // string concat 188 assert.NoError(t, queries[4].Err) 189 assert.Equal(t, "SELECT 7", queries[4].Query) 190 assert.NoError(t, queries[5].Err) 191 assert.Equal(t, "SELECT id FROM foo", queries[5].Query) 192 } 193 194 // run sqlvet from parent dir 195 func (s *GoSourceTests) SubTestCheckRelativeDir(t *testing.T, fixtures struct { 196 TmpDir string `fixture:"GoSourceTmpDir"` 197 }) { 198 dir := fixtures.TmpDir 199 200 source := []byte(` 201 package main 202 203 func main() { 204 } 205 `) 206 207 fpath := filepath.Join(dir, "main.go") 208 err := ioutil.WriteFile(fpath, source, 0644) 209 assert.NoError(t, err) 210 211 cwd, err := os.Getwd() 212 assert.NoError(t, err) 213 parentDir := filepath.Dir(dir) 214 os.Chdir(parentDir) 215 defer os.Chdir(cwd) 216 217 queries, err := vet.CheckDir(vet.VetContext{}, filepath.Base(dir), "", nil) 218 if err != nil { 219 t.Fatalf("Failed to load package: %s", err.Error()) 220 return 221 } 222 assert.Equal(t, 0, len(queries)) 223 } 224 225 func TestGoSource(t *testing.T) { 226 gtest.RunSubTests(t, &GoSourceTests{}) 227 } 228 229 func (s *GoSourceTests) SubTestQueryParam(t *testing.T, fixtures struct { 230 TmpDir string `fixture:"GoSourceTmpDir"` 231 }) { 232 dir := fixtures.TmpDir 233 234 source := []byte(` 235 package main 236 237 import ( 238 "context" 239 "database/sql" 240 ) 241 242 func main() { 243 db, _ := sql.Open("mysql", "user:password@tcp(127.0.0.1:3306)/hello") 244 245 db.Query("SELECT 2 FROM foo WHERE id=$1", 1) 246 247 db.Exec("UPDATE foo SET id = $1", 10) 248 249 ctx := context.Background() 250 tx, _ := db.Begin() 251 tx.ExecContext(ctx, "INSERT INTO foo (id, value) VALUES ($1, $2)", 1, "hello") 252 253 db.Query("SELECT 2 FROM foo WHERE id=$1 OR value=$1", 1) 254 } 255 `) 256 257 fpath := filepath.Join(dir, "main.go") 258 err := ioutil.WriteFile(fpath, source, 0644) 259 assert.NoError(t, err) 260 261 queries, err := vet.CheckDir(vet.VetContext{}, dir, "", nil) 262 if err != nil { 263 t.Fatalf("Failed to load package: %s", err.Error()) 264 return 265 } 266 assert.Equal(t, 4, len(queries)) 267 sort.Slice(queries, func(i, j int) bool { 268 return queries[i].Position.Offset < queries[j].Position.Offset 269 }) 270 271 assert.NoError(t, queries[0].Err) 272 assert.Equal(t, "SELECT 2 FROM foo WHERE id=$1", queries[0].Query) 273 assert.Equal(t, 1, queries[0].ParameterArgCount) 274 275 assert.NoError(t, queries[1].Err) 276 assert.Equal(t, "UPDATE foo SET id = $1", queries[1].Query) 277 assert.Equal(t, 1, queries[1].ParameterArgCount) 278 279 assert.NoError(t, queries[2].Err) 280 assert.Equal(t, "INSERT INTO foo (id, value) VALUES ($1, $2)", queries[2].Query) 281 assert.Equal(t, 2, queries[2].ParameterArgCount) 282 283 assert.NoError(t, queries[3].Err) 284 assert.Equal(t, "SELECT 2 FROM foo WHERE id=$1 OR value=$1", queries[3].Query) 285 assert.Equal(t, 1, queries[3].ParameterArgCount) 286 } 287 288 func (s *GoSourceTests) SubTestBuildFlags(t *testing.T, fixtures struct { 289 TmpDir string `fixture:"GoSourceTmpDir"` 290 }) { 291 dir := fixtures.TmpDir 292 293 source := []byte(` 294 //+build myBuildTag 295 296 package main 297 298 import ( 299 "fmt" 300 ) 301 302 func main() { 303 fmt.Printf("Hello World\n") 304 } 305 `) 306 307 fpath := filepath.Join(dir, "main.go") 308 err := ioutil.WriteFile(fpath, source, 0644) 309 assert.NoError(t, err) 310 311 _, err = vet.CheckDir(vet.VetContext{}, dir, "", nil) 312 assert.Error(t, err) 313 314 _, err = vet.CheckDir(vet.VetContext{}, dir, "-tags myBuildTag", nil) 315 assert.NoError(t, err) 316 }