github.com/jd-ly/tools@v0.5.7/internal/apidiff/apidiff_test.go (about)

     1  package apidiff
     2  
     3  import (
     4  	"bufio"
     5  	"fmt"
     6  	"go/types"
     7  	"io/ioutil"
     8  	"os"
     9  	"path/filepath"
    10  	"reflect"
    11  	"sort"
    12  	"strings"
    13  	"testing"
    14  
    15  	"github.com/jd-ly/tools/go/packages"
    16  	"github.com/jd-ly/tools/internal/testenv"
    17  )
    18  
    19  func TestChanges(t *testing.T) {
    20  	dir, err := ioutil.TempDir("", "apidiff_test")
    21  	if err != nil {
    22  		t.Fatal(err)
    23  	}
    24  	dir = filepath.Join(dir, "go")
    25  	wanti, wantc := splitIntoPackages(t, dir)
    26  	defer os.RemoveAll(dir)
    27  	sort.Strings(wanti)
    28  	sort.Strings(wantc)
    29  
    30  	oldpkg, err := load(t, "apidiff/old", dir)
    31  	if err != nil {
    32  		t.Fatal(err)
    33  	}
    34  	newpkg, err := load(t, "apidiff/new", dir)
    35  	if err != nil {
    36  		t.Fatal(err)
    37  	}
    38  
    39  	report := Changes(oldpkg.Types, newpkg.Types)
    40  
    41  	got := report.messages(false)
    42  	if !reflect.DeepEqual(got, wanti) {
    43  		t.Errorf("incompatibles: got %v\nwant %v\n", got, wanti)
    44  	}
    45  	got = report.messages(true)
    46  	if !reflect.DeepEqual(got, wantc) {
    47  		t.Errorf("compatibles: got %v\nwant %v\n", got, wantc)
    48  	}
    49  }
    50  
    51  func splitIntoPackages(t *testing.T, dir string) (incompatibles, compatibles []string) {
    52  	// Read the input file line by line.
    53  	// Write a line into the old or new package,
    54  	// dependent on comments.
    55  	// Also collect expected messages.
    56  	f, err := os.Open("testdata/tests.go")
    57  	if err != nil {
    58  		t.Fatal(err)
    59  	}
    60  	defer f.Close()
    61  
    62  	if err := os.MkdirAll(filepath.Join(dir, "src", "apidiff"), 0700); err != nil {
    63  		t.Fatal(err)
    64  	}
    65  	if err := ioutil.WriteFile(filepath.Join(dir, "src", "apidiff", "go.mod"), []byte("module apidiff\n"), 0666); err != nil {
    66  		t.Fatal(err)
    67  	}
    68  
    69  	oldd := filepath.Join(dir, "src/apidiff/old")
    70  	newd := filepath.Join(dir, "src/apidiff/new")
    71  	if err := os.MkdirAll(oldd, 0700); err != nil {
    72  		t.Fatal(err)
    73  	}
    74  	if err := os.Mkdir(newd, 0700); err != nil && !os.IsExist(err) {
    75  		t.Fatal(err)
    76  	}
    77  
    78  	oldf, err := os.Create(filepath.Join(oldd, "old.go"))
    79  	if err != nil {
    80  		t.Fatal(err)
    81  	}
    82  	newf, err := os.Create(filepath.Join(newd, "new.go"))
    83  	if err != nil {
    84  		t.Fatal(err)
    85  	}
    86  
    87  	wl := func(f *os.File, line string) {
    88  		if _, err := fmt.Fprintln(f, line); err != nil {
    89  			t.Fatal(err)
    90  		}
    91  	}
    92  	writeBoth := func(line string) { wl(oldf, line); wl(newf, line) }
    93  	writeln := writeBoth
    94  	s := bufio.NewScanner(f)
    95  	for s.Scan() {
    96  		line := s.Text()
    97  		tl := strings.TrimSpace(line)
    98  		switch {
    99  		case tl == "// old":
   100  			writeln = func(line string) { wl(oldf, line) }
   101  		case tl == "// new":
   102  			writeln = func(line string) { wl(newf, line) }
   103  		case tl == "// both":
   104  			writeln = writeBoth
   105  		case strings.HasPrefix(tl, "// i "):
   106  			incompatibles = append(incompatibles, strings.TrimSpace(tl[4:]))
   107  		case strings.HasPrefix(tl, "// c "):
   108  			compatibles = append(compatibles, strings.TrimSpace(tl[4:]))
   109  		default:
   110  			writeln(line)
   111  		}
   112  	}
   113  	if s.Err() != nil {
   114  		t.Fatal(s.Err())
   115  	}
   116  	return
   117  }
   118  
   119  func load(t *testing.T, importPath, goPath string) (*packages.Package, error) {
   120  	testenv.NeedsGoPackages(t)
   121  
   122  	cfg := &packages.Config{
   123  		Mode: packages.LoadTypes,
   124  	}
   125  	if goPath != "" {
   126  		cfg.Env = append(os.Environ(), "GOPATH="+goPath)
   127  		cfg.Dir = filepath.Join(goPath, "src", filepath.FromSlash(importPath))
   128  	}
   129  	pkgs, err := packages.Load(cfg, importPath)
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  	if len(pkgs[0].Errors) > 0 {
   134  		return nil, pkgs[0].Errors[0]
   135  	}
   136  	return pkgs[0], nil
   137  }
   138  
   139  func TestExportedFields(t *testing.T) {
   140  	pkg, err := load(t, "github.com/jd-ly/tools/internal/apidiff/testdata/exported_fields", "")
   141  	if err != nil {
   142  		t.Fatal(err)
   143  	}
   144  	typeof := func(name string) types.Type {
   145  		return pkg.Types.Scope().Lookup(name).Type()
   146  	}
   147  
   148  	s := typeof("S")
   149  	su := s.(*types.Named).Underlying().(*types.Struct)
   150  
   151  	ef := exportedSelectableFields(su)
   152  	wants := []struct {
   153  		name string
   154  		typ  types.Type
   155  	}{
   156  		{"A1", typeof("A1")},
   157  		{"D", types.Typ[types.Bool]},
   158  		{"E", types.Typ[types.Int]},
   159  		{"F", typeof("F")},
   160  		{"S", types.NewPointer(s)},
   161  	}
   162  
   163  	if got, want := len(ef), len(wants); got != want {
   164  		t.Errorf("got %d fields, want %d\n%+v", got, want, ef)
   165  	}
   166  	for _, w := range wants {
   167  		if got := ef[w.name]; got != nil && !types.Identical(got.Type(), w.typ) {
   168  			t.Errorf("%s: got %v, want %v", w.name, got.Type(), w.typ)
   169  		}
   170  	}
   171  }