github.com/0xKiwi/rules_go@v0.24.3/go/tools/fetch_repo/fetch_repo_test.go (about)

     1  package main
     2  
     3  import (
     4  	"os"
     5  	"reflect"
     6  	"testing"
     7  
     8  	"golang.org/x/tools/go/vcs"
     9  )
    10  
    11  var (
    12  	root = &vcs.RepoRoot{
    13  		VCS:  vcs.ByCmd("git"),
    14  		Repo: "https://github.com/bazeltest/rules_go",
    15  		Root: "github.com/bazeltest/rules_go",
    16  	}
    17  )
    18  
    19  func TestMain(m *testing.M) {
    20  	// Replace vcs.RepoRootForImportPath to disable any network calls.
    21  	repoRootForImportPath = func(_ string, _ bool) (*vcs.RepoRoot, error) {
    22  		return root, nil
    23  	}
    24  	os.Exit(m.Run())
    25  }
    26  
    27  func TestGetRepoRoot(t *testing.T) {
    28  	for _, tc := range []struct {
    29  		label      string
    30  		remote     string
    31  		cmd        string
    32  		importpath string
    33  		r          *vcs.RepoRoot
    34  	}{
    35  		{
    36  			label:      "all",
    37  			remote:     "https://github.com/bazeltest/rules_go",
    38  			cmd:        "git",
    39  			importpath: "github.com/bazeltest/rules_go",
    40  			r:          root,
    41  		},
    42  		{
    43  			label:      "different remote",
    44  			remote:     "https://example.com/rules_go",
    45  			cmd:        "git",
    46  			importpath: "github.com/bazeltest/rules_go",
    47  			r: &vcs.RepoRoot{
    48  				VCS:  vcs.ByCmd("git"),
    49  				Repo: "https://example.com/rules_go",
    50  				Root: "github.com/bazeltest/rules_go",
    51  			},
    52  		},
    53  		{
    54  			label:      "only importpath",
    55  			importpath: "github.com/bazeltest/rules_go",
    56  			r:          root,
    57  		},
    58  	} {
    59  		r, err := getRepoRoot(tc.remote, tc.cmd, tc.importpath)
    60  		if err != nil {
    61  			t.Errorf("[%s] %v", tc.label, err)
    62  		}
    63  		if !reflect.DeepEqual(r, tc.r) {
    64  			t.Errorf("[%s] Expected %+v, got %+v", tc.label, tc.r, r)
    65  		}
    66  	}
    67  }
    68  
    69  func TestGetRepoRoot_error(t *testing.T) {
    70  	for _, tc := range []struct {
    71  		label      string
    72  		remote     string
    73  		cmd        string
    74  		importpath string
    75  	}{
    76  		{
    77  			label:  "importpath as remote",
    78  			remote: "github.com/bazeltest/rules_go",
    79  		},
    80  		{
    81  			label:      "missing vcs",
    82  			remote:     "https://github.com/bazeltest/rules_go",
    83  			importpath: "github.com/bazeltest/rules_go",
    84  		},
    85  		{
    86  			label:      "missing remote",
    87  			cmd:        "git",
    88  			importpath: "github.com/bazeltest/rules_go",
    89  		},
    90  	} {
    91  		r, err := getRepoRoot(tc.remote, tc.cmd, tc.importpath)
    92  		if err == nil {
    93  			t.Errorf("[%s] expected error. Got %+v", tc.label, r)
    94  		}
    95  	}
    96  }