github.com/ninadingole/gotest-ls@v0.0.3/pkg/list.go (about)

     1  // Package pkg contains the core logic of the gotest-ls tool which finds all the go test files.
     2  // and using ast package, it lists all the tests in the given files.
     3  package pkg
     4  
     5  import (
     6  	"fmt"
     7  	"go/ast"
     8  	"go/parser"
     9  	"go/token"
    10  	"io/fs"
    11  	"path/filepath"
    12  	"sort"
    13  	"strings"
    14  )
    15  
    16  // testType represents the type of test function.
    17  type testType int
    18  
    19  const (
    20  	testTypeNone testType = iota
    21  	testTypeSubTest
    22  	testTypeTableTest
    23  )
    24  
    25  // TestDetail is a struct that contains the details of a single test.
    26  // It contains the name of the test, the line number, the file name, the relative path and the absolute path.
    27  // It also contains the token position (token.Pos) of the test in the file.
    28  type TestDetail struct {
    29  	Name         string    `json:"name"`
    30  	FileName     string    `json:"fileName"`
    31  	RelativePath string    `json:"relativePath"`
    32  	AbsolutePath string    `json:"absolutePath"`
    33  	Line         int       `json:"line"`
    34  	Pos          token.Pos `json:"pos"`
    35  }
    36  
    37  // subTestDetail returns the testname and the position of the subtest in the file.
    38  type subTestDetail struct {
    39  	name string
    40  	pos  token.Pos
    41  }
    42  
    43  // List returns all the go test files in the given directories or a given file.
    44  // It returns an error if the given directories are invalid.
    45  // It returns an empty slice if no tests are found.
    46  // The returned slice is sorted by the test name.
    47  func List(fileOrDirs []string) ([]TestDetail, error) {
    48  	files, err := loadFiles(fileOrDirs)
    49  	if err != nil {
    50  		return nil, err
    51  	}
    52  
    53  	tests, err := listTests(files)
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  
    58  	return tests, nil
    59  }
    60  
    61  // loadFiles loads all the go files in the given paths.
    62  func loadFiles(dirs []string) (map[string][]string, error) {
    63  	testFiles := make(map[string][]string)
    64  
    65  	for _, dir := range dirs {
    66  		err := filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error {
    67  			if err != nil {
    68  				return err
    69  			}
    70  
    71  			if !d.IsDir() && filepath.Ext(path) == ".go" && strings.HasSuffix(path, "_test.go") {
    72  				testFiles[dir] = append(testFiles[dir], path)
    73  			}
    74  
    75  			return nil
    76  		})
    77  		if err != nil {
    78  			return nil, err
    79  		}
    80  	}
    81  
    82  	return testFiles, nil
    83  }
    84  
    85  // listTests lists all the tests in the given go test files.
    86  func listTests(files map[string][]string) ([]TestDetail, error) { //nolint: gocognit
    87  	var tests []TestDetail
    88  
    89  	for dir, testFiles := range files {
    90  		for _, testFile := range testFiles {
    91  			set := token.NewFileSet()
    92  
    93  			parseFile, err := parser.ParseFile(set, testFile, nil, parser.ParseComments)
    94  			if err != nil {
    95  				return nil, err
    96  			}
    97  
    98  			for _, obj := range parseFile.Scope.Objects {
    99  				if obj.Kind == ast.Fun {
   100  					if isGolangTest(obj) {
   101  						isSubTest := false
   102  
   103  						if fnDecl, ok := obj.Decl.(*ast.FuncDecl); ok {
   104  							for i, v := range fnDecl.Body.List {
   105  								switch identifyTestType(v) {
   106  								case testTypeSubTest:
   107  									isSubTest = true
   108  
   109  									if test := findSubTestName(v); test != nil {
   110  										tests = append(tests, buildTestDetail(obj, test.name, dir, testFile, set, test.pos))
   111  									}
   112  
   113  								case testTypeTableTest:
   114  									isSubTest = true
   115  									testNameFieldInStruct := findTableTestNameField(v)
   116  
   117  									if testNameFieldInStruct != "" {
   118  										for j := i; j > 0; j-- {
   119  											if ttDetails := parseTableTestStructsIfAny(fnDecl.Body.List[j], testNameFieldInStruct); ttDetails != nil {
   120  												for _, ttDetail := range ttDetails {
   121  													tests = append(tests, buildTestDetail(obj, ttDetail.name, dir, testFile, set, ttDetail.pos))
   122  												}
   123  											}
   124  										}
   125  									}
   126  								case testTypeNone:
   127  									continue
   128  								}
   129  							}
   130  						}
   131  
   132  						if !isSubTest {
   133  							tests = append(tests, buildTestDetail(obj, "", dir, testFile, set, obj.Pos()))
   134  						}
   135  					}
   136  				}
   137  			}
   138  		}
   139  	}
   140  
   141  	// sort the tests by name
   142  	sort.Slice(tests, func(i, j int) bool {
   143  		return strings.Compare(tests[i].Name, tests[j].Name) < 0
   144  	})
   145  
   146  	return tests, nil
   147  }
   148  
   149  // isGolangTest checks if the function name starts with golang test standards
   150  // it checks for `Test`, `Example` or `Benchmark` prefixes in a function name.
   151  // Other than test functions all the other functions are ignored.
   152  func isGolangTest(obj *ast.Object) bool {
   153  	return strings.HasPrefix(obj.Name, "Test") ||
   154  		strings.HasPrefix(obj.Name, "Example") ||
   155  		strings.HasPrefix(obj.Name, "Benchmark")
   156  }
   157  
   158  // identifyTestType identifies the type of the test based on the given ast node.
   159  // it looks for `t.Run` function in the test function body. If the test contains subtests then it returns
   160  // testTypeSubTest. If the test contains table tests then it returns testTypeTableTest.
   161  // Otherwise, it returns testTypeNone.
   162  func identifyTestType(v ast.Stmt) testType {
   163  	if expr, ok := v.(*ast.ExprStmt); ok {
   164  		if callExpr, ok := expr.X.(*ast.CallExpr); ok {
   165  			if selectorExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok {
   166  				if selectorExpr.Sel.Name == "Run" {
   167  					return testTypeSubTest
   168  				}
   169  			}
   170  		}
   171  	}
   172  
   173  	if expr, ok := v.(*ast.RangeStmt); ok {
   174  		for _, v := range expr.Body.List {
   175  			if typ := identifyTestType(v); typ == testTypeSubTest {
   176  				return testTypeTableTest
   177  			}
   178  		}
   179  	}
   180  
   181  	return testTypeNone
   182  }
   183  
   184  // findSubTestName finds the name of the subtest in the given ast node.
   185  // it looks for `t.Run` function in the test function body. If the test contains subtests then it returns
   186  // the name of the subtest.
   187  // A test would look like this in the source code.
   188  //
   189  //	func Test_subTestPattern(t *testing.T) {
   190  //		t.Parallel()
   191  //
   192  //		msg := "Hello, world!"
   193  //
   194  //		t.Run("subtest", func(t *testing.T) {
   195  //			t.Parallel()
   196  //			t.Log(msg)
   197  //		})
   198  //
   199  //		t.Run("subtest 2", func(t *testing.T) {
   200  //			t.Parallel()
   201  //			t.Log("This is a subtest")
   202  //		})
   203  //	}
   204  func findSubTestName(v ast.Stmt) *subTestDetail {
   205  	if expr, ok := v.(*ast.ExprStmt); ok {
   206  		if callExpr, ok := expr.X.(*ast.CallExpr); ok {
   207  			if basic, ok := callExpr.Args[0].(*ast.BasicLit); ok {
   208  				return &subTestDetail{
   209  					name: basic.Value,
   210  					pos:  callExpr.Pos(),
   211  				}
   212  			}
   213  		}
   214  	}
   215  
   216  	return nil
   217  }
   218  
   219  // buildTestDetail returns the TestDetail object with the information received from the given parameters.
   220  func buildTestDetail(
   221  	obj *ast.Object,
   222  	name string,
   223  	dir string,
   224  	file string,
   225  	set *token.FileSet,
   226  	pos token.Pos,
   227  ) TestDetail {
   228  	fileAbsPath, err := filepath.Abs(file)
   229  	if err != nil {
   230  		panic(fmt.Errorf("failed to get absolute path of file %s: %w", file, err))
   231  	}
   232  
   233  	fileName := filepath.Base(file)
   234  
   235  	relativePath, err := filepath.Rel(filepath.Dir(dir), file)
   236  	if err != nil {
   237  		panic(fmt.Errorf("failed to get relative path of file %s: %w", file, err))
   238  	}
   239  
   240  	detail := TestDetail{
   241  		Name:         obj.Name,
   242  		FileName:     fileName,
   243  		RelativePath: relativePath,
   244  		AbsolutePath: fileAbsPath,
   245  		Line:         set.Position(pos).Line,
   246  		Pos:          pos,
   247  	}
   248  
   249  	if name != "" {
   250  		detail.Name = fmt.Sprintf("%s/%s", obj.Name,
   251  			strings.ReplaceAll(strings.ReplaceAll(name, "\"", ""), " ", "_"))
   252  	}
   253  
   254  	return detail
   255  }
   256  
   257  // findTableTestNameField returns the name of the field in the table test struct which contains the test name.
   258  // it looks for the field used in `t.Run` inside the for-loop of a table test and returns the name of the parameter
   259  // from the struct that is used to populate the test name.
   260  // A typical table test range function would look like this in the source code.
   261  //
   262  //	for _, tt := range tests {
   263  //			tt := tt
   264  //			t.Run(tt.name, func(t *testing.T) {
   265  //				t.Parallel()
   266  //
   267  //				if got := tt.calc(); got != tt.want {
   268  //					t.Errorf("got %d, want %d", got, tt.want)
   269  //				}
   270  //			})
   271  //		}
   272  func findTableTestNameField(v ast.Stmt) string {
   273  	if rangeStmt, ok := v.(*ast.RangeStmt); ok {
   274  		for _, stmt := range rangeStmt.Body.List {
   275  			if exprStmt, ok := stmt.(*ast.ExprStmt); ok {
   276  				if callExpr, ok := exprStmt.X.(*ast.CallExpr); ok {
   277  					if selectorExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok {
   278  						if ident, ok := selectorExpr.X.(*ast.Ident); ok {
   279  							if ident.Name == "t" && selectorExpr.Sel.Name == "Run" {
   280  								if sExpr, ok := callExpr.Args[0].(*ast.SelectorExpr); ok {
   281  									return strings.ReplaceAll(sExpr.Sel.Name, "\"", "")
   282  								}
   283  							}
   284  						}
   285  					}
   286  				}
   287  			}
   288  		}
   289  	}
   290  
   291  	return ""
   292  }
   293  
   294  // parseTableTestStructsIfAny parses the struct array in the table test and returns the value of the field that
   295  // will be passed to `t.Run` function when the test is run.
   296  func parseTableTestStructsIfAny(v ast.Stmt, fieldName string) []subTestDetail {
   297  	var values []subTestDetail
   298  
   299  	if assignStmt, ok := v.(*ast.AssignStmt); ok {
   300  		for _, expr := range assignStmt.Rhs {
   301  			if cmpsLit, ok := expr.(*ast.CompositeLit); ok {
   302  				for _, elt := range cmpsLit.Elts {
   303  					if compositeLit, ok := elt.(*ast.CompositeLit); ok {
   304  						for _, elt := range compositeLit.Elts {
   305  							if kvExpr, ok := elt.(*ast.KeyValueExpr); ok {
   306  								if key, ok := kvExpr.Key.(*ast.Ident); ok {
   307  									if key.Name == fieldName {
   308  										if value, ok := kvExpr.Value.(*ast.BasicLit); ok {
   309  											values = append(values,
   310  												subTestDetail{
   311  													name: value.Value,
   312  													pos:  key.Pos(),
   313  												})
   314  										}
   315  									}
   316  								}
   317  							}
   318  						}
   319  					}
   320  				}
   321  			}
   322  		}
   323  	}
   324  
   325  	return values
   326  }