github.com/go-toolsmith/pkgload@v1.2.3-0.20240512101226-e704f55998f1/pkgload_test.go (about)

     1  package pkgload
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"go/build"
     7  	"go/token"
     8  	"os"
     9  	"path/filepath"
    10  	"strings"
    11  	"testing"
    12  
    13  	"golang.org/x/tools/go/packages"
    14  )
    15  
    16  func TestVisitUnits(t *testing.T) {
    17  	tests := []struct {
    18  		path string
    19  		desc string
    20  	}{
    21  		{"./testdata/all_included", "Base+Test+ExternalTest+TestBinary"},
    22  		{"./testdata/base_only", "Base"},
    23  		{"./testdata/base_with_ext_tests", "Base+ExternalTest+TestBinary"},
    24  		{"./testdata/base_with_tests", "Base+Test+TestBinary"},
    25  		{"./testdata/empty", ""},
    26  		{"./testdata/main_only", "Base"},
    27  		{"./testdata/main_with_tests", "Base+Test+TestBinary"},
    28  
    29  		{"./testdata/horrors/base-only/blah.test", "Base"},
    30  		{"./testdata/horrors/base-only/blah_test", "Base"},
    31  		{"./testdata/horrors/base-with-tests/blah.test", "Base+Test+TestBinary"},
    32  		{"./testdata/horrors/base-with-tests/blah_test", "Base+Test+TestBinary"},
    33  		{"./testdata/horrors/test_data", "ExternalTest"},
    34  	}
    35  
    36  	checkFields := func(desc string, u *Unit) error {
    37  		for _, key := range strings.Split(desc, "+") {
    38  			switch key {
    39  			case "Base":
    40  				if u.Base == nil {
    41  					return errors.New("Base is missing")
    42  				}
    43  			case "Test":
    44  				if u.Test == nil {
    45  					return errors.New("Test is missing")
    46  				}
    47  			case "ExternalTest":
    48  				if u.ExternalTest == nil {
    49  					return errors.New("ExternalTest is missing")
    50  				}
    51  			case "TestBinary":
    52  				if u.TestBinary == nil {
    53  					return errors.New("TestBinary is missing")
    54  				}
    55  			default:
    56  				panic(fmt.Sprintf("unexpected key: %v", key))
    57  			}
    58  		}
    59  		return nil
    60  	}
    61  
    62  	type testPackage struct {
    63  		filePath string
    64  		pkgPath  string
    65  		desc     string
    66  	}
    67  
    68  	paths := make([]string, len(tests))
    69  	testsMap := make(map[string]testPackage)
    70  	for i, test := range tests {
    71  		pkgPath := "github.com/go-toolsmith/pkgload/" + strings.TrimPrefix(test.path, "./")
    72  		paths[i] = tests[i].path
    73  		absPath, err := filepath.Abs(tests[i].path)
    74  		if err != nil {
    75  			t.Fatalf("get abs path: %v", err)
    76  		}
    77  		testsMap[pkgPath] = testPackage{
    78  			filePath: absPath,
    79  			pkgPath:  pkgPath,
    80  			desc:     test.desc,
    81  		}
    82  	}
    83  
    84  	runWithMode := func(name string, mode packages.LoadMode, fn func(*packages.Config, *testing.T)) {
    85  		t.Run(name, func(t *testing.T) {
    86  			cfg := packages.Config{Mode: mode, Tests: true, Fset: token.NewFileSet()}
    87  			fn(&cfg, t)
    88  		})
    89  	}
    90  	runWithAllModes := func(name string, fn func(*packages.Config, *testing.T)) {
    91  		runWithMode(name+"/Files", packages.LoadFiles, fn)
    92  		runWithMode(name+"/LoadImports", packages.LoadImports, fn)
    93  		runWithMode(name+"/LoadTypes", packages.LoadTypes, fn)
    94  		runWithMode(name+"/LoadSyntax", packages.LoadSyntax, fn)
    95  	}
    96  
    97  	// Check that loading GOROOT packages does not cause
    98  	// VisitUnits to panic.
    99  	runWithAllModes("loadStd", func(cfg *packages.Config, t *testing.T) {
   100  		goroot := build.Default.GOROOT
   101  		wd, err := os.Getwd()
   102  		if err != nil {
   103  			t.Skipf("can't get wd: %v", err)
   104  		}
   105  		defer func(prev string) {
   106  			if err := os.Chdir(prev); err != nil {
   107  				panic(fmt.Sprintf("can't go back: %v", err))
   108  			}
   109  		}(wd)
   110  		if err := os.Chdir(goroot); err != nil {
   111  			t.Skipf("chdir: %v", err)
   112  		}
   113  		pkgs, err := packages.Load(cfg, "std")
   114  		if err != nil {
   115  			t.Fatalf("load packages: %v", err)
   116  		}
   117  		VisitUnits(pkgs, func(u *Unit) {})
   118  	})
   119  
   120  	runWithAllModes("loadAll", func(cfg *packages.Config, t *testing.T) {
   121  		pkgs, err := packages.Load(cfg, paths...)
   122  		if err != nil {
   123  			t.Fatalf("load packages: %v", err)
   124  		}
   125  		remains := len(testsMap) - 1 // Substract the empty unit
   126  		VisitUnits(pkgs, func(u *Unit) {
   127  			p, ok := testsMap[u.NonNil().PkgPath]
   128  			if !ok {
   129  				t.Fatalf("unmatched pkg path %q", u.NonNil().PkgPath)
   130  			}
   131  			remains--
   132  			if err := checkFields(p.desc, u); err != nil {
   133  				t.Errorf("%q: check %q: %v",
   134  					u.NonNil().PkgPath, p.desc, err)
   135  			}
   136  		})
   137  		if remains != 0 {
   138  			t.Errorf("unprocessed units: %d", remains)
   139  		}
   140  	})
   141  
   142  	runWithAllModes("loadOneByOne", func(cfg *packages.Config, t *testing.T) {
   143  		for _, path := range paths {
   144  			pkgs, err := packages.Load(cfg, path)
   145  			if err != nil {
   146  				t.Fatalf("load packages: %v", err)
   147  			}
   148  			VisitUnits(pkgs, func(u *Unit) {
   149  				p, ok := testsMap[u.NonNil().PkgPath]
   150  				if !ok {
   151  					t.Fatalf("unmatched pkg path %q", u.NonNil().PkgPath)
   152  				}
   153  				if err := checkFields(p.desc, u); err != nil {
   154  					t.Errorf("%q: check %q: %v",
   155  						u.NonNil().PkgPath, p.desc, err)
   156  				}
   157  			})
   158  		}
   159  	})
   160  }