github.com/ninadingole/gotest-ls@v0.0.3/pkg/list.go (about) 1 // Package pkg contains the core logic of the gotest-ls tool which finds all the go test files. 2 // and using ast package, it lists all the tests in the given files. 3 package pkg 4 5 import ( 6 "fmt" 7 "go/ast" 8 "go/parser" 9 "go/token" 10 "io/fs" 11 "path/filepath" 12 "sort" 13 "strings" 14 ) 15 16 // testType represents the type of test function. 17 type testType int 18 19 const ( 20 testTypeNone testType = iota 21 testTypeSubTest 22 testTypeTableTest 23 ) 24 25 // TestDetail is a struct that contains the details of a single test. 26 // It contains the name of the test, the line number, the file name, the relative path and the absolute path. 27 // It also contains the token position (token.Pos) of the test in the file. 28 type TestDetail struct { 29 Name string `json:"name"` 30 FileName string `json:"fileName"` 31 RelativePath string `json:"relativePath"` 32 AbsolutePath string `json:"absolutePath"` 33 Line int `json:"line"` 34 Pos token.Pos `json:"pos"` 35 } 36 37 // subTestDetail returns the testname and the position of the subtest in the file. 38 type subTestDetail struct { 39 name string 40 pos token.Pos 41 } 42 43 // List returns all the go test files in the given directories or a given file. 44 // It returns an error if the given directories are invalid. 45 // It returns an empty slice if no tests are found. 46 // The returned slice is sorted by the test name. 47 func List(fileOrDirs []string) ([]TestDetail, error) { 48 files, err := loadFiles(fileOrDirs) 49 if err != nil { 50 return nil, err 51 } 52 53 tests, err := listTests(files) 54 if err != nil { 55 return nil, err 56 } 57 58 return tests, nil 59 } 60 61 // loadFiles loads all the go files in the given paths. 62 func loadFiles(dirs []string) (map[string][]string, error) { 63 testFiles := make(map[string][]string) 64 65 for _, dir := range dirs { 66 err := filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { 67 if err != nil { 68 return err 69 } 70 71 if !d.IsDir() && filepath.Ext(path) == ".go" && strings.HasSuffix(path, "_test.go") { 72 testFiles[dir] = append(testFiles[dir], path) 73 } 74 75 return nil 76 }) 77 if err != nil { 78 return nil, err 79 } 80 } 81 82 return testFiles, nil 83 } 84 85 // listTests lists all the tests in the given go test files. 86 func listTests(files map[string][]string) ([]TestDetail, error) { //nolint: gocognit 87 var tests []TestDetail 88 89 for dir, testFiles := range files { 90 for _, testFile := range testFiles { 91 set := token.NewFileSet() 92 93 parseFile, err := parser.ParseFile(set, testFile, nil, parser.ParseComments) 94 if err != nil { 95 return nil, err 96 } 97 98 for _, obj := range parseFile.Scope.Objects { 99 if obj.Kind == ast.Fun { 100 if isGolangTest(obj) { 101 isSubTest := false 102 103 if fnDecl, ok := obj.Decl.(*ast.FuncDecl); ok { 104 for i, v := range fnDecl.Body.List { 105 switch identifyTestType(v) { 106 case testTypeSubTest: 107 isSubTest = true 108 109 if test := findSubTestName(v); test != nil { 110 tests = append(tests, buildTestDetail(obj, test.name, dir, testFile, set, test.pos)) 111 } 112 113 case testTypeTableTest: 114 isSubTest = true 115 testNameFieldInStruct := findTableTestNameField(v) 116 117 if testNameFieldInStruct != "" { 118 for j := i; j > 0; j-- { 119 if ttDetails := parseTableTestStructsIfAny(fnDecl.Body.List[j], testNameFieldInStruct); ttDetails != nil { 120 for _, ttDetail := range ttDetails { 121 tests = append(tests, buildTestDetail(obj, ttDetail.name, dir, testFile, set, ttDetail.pos)) 122 } 123 } 124 } 125 } 126 case testTypeNone: 127 continue 128 } 129 } 130 } 131 132 if !isSubTest { 133 tests = append(tests, buildTestDetail(obj, "", dir, testFile, set, obj.Pos())) 134 } 135 } 136 } 137 } 138 } 139 } 140 141 // sort the tests by name 142 sort.Slice(tests, func(i, j int) bool { 143 return strings.Compare(tests[i].Name, tests[j].Name) < 0 144 }) 145 146 return tests, nil 147 } 148 149 // isGolangTest checks if the function name starts with golang test standards 150 // it checks for `Test`, `Example` or `Benchmark` prefixes in a function name. 151 // Other than test functions all the other functions are ignored. 152 func isGolangTest(obj *ast.Object) bool { 153 return strings.HasPrefix(obj.Name, "Test") || 154 strings.HasPrefix(obj.Name, "Example") || 155 strings.HasPrefix(obj.Name, "Benchmark") 156 } 157 158 // identifyTestType identifies the type of the test based on the given ast node. 159 // it looks for `t.Run` function in the test function body. If the test contains subtests then it returns 160 // testTypeSubTest. If the test contains table tests then it returns testTypeTableTest. 161 // Otherwise, it returns testTypeNone. 162 func identifyTestType(v ast.Stmt) testType { 163 if expr, ok := v.(*ast.ExprStmt); ok { 164 if callExpr, ok := expr.X.(*ast.CallExpr); ok { 165 if selectorExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok { 166 if selectorExpr.Sel.Name == "Run" { 167 return testTypeSubTest 168 } 169 } 170 } 171 } 172 173 if expr, ok := v.(*ast.RangeStmt); ok { 174 for _, v := range expr.Body.List { 175 if typ := identifyTestType(v); typ == testTypeSubTest { 176 return testTypeTableTest 177 } 178 } 179 } 180 181 return testTypeNone 182 } 183 184 // findSubTestName finds the name of the subtest in the given ast node. 185 // it looks for `t.Run` function in the test function body. If the test contains subtests then it returns 186 // the name of the subtest. 187 // A test would look like this in the source code. 188 // 189 // func Test_subTestPattern(t *testing.T) { 190 // t.Parallel() 191 // 192 // msg := "Hello, world!" 193 // 194 // t.Run("subtest", func(t *testing.T) { 195 // t.Parallel() 196 // t.Log(msg) 197 // }) 198 // 199 // t.Run("subtest 2", func(t *testing.T) { 200 // t.Parallel() 201 // t.Log("This is a subtest") 202 // }) 203 // } 204 func findSubTestName(v ast.Stmt) *subTestDetail { 205 if expr, ok := v.(*ast.ExprStmt); ok { 206 if callExpr, ok := expr.X.(*ast.CallExpr); ok { 207 if basic, ok := callExpr.Args[0].(*ast.BasicLit); ok { 208 return &subTestDetail{ 209 name: basic.Value, 210 pos: callExpr.Pos(), 211 } 212 } 213 } 214 } 215 216 return nil 217 } 218 219 // buildTestDetail returns the TestDetail object with the information received from the given parameters. 220 func buildTestDetail( 221 obj *ast.Object, 222 name string, 223 dir string, 224 file string, 225 set *token.FileSet, 226 pos token.Pos, 227 ) TestDetail { 228 fileAbsPath, err := filepath.Abs(file) 229 if err != nil { 230 panic(fmt.Errorf("failed to get absolute path of file %s: %w", file, err)) 231 } 232 233 fileName := filepath.Base(file) 234 235 relativePath, err := filepath.Rel(filepath.Dir(dir), file) 236 if err != nil { 237 panic(fmt.Errorf("failed to get relative path of file %s: %w", file, err)) 238 } 239 240 detail := TestDetail{ 241 Name: obj.Name, 242 FileName: fileName, 243 RelativePath: relativePath, 244 AbsolutePath: fileAbsPath, 245 Line: set.Position(pos).Line, 246 Pos: pos, 247 } 248 249 if name != "" { 250 detail.Name = fmt.Sprintf("%s/%s", obj.Name, 251 strings.ReplaceAll(strings.ReplaceAll(name, "\"", ""), " ", "_")) 252 } 253 254 return detail 255 } 256 257 // findTableTestNameField returns the name of the field in the table test struct which contains the test name. 258 // it looks for the field used in `t.Run` inside the for-loop of a table test and returns the name of the parameter 259 // from the struct that is used to populate the test name. 260 // A typical table test range function would look like this in the source code. 261 // 262 // for _, tt := range tests { 263 // tt := tt 264 // t.Run(tt.name, func(t *testing.T) { 265 // t.Parallel() 266 // 267 // if got := tt.calc(); got != tt.want { 268 // t.Errorf("got %d, want %d", got, tt.want) 269 // } 270 // }) 271 // } 272 func findTableTestNameField(v ast.Stmt) string { 273 if rangeStmt, ok := v.(*ast.RangeStmt); ok { 274 for _, stmt := range rangeStmt.Body.List { 275 if exprStmt, ok := stmt.(*ast.ExprStmt); ok { 276 if callExpr, ok := exprStmt.X.(*ast.CallExpr); ok { 277 if selectorExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok { 278 if ident, ok := selectorExpr.X.(*ast.Ident); ok { 279 if ident.Name == "t" && selectorExpr.Sel.Name == "Run" { 280 if sExpr, ok := callExpr.Args[0].(*ast.SelectorExpr); ok { 281 return strings.ReplaceAll(sExpr.Sel.Name, "\"", "") 282 } 283 } 284 } 285 } 286 } 287 } 288 } 289 } 290 291 return "" 292 } 293 294 // parseTableTestStructsIfAny parses the struct array in the table test and returns the value of the field that 295 // will be passed to `t.Run` function when the test is run. 296 func parseTableTestStructsIfAny(v ast.Stmt, fieldName string) []subTestDetail { 297 var values []subTestDetail 298 299 if assignStmt, ok := v.(*ast.AssignStmt); ok { 300 for _, expr := range assignStmt.Rhs { 301 if cmpsLit, ok := expr.(*ast.CompositeLit); ok { 302 for _, elt := range cmpsLit.Elts { 303 if compositeLit, ok := elt.(*ast.CompositeLit); ok { 304 for _, elt := range compositeLit.Elts { 305 if kvExpr, ok := elt.(*ast.KeyValueExpr); ok { 306 if key, ok := kvExpr.Key.(*ast.Ident); ok { 307 if key.Name == fieldName { 308 if value, ok := kvExpr.Value.(*ast.BasicLit); ok { 309 values = append(values, 310 subTestDetail{ 311 name: value.Value, 312 pos: key.Pos(), 313 }) 314 } 315 } 316 } 317 } 318 } 319 } 320 } 321 } 322 } 323 } 324 325 return values 326 }