github.com/jd-ly/tools@v0.5.7/internal/apidiff/apidiff_test.go (about) 1 package apidiff 2 3 import ( 4 "bufio" 5 "fmt" 6 "go/types" 7 "io/ioutil" 8 "os" 9 "path/filepath" 10 "reflect" 11 "sort" 12 "strings" 13 "testing" 14 15 "github.com/jd-ly/tools/go/packages" 16 "github.com/jd-ly/tools/internal/testenv" 17 ) 18 19 func TestChanges(t *testing.T) { 20 dir, err := ioutil.TempDir("", "apidiff_test") 21 if err != nil { 22 t.Fatal(err) 23 } 24 dir = filepath.Join(dir, "go") 25 wanti, wantc := splitIntoPackages(t, dir) 26 defer os.RemoveAll(dir) 27 sort.Strings(wanti) 28 sort.Strings(wantc) 29 30 oldpkg, err := load(t, "apidiff/old", dir) 31 if err != nil { 32 t.Fatal(err) 33 } 34 newpkg, err := load(t, "apidiff/new", dir) 35 if err != nil { 36 t.Fatal(err) 37 } 38 39 report := Changes(oldpkg.Types, newpkg.Types) 40 41 got := report.messages(false) 42 if !reflect.DeepEqual(got, wanti) { 43 t.Errorf("incompatibles: got %v\nwant %v\n", got, wanti) 44 } 45 got = report.messages(true) 46 if !reflect.DeepEqual(got, wantc) { 47 t.Errorf("compatibles: got %v\nwant %v\n", got, wantc) 48 } 49 } 50 51 func splitIntoPackages(t *testing.T, dir string) (incompatibles, compatibles []string) { 52 // Read the input file line by line. 53 // Write a line into the old or new package, 54 // dependent on comments. 55 // Also collect expected messages. 56 f, err := os.Open("testdata/tests.go") 57 if err != nil { 58 t.Fatal(err) 59 } 60 defer f.Close() 61 62 if err := os.MkdirAll(filepath.Join(dir, "src", "apidiff"), 0700); err != nil { 63 t.Fatal(err) 64 } 65 if err := ioutil.WriteFile(filepath.Join(dir, "src", "apidiff", "go.mod"), []byte("module apidiff\n"), 0666); err != nil { 66 t.Fatal(err) 67 } 68 69 oldd := filepath.Join(dir, "src/apidiff/old") 70 newd := filepath.Join(dir, "src/apidiff/new") 71 if err := os.MkdirAll(oldd, 0700); err != nil { 72 t.Fatal(err) 73 } 74 if err := os.Mkdir(newd, 0700); err != nil && !os.IsExist(err) { 75 t.Fatal(err) 76 } 77 78 oldf, err := os.Create(filepath.Join(oldd, "old.go")) 79 if err != nil { 80 t.Fatal(err) 81 } 82 newf, err := os.Create(filepath.Join(newd, "new.go")) 83 if err != nil { 84 t.Fatal(err) 85 } 86 87 wl := func(f *os.File, line string) { 88 if _, err := fmt.Fprintln(f, line); err != nil { 89 t.Fatal(err) 90 } 91 } 92 writeBoth := func(line string) { wl(oldf, line); wl(newf, line) } 93 writeln := writeBoth 94 s := bufio.NewScanner(f) 95 for s.Scan() { 96 line := s.Text() 97 tl := strings.TrimSpace(line) 98 switch { 99 case tl == "// old": 100 writeln = func(line string) { wl(oldf, line) } 101 case tl == "// new": 102 writeln = func(line string) { wl(newf, line) } 103 case tl == "// both": 104 writeln = writeBoth 105 case strings.HasPrefix(tl, "// i "): 106 incompatibles = append(incompatibles, strings.TrimSpace(tl[4:])) 107 case strings.HasPrefix(tl, "// c "): 108 compatibles = append(compatibles, strings.TrimSpace(tl[4:])) 109 default: 110 writeln(line) 111 } 112 } 113 if s.Err() != nil { 114 t.Fatal(s.Err()) 115 } 116 return 117 } 118 119 func load(t *testing.T, importPath, goPath string) (*packages.Package, error) { 120 testenv.NeedsGoPackages(t) 121 122 cfg := &packages.Config{ 123 Mode: packages.LoadTypes, 124 } 125 if goPath != "" { 126 cfg.Env = append(os.Environ(), "GOPATH="+goPath) 127 cfg.Dir = filepath.Join(goPath, "src", filepath.FromSlash(importPath)) 128 } 129 pkgs, err := packages.Load(cfg, importPath) 130 if err != nil { 131 return nil, err 132 } 133 if len(pkgs[0].Errors) > 0 { 134 return nil, pkgs[0].Errors[0] 135 } 136 return pkgs[0], nil 137 } 138 139 func TestExportedFields(t *testing.T) { 140 pkg, err := load(t, "github.com/jd-ly/tools/internal/apidiff/testdata/exported_fields", "") 141 if err != nil { 142 t.Fatal(err) 143 } 144 typeof := func(name string) types.Type { 145 return pkg.Types.Scope().Lookup(name).Type() 146 } 147 148 s := typeof("S") 149 su := s.(*types.Named).Underlying().(*types.Struct) 150 151 ef := exportedSelectableFields(su) 152 wants := []struct { 153 name string 154 typ types.Type 155 }{ 156 {"A1", typeof("A1")}, 157 {"D", types.Typ[types.Bool]}, 158 {"E", types.Typ[types.Int]}, 159 {"F", typeof("F")}, 160 {"S", types.NewPointer(s)}, 161 } 162 163 if got, want := len(ef), len(wants); got != want { 164 t.Errorf("got %d fields, want %d\n%+v", got, want, ef) 165 } 166 for _, w := range wants { 167 if got := ef[w.name]; got != nil && !types.Identical(got.Type(), w.typ) { 168 t.Errorf("%s: got %v, want %v", w.name, got.Type(), w.typ) 169 } 170 } 171 }