github.com/golang/mock@v1.6.0/mockgen/mockgen_test.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"io/ioutil"
     6  	"os"
     7  	"path/filepath"
     8  	"reflect"
     9  	"regexp"
    10  	"strings"
    11  	"testing"
    12  
    13  	"github.com/golang/mock/mockgen/model"
    14  )
    15  
    16  func TestMakeArgString(t *testing.T) {
    17  	testCases := []struct {
    18  		argNames  []string
    19  		argTypes  []string
    20  		argString string
    21  	}{
    22  		{
    23  			argNames:  nil,
    24  			argTypes:  nil,
    25  			argString: "",
    26  		},
    27  		{
    28  			argNames:  []string{"arg0"},
    29  			argTypes:  []string{"int"},
    30  			argString: "arg0 int",
    31  		},
    32  		{
    33  			argNames:  []string{"arg0", "arg1"},
    34  			argTypes:  []string{"int", "bool"},
    35  			argString: "arg0 int, arg1 bool",
    36  		},
    37  		{
    38  			argNames:  []string{"arg0", "arg1"},
    39  			argTypes:  []string{"int", "int"},
    40  			argString: "arg0, arg1 int",
    41  		},
    42  		{
    43  			argNames:  []string{"arg0", "arg1", "arg2"},
    44  			argTypes:  []string{"bool", "int", "int"},
    45  			argString: "arg0 bool, arg1, arg2 int",
    46  		},
    47  		{
    48  			argNames:  []string{"arg0", "arg1", "arg2"},
    49  			argTypes:  []string{"int", "bool", "int"},
    50  			argString: "arg0 int, arg1 bool, arg2 int",
    51  		},
    52  		{
    53  			argNames:  []string{"arg0", "arg1", "arg2"},
    54  			argTypes:  []string{"int", "int", "bool"},
    55  			argString: "arg0, arg1 int, arg2 bool",
    56  		},
    57  		{
    58  			argNames:  []string{"arg0", "arg1", "arg2"},
    59  			argTypes:  []string{"int", "int", "int"},
    60  			argString: "arg0, arg1, arg2 int",
    61  		},
    62  		{
    63  			argNames:  []string{"arg0", "arg1", "arg2", "arg3"},
    64  			argTypes:  []string{"bool", "int", "int", "int"},
    65  			argString: "arg0 bool, arg1, arg2, arg3 int",
    66  		},
    67  		{
    68  			argNames:  []string{"arg0", "arg1", "arg2", "arg3"},
    69  			argTypes:  []string{"int", "bool", "int", "int"},
    70  			argString: "arg0 int, arg1 bool, arg2, arg3 int",
    71  		},
    72  		{
    73  			argNames:  []string{"arg0", "arg1", "arg2", "arg3"},
    74  			argTypes:  []string{"int", "int", "bool", "int"},
    75  			argString: "arg0, arg1 int, arg2 bool, arg3 int",
    76  		},
    77  		{
    78  			argNames:  []string{"arg0", "arg1", "arg2", "arg3"},
    79  			argTypes:  []string{"int", "int", "int", "bool"},
    80  			argString: "arg0, arg1, arg2 int, arg3 bool",
    81  		},
    82  		{
    83  			argNames:  []string{"arg0", "arg1", "arg2", "arg3", "arg4"},
    84  			argTypes:  []string{"bool", "int", "int", "int", "bool"},
    85  			argString: "arg0 bool, arg1, arg2, arg3 int, arg4 bool",
    86  		},
    87  		{
    88  			argNames:  []string{"arg0", "arg1", "arg2", "arg3", "arg4"},
    89  			argTypes:  []string{"int", "bool", "int", "int", "bool"},
    90  			argString: "arg0 int, arg1 bool, arg2, arg3 int, arg4 bool",
    91  		},
    92  		{
    93  			argNames:  []string{"arg0", "arg1", "arg2", "arg3", "arg4"},
    94  			argTypes:  []string{"int", "int", "bool", "int", "bool"},
    95  			argString: "arg0, arg1 int, arg2 bool, arg3 int, arg4 bool",
    96  		},
    97  		{
    98  			argNames:  []string{"arg0", "arg1", "arg2", "arg3", "arg4"},
    99  			argTypes:  []string{"int", "int", "int", "bool", "bool"},
   100  			argString: "arg0, arg1, arg2 int, arg3, arg4 bool",
   101  		},
   102  		{
   103  			argNames:  []string{"arg0", "arg1", "arg2", "arg3", "arg4"},
   104  			argTypes:  []string{"int", "int", "bool", "bool", "int"},
   105  			argString: "arg0, arg1 int, arg2, arg3 bool, arg4 int",
   106  		},
   107  	}
   108  
   109  	for i, tc := range testCases {
   110  		t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) {
   111  			s := makeArgString(tc.argNames, tc.argTypes)
   112  			if s != tc.argString {
   113  				t.Errorf("result == %q, want %q", s, tc.argString)
   114  			}
   115  		})
   116  	}
   117  }
   118  
   119  func TestNewIdentifierAllocator(t *testing.T) {
   120  	a := newIdentifierAllocator([]string{"taken1", "taken2"})
   121  	if len(a) != 2 {
   122  		t.Fatalf("expected 2 items, got %v", len(a))
   123  	}
   124  
   125  	_, ok := a["taken1"]
   126  	if !ok {
   127  		t.Errorf("allocator doesn't contain 'taken1': %#v", a)
   128  	}
   129  
   130  	_, ok = a["taken2"]
   131  	if !ok {
   132  		t.Errorf("allocator doesn't contain 'taken2': %#v", a)
   133  	}
   134  }
   135  
   136  func allocatorContainsIdentifiers(a identifierAllocator, ids []string) bool {
   137  	if len(a) != len(ids) {
   138  		return false
   139  	}
   140  
   141  	for _, id := range ids {
   142  		_, ok := a[id]
   143  		if !ok {
   144  			return false
   145  		}
   146  	}
   147  
   148  	return true
   149  }
   150  
   151  func TestIdentifierAllocator_allocateIdentifier(t *testing.T) {
   152  	a := newIdentifierAllocator([]string{"taken"})
   153  
   154  	t2 := a.allocateIdentifier("taken_2")
   155  	if t2 != "taken_2" {
   156  		t.Fatalf("expected 'taken_2', got %q", t2)
   157  	}
   158  	expected := []string{"taken", "taken_2"}
   159  	if !allocatorContainsIdentifiers(a, expected) {
   160  		t.Fatalf("allocator doesn't contain the expected items - allocator: %#v, expected items: %#v", a, expected)
   161  	}
   162  
   163  	t3 := a.allocateIdentifier("taken")
   164  	if t3 != "taken_3" {
   165  		t.Fatalf("expected 'taken_3', got %q", t3)
   166  	}
   167  	expected = []string{"taken", "taken_2", "taken_3"}
   168  	if !allocatorContainsIdentifiers(a, expected) {
   169  		t.Fatalf("allocator doesn't contain the expected items - allocator: %#v, expected items: %#v", a, expected)
   170  	}
   171  
   172  	t4 := a.allocateIdentifier("taken")
   173  	if t4 != "taken_4" {
   174  		t.Fatalf("expected 'taken_4', got %q", t4)
   175  	}
   176  	expected = []string{"taken", "taken_2", "taken_3", "taken_4"}
   177  	if !allocatorContainsIdentifiers(a, expected) {
   178  		t.Fatalf("allocator doesn't contain the expected items - allocator: %#v, expected items: %#v", a, expected)
   179  	}
   180  
   181  	id := a.allocateIdentifier("id")
   182  	if id != "id" {
   183  		t.Fatalf("expected 'id', got %q", id)
   184  	}
   185  	expected = []string{"taken", "taken_2", "taken_3", "taken_4", "id"}
   186  	if !allocatorContainsIdentifiers(a, expected) {
   187  		t.Fatalf("allocator doesn't contain the expected items - allocator: %#v, expected items: %#v", a, expected)
   188  	}
   189  }
   190  
   191  func TestGenerateMockInterface_Helper(t *testing.T) {
   192  	for _, test := range []struct {
   193  		Name       string
   194  		Identifier string
   195  		HelperLine string
   196  		Methods    []*model.Method
   197  	}{
   198  		{Name: "mock", Identifier: "MockSomename", HelperLine: "m.ctrl.T.Helper()"},
   199  		{Name: "recorder", Identifier: "MockSomenameMockRecorder", HelperLine: "mr.mock.ctrl.T.Helper()"},
   200  		{
   201  			Name:       "mock identifier conflict",
   202  			Identifier: "MockSomename",
   203  			HelperLine: "m_2.ctrl.T.Helper()",
   204  			Methods: []*model.Method{
   205  				{
   206  					Name: "MethodA",
   207  					In: []*model.Parameter{
   208  						{
   209  							Name: "m",
   210  							Type: &model.NamedType{Type: "int"},
   211  						},
   212  					},
   213  				},
   214  			},
   215  		},
   216  		{
   217  			Name:       "recorder identifier conflict",
   218  			Identifier: "MockSomenameMockRecorder",
   219  			HelperLine: "mr_2.mock.ctrl.T.Helper()",
   220  			Methods: []*model.Method{
   221  				{
   222  					Name: "MethodA",
   223  					In: []*model.Parameter{
   224  						{
   225  							Name: "mr",
   226  							Type: &model.NamedType{Type: "int"},
   227  						},
   228  					},
   229  				},
   230  			},
   231  		},
   232  	} {
   233  		t.Run(test.Name, func(t *testing.T) {
   234  			g := generator{}
   235  
   236  			if len(test.Methods) == 0 {
   237  				test.Methods = []*model.Method{
   238  					{Name: "MethodA"},
   239  					{Name: "MethodB"},
   240  				}
   241  			}
   242  
   243  			intf := &model.Interface{Name: "Somename"}
   244  			for _, m := range test.Methods {
   245  				intf.AddMethod(m)
   246  			}
   247  
   248  			if err := g.GenerateMockInterface(intf, "somepackage"); err != nil {
   249  				t.Fatal(err)
   250  			}
   251  
   252  			lines := strings.Split(g.buf.String(), "\n")
   253  
   254  			// T.Helper() should be the first line
   255  			for _, method := range test.Methods {
   256  				if strings.TrimSpace(lines[findMethod(t, test.Identifier, method.Name, lines)+1]) != test.HelperLine {
   257  					t.Fatalf("method %s.%s did not declare itself a Helper method", test.Identifier, method.Name)
   258  				}
   259  			}
   260  		})
   261  	}
   262  }
   263  
   264  func findMethod(t *testing.T, identifier, methodName string, lines []string) int {
   265  	t.Helper()
   266  	r := regexp.MustCompile(fmt.Sprintf(`func\s+\(.+%s\)\s*%s`, identifier, methodName))
   267  	for i, line := range lines {
   268  		if r.MatchString(line) {
   269  			return i
   270  		}
   271  	}
   272  
   273  	t.Fatalf("unable to find 'func (m %s) %s'", identifier, methodName)
   274  	panic("unreachable")
   275  }
   276  
   277  func TestGetArgNames(t *testing.T) {
   278  	for _, testCase := range []struct {
   279  		name     string
   280  		method   *model.Method
   281  		expected []string
   282  	}{
   283  		{
   284  			name: "NamedArg",
   285  			method: &model.Method{
   286  				In: []*model.Parameter{
   287  					{
   288  						Name: "firstArg",
   289  						Type: &model.NamedType{Type: "int"},
   290  					},
   291  					{
   292  						Name: "secondArg",
   293  						Type: &model.NamedType{Type: "string"},
   294  					},
   295  				},
   296  			},
   297  			expected: []string{"firstArg", "secondArg"},
   298  		},
   299  		{
   300  			name: "NotNamedArg",
   301  			method: &model.Method{
   302  				In: []*model.Parameter{
   303  					{
   304  						Name: "",
   305  						Type: &model.NamedType{Type: "int"},
   306  					},
   307  					{
   308  						Name: "",
   309  						Type: &model.NamedType{Type: "string"},
   310  					},
   311  				},
   312  			},
   313  			expected: []string{"arg0", "arg1"},
   314  		},
   315  		{
   316  			name: "MixedNameArg",
   317  			method: &model.Method{
   318  				In: []*model.Parameter{
   319  					{
   320  						Name: "firstArg",
   321  						Type: &model.NamedType{Type: "int"},
   322  					},
   323  					{
   324  						Name: "_",
   325  						Type: &model.NamedType{Type: "string"},
   326  					},
   327  				},
   328  			},
   329  			expected: []string{"firstArg", "arg1"},
   330  		},
   331  	} {
   332  		t.Run(testCase.name, func(t *testing.T) {
   333  			g := generator{}
   334  
   335  			result := g.getArgNames(testCase.method)
   336  			if !reflect.DeepEqual(result, testCase.expected) {
   337  				t.Fatalf("expected %s, got %s", result, testCase.expected)
   338  			}
   339  		})
   340  	}
   341  }
   342  
   343  func Test_createPackageMap(t *testing.T) {
   344  	tests := []struct {
   345  		name            string
   346  		importPath      string
   347  		wantPackageName string
   348  		wantOK          bool
   349  	}{
   350  		{"golang package", "context", "context", true},
   351  		{"third party", "golang.org/x/tools/present", "present", true},
   352  	}
   353  	var importPaths []string
   354  	for _, t := range tests {
   355  		importPaths = append(importPaths, t.importPath)
   356  	}
   357  	packages := createPackageMap(importPaths)
   358  	for _, tt := range tests {
   359  		t.Run(tt.name, func(t *testing.T) {
   360  			gotPackageName, gotOk := packages[tt.importPath]
   361  			if gotPackageName != tt.wantPackageName {
   362  				t.Errorf("createPackageMap() gotPackageName = %v, wantPackageName = %v", gotPackageName, tt.wantPackageName)
   363  			}
   364  			if gotOk != tt.wantOK {
   365  				t.Errorf("createPackageMap() gotOk = %v, wantOK = %v", gotOk, tt.wantOK)
   366  			}
   367  		})
   368  	}
   369  }
   370  
   371  func TestParsePackageImport_FallbackGoPath(t *testing.T) {
   372  	goPath, err := ioutil.TempDir("", "gopath")
   373  	if err != nil {
   374  		t.Error(err)
   375  	}
   376  	defer func() {
   377  		if err = os.RemoveAll(goPath); err != nil {
   378  			t.Error(err)
   379  		}
   380  	}()
   381  	srcDir := filepath.Join(goPath, "src/example.com/foo")
   382  	err = os.MkdirAll(srcDir, 0755)
   383  	if err != nil {
   384  		t.Error(err)
   385  	}
   386  	key := "GOPATH"
   387  	value := goPath
   388  	if err := os.Setenv(key, value); err != nil {
   389  		t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err)
   390  	}
   391  	key = "GO111MODULE"
   392  	value = "on"
   393  	if err := os.Setenv(key, value); err != nil {
   394  		t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err)
   395  	}
   396  	pkgPath, err := parsePackageImport(srcDir)
   397  	expected := "example.com/foo"
   398  	if pkgPath != expected {
   399  		t.Errorf("expect %s, got %s", expected, pkgPath)
   400  	}
   401  }
   402  
   403  func TestParsePackageImport_FallbackMultiGoPath(t *testing.T) {
   404  	var goPathList []string
   405  
   406  	// first gopath
   407  	goPath, err := ioutil.TempDir("", "gopath1")
   408  	if err != nil {
   409  		t.Error(err)
   410  	}
   411  	goPathList = append(goPathList, goPath)
   412  	defer func() {
   413  		if err = os.RemoveAll(goPath); err != nil {
   414  			t.Error(err)
   415  		}
   416  	}()
   417  	srcDir := filepath.Join(goPath, "src/example.com/foo")
   418  	err = os.MkdirAll(srcDir, 0755)
   419  	if err != nil {
   420  		t.Error(err)
   421  	}
   422  
   423  	// second gopath
   424  	goPath, err = ioutil.TempDir("", "gopath2")
   425  	if err != nil {
   426  		t.Error(err)
   427  	}
   428  	goPathList = append(goPathList, goPath)
   429  	defer func() {
   430  		if err = os.RemoveAll(goPath); err != nil {
   431  			t.Error(err)
   432  		}
   433  	}()
   434  
   435  	goPaths := strings.Join(goPathList, string(os.PathListSeparator))
   436  	key := "GOPATH"
   437  	value := goPaths
   438  	if err := os.Setenv(key, value); err != nil {
   439  		t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err)
   440  	}
   441  	key = "GO111MODULE"
   442  	value = "on"
   443  	if err := os.Setenv(key, value); err != nil {
   444  		t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err)
   445  	}
   446  	pkgPath, err := parsePackageImport(srcDir)
   447  	expected := "example.com/foo"
   448  	if pkgPath != expected {
   449  		t.Errorf("expect %s, got %s", expected, pkgPath)
   450  	}
   451  }