github.com/graybobo/golang.org-package-offline-cache@v0.0.0-20200626051047-6608995c132f/x/tools/imports/fix_test.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package imports
     6  
     7  import (
     8  	"flag"
     9  	"go/build"
    10  	"io/ioutil"
    11  	"os"
    12  	"path/filepath"
    13  	"sync"
    14  	"testing"
    15  )
    16  
    17  var only = flag.String("only", "", "If non-empty, the fix test to run")
    18  
    19  var tests = []struct {
    20  	name    string
    21  	in, out string
    22  }{
    23  	// Adding an import to an existing parenthesized import
    24  	{
    25  		name: "factored_imports_add",
    26  		in: `package foo
    27  import (
    28    "fmt"
    29  )
    30  func bar() {
    31  var b bytes.Buffer
    32  fmt.Println(b.String())
    33  }
    34  `,
    35  		out: `package foo
    36  
    37  import (
    38  	"bytes"
    39  	"fmt"
    40  )
    41  
    42  func bar() {
    43  	var b bytes.Buffer
    44  	fmt.Println(b.String())
    45  }
    46  `,
    47  	},
    48  
    49  	// Adding an import to an existing parenthesized import,
    50  	// verifying it goes into the first section.
    51  	{
    52  		name: "factored_imports_add_first_sec",
    53  		in: `package foo
    54  import (
    55    "fmt"
    56  
    57    "appengine"
    58  )
    59  func bar() {
    60  var b bytes.Buffer
    61  _ = appengine.IsDevServer
    62  fmt.Println(b.String())
    63  }
    64  `,
    65  		out: `package foo
    66  
    67  import (
    68  	"bytes"
    69  	"fmt"
    70  
    71  	"appengine"
    72  )
    73  
    74  func bar() {
    75  	var b bytes.Buffer
    76  	_ = appengine.IsDevServer
    77  	fmt.Println(b.String())
    78  }
    79  `,
    80  	},
    81  
    82  	// Adding an import to an existing parenthesized import,
    83  	// verifying it goes into the first section. (test 2)
    84  	{
    85  		name: "factored_imports_add_first_sec_2",
    86  		in: `package foo
    87  import (
    88    "fmt"
    89  
    90    "appengine"
    91  )
    92  func bar() {
    93  _ = math.NaN
    94  _ = fmt.Sprintf
    95  _ = appengine.IsDevServer
    96  }
    97  `,
    98  		out: `package foo
    99  
   100  import (
   101  	"fmt"
   102  	"math"
   103  
   104  	"appengine"
   105  )
   106  
   107  func bar() {
   108  	_ = math.NaN
   109  	_ = fmt.Sprintf
   110  	_ = appengine.IsDevServer
   111  }
   112  `,
   113  	},
   114  
   115  	// Adding a new import line, without parens
   116  	{
   117  		name: "add_import_section",
   118  		in: `package foo
   119  func bar() {
   120  var b bytes.Buffer
   121  }
   122  `,
   123  		out: `package foo
   124  
   125  import "bytes"
   126  
   127  func bar() {
   128  	var b bytes.Buffer
   129  }
   130  `,
   131  	},
   132  
   133  	// Adding two new imports, which should make a parenthesized import decl.
   134  	{
   135  		name: "add_import_paren_section",
   136  		in: `package foo
   137  func bar() {
   138  _, _ := bytes.Buffer, zip.NewReader
   139  }
   140  `,
   141  		out: `package foo
   142  
   143  import (
   144  	"archive/zip"
   145  	"bytes"
   146  )
   147  
   148  func bar() {
   149  	_, _ := bytes.Buffer, zip.NewReader
   150  }
   151  `,
   152  	},
   153  
   154  	// Make sure we don't add things twice
   155  	{
   156  		name: "no_double_add",
   157  		in: `package foo
   158  func bar() {
   159  _, _ := bytes.Buffer, bytes.NewReader
   160  }
   161  `,
   162  		out: `package foo
   163  
   164  import "bytes"
   165  
   166  func bar() {
   167  	_, _ := bytes.Buffer, bytes.NewReader
   168  }
   169  `,
   170  	},
   171  
   172  	// Remove unused imports, 1 of a factored block
   173  	{
   174  		name: "remove_unused_1_of_2",
   175  		in: `package foo
   176  import (
   177  "bytes"
   178  "fmt"
   179  )
   180  
   181  func bar() {
   182  _, _ := bytes.Buffer, bytes.NewReader
   183  }
   184  `,
   185  		out: `package foo
   186  
   187  import "bytes"
   188  
   189  func bar() {
   190  	_, _ := bytes.Buffer, bytes.NewReader
   191  }
   192  `,
   193  	},
   194  
   195  	// Remove unused imports, 2 of 2
   196  	{
   197  		name: "remove_unused_2_of_2",
   198  		in: `package foo
   199  import (
   200  "bytes"
   201  "fmt"
   202  )
   203  
   204  func bar() {
   205  }
   206  `,
   207  		out: `package foo
   208  
   209  func bar() {
   210  }
   211  `,
   212  	},
   213  
   214  	// Remove unused imports, 1 of 1
   215  	{
   216  		name: "remove_unused_1_of_1",
   217  		in: `package foo
   218  
   219  import "fmt"
   220  
   221  func bar() {
   222  }
   223  `,
   224  		out: `package foo
   225  
   226  func bar() {
   227  }
   228  `,
   229  	},
   230  
   231  	// Don't remove empty imports.
   232  	{
   233  		name: "dont_remove_empty_imports",
   234  		in: `package foo
   235  import (
   236  _ "image/png"
   237  _ "image/jpeg"
   238  )
   239  `,
   240  		out: `package foo
   241  
   242  import (
   243  	_ "image/jpeg"
   244  	_ "image/png"
   245  )
   246  `,
   247  	},
   248  
   249  	// Don't remove dot imports.
   250  	{
   251  		name: "dont_remove_dot_imports",
   252  		in: `package foo
   253  import (
   254  . "foo"
   255  . "bar"
   256  )
   257  `,
   258  		out: `package foo
   259  
   260  import (
   261  	. "bar"
   262  	. "foo"
   263  )
   264  `,
   265  	},
   266  
   267  	// Skip refs the parser can resolve.
   268  	{
   269  		name: "skip_resolved_refs",
   270  		in: `package foo
   271  
   272  func f() {
   273  	type t struct{ Println func(string) }
   274  	fmt := t{Println: func(string) {}}
   275  	fmt.Println("foo")
   276  }
   277  `,
   278  		out: `package foo
   279  
   280  func f() {
   281  	type t struct{ Println func(string) }
   282  	fmt := t{Println: func(string) {}}
   283  	fmt.Println("foo")
   284  }
   285  `,
   286  	},
   287  
   288  	// Do not add a package we already have a resolution for.
   289  	{
   290  		name: "skip_template",
   291  		in: `package foo
   292  
   293  import "html/template"
   294  
   295  func f() { t = template.New("sometemplate") }
   296  `,
   297  		out: `package foo
   298  
   299  import "html/template"
   300  
   301  func f() { t = template.New("sometemplate") }
   302  `,
   303  	},
   304  
   305  	// Don't touch cgo
   306  	{
   307  		name: "cgo",
   308  		in: `package foo
   309  
   310  /*
   311  #include <foo.h>
   312  */
   313  import "C"
   314  `,
   315  		out: `package foo
   316  
   317  /*
   318  #include <foo.h>
   319  */
   320  import "C"
   321  `,
   322  	},
   323  
   324  	// Put some things in their own section
   325  	{
   326  		name: "make_sections",
   327  		in: `package foo
   328  
   329  import (
   330  "os"
   331  )
   332  
   333  func foo () {
   334  _, _ = os.Args, fmt.Println
   335  _, _ = appengine.FooSomething, user.Current
   336  }
   337  `,
   338  		out: `package foo
   339  
   340  import (
   341  	"fmt"
   342  	"os"
   343  
   344  	"appengine"
   345  	"appengine/user"
   346  )
   347  
   348  func foo() {
   349  	_, _ = os.Args, fmt.Println
   350  	_, _ = appengine.FooSomething, user.Current
   351  }
   352  `,
   353  	},
   354  
   355  	// Delete existing empty import block
   356  	{
   357  		name: "delete_empty_import_block",
   358  		in: `package foo
   359  
   360  import ()
   361  `,
   362  		out: `package foo
   363  `,
   364  	},
   365  
   366  	// Use existing empty import block
   367  	{
   368  		name: "use_empty_import_block",
   369  		in: `package foo
   370  
   371  import ()
   372  
   373  func f() {
   374  	_ = fmt.Println
   375  }
   376  `,
   377  		out: `package foo
   378  
   379  import "fmt"
   380  
   381  func f() {
   382  	_ = fmt.Println
   383  }
   384  `,
   385  	},
   386  
   387  	// Blank line before adding new section.
   388  	{
   389  		name: "blank_line_before_new_group",
   390  		in: `package foo
   391  
   392  import (
   393  	"fmt"
   394  	"net"
   395  )
   396  
   397  func f() {
   398  	_ = net.Dial
   399  	_ = fmt.Printf
   400  	_ = snappy.Foo
   401  }
   402  `,
   403  		out: `package foo
   404  
   405  import (
   406  	"fmt"
   407  	"net"
   408  
   409  	"code.google.com/p/snappy-go/snappy"
   410  )
   411  
   412  func f() {
   413  	_ = net.Dial
   414  	_ = fmt.Printf
   415  	_ = snappy.Foo
   416  }
   417  `,
   418  	},
   419  
   420  	// Blank line between standard library and third-party stuff.
   421  	{
   422  		name: "blank_line_separating_std_and_third_party",
   423  		in: `package foo
   424  
   425  import (
   426  	"code.google.com/p/snappy-go/snappy"
   427  	"fmt"
   428  	"net"
   429  )
   430  
   431  func f() {
   432  	_ = net.Dial
   433  	_ = fmt.Printf
   434  	_ = snappy.Foo
   435  }
   436  `,
   437  		out: `package foo
   438  
   439  import (
   440  	"fmt"
   441  	"net"
   442  
   443  	"code.google.com/p/snappy-go/snappy"
   444  )
   445  
   446  func f() {
   447  	_ = net.Dial
   448  	_ = fmt.Printf
   449  	_ = snappy.Foo
   450  }
   451  `,
   452  	},
   453  
   454  	// golang.org/issue/6884
   455  	{
   456  		name: "issue 6884",
   457  		in: `package main
   458  
   459  // A comment
   460  func main() {
   461  	fmt.Println("Hello, world")
   462  }
   463  `,
   464  		out: `package main
   465  
   466  import "fmt"
   467  
   468  // A comment
   469  func main() {
   470  	fmt.Println("Hello, world")
   471  }
   472  `,
   473  	},
   474  
   475  	// golang.org/issue/7132
   476  	{
   477  		name: "issue 7132",
   478  		in: `package main
   479  
   480  import (
   481  "fmt"
   482  
   483  "gu"
   484  "github.com/foo/bar"
   485  )
   486  
   487  var (
   488  a = bar.a
   489  b = gu.a
   490  c = fmt.Printf
   491  )
   492  `,
   493  		out: `package main
   494  
   495  import (
   496  	"fmt"
   497  
   498  	"gu"
   499  
   500  	"github.com/foo/bar"
   501  )
   502  
   503  var (
   504  	a = bar.a
   505  	b = gu.a
   506  	c = fmt.Printf
   507  )
   508  `,
   509  	},
   510  
   511  	{
   512  		name: "renamed package",
   513  		in: `package main
   514  
   515  var _ = str.HasPrefix
   516  `,
   517  		out: `package main
   518  
   519  import str "strings"
   520  
   521  var _ = str.HasPrefix
   522  `,
   523  	},
   524  
   525  	{
   526  		name: "fragment with main",
   527  		in:   `func main(){fmt.Println("Hello, world")}`,
   528  		out: `package main
   529  
   530  import "fmt"
   531  
   532  func main() { fmt.Println("Hello, world") }
   533  `,
   534  	},
   535  
   536  	{
   537  		name: "fragment without main",
   538  		in:   `func notmain(){fmt.Println("Hello, world")}`,
   539  		out: `import "fmt"
   540  
   541  func notmain() { fmt.Println("Hello, world") }`,
   542  	},
   543  
   544  	// Remove first import within in a 2nd/3rd/4th/etc. section.
   545  	// golang.org/issue/7679
   546  	{
   547  		name: "issue 7679",
   548  		in: `package main
   549  
   550  import (
   551  	"fmt"
   552  
   553  	"github.com/foo/bar"
   554  	"github.com/foo/qux"
   555  )
   556  
   557  func main() {
   558  	var _ = fmt.Println
   559  	//var _ = bar.A
   560  	var _ = qux.B
   561  }
   562  `,
   563  		out: `package main
   564  
   565  import (
   566  	"fmt"
   567  
   568  	"github.com/foo/qux"
   569  )
   570  
   571  func main() {
   572  	var _ = fmt.Println
   573  	//var _ = bar.A
   574  	var _ = qux.B
   575  }
   576  `,
   577  	},
   578  
   579  	// Blank line can be added before all types of import declarations.
   580  	// golang.org/issue/7866
   581  	{
   582  		name: "issue 7866",
   583  		in: `package main
   584  
   585  import (
   586  	"fmt"
   587  	renamed_bar "github.com/foo/bar"
   588  
   589  	. "github.com/foo/baz"
   590  	"io"
   591  
   592  	_ "github.com/foo/qux"
   593  	"strings"
   594  )
   595  
   596  func main() {
   597  	_, _, _, _, _ = fmt.Errorf, io.Copy, strings.Contains, renamed_bar.A, B
   598  }
   599  `,
   600  		out: `package main
   601  
   602  import (
   603  	"fmt"
   604  
   605  	renamed_bar "github.com/foo/bar"
   606  
   607  	"io"
   608  
   609  	. "github.com/foo/baz"
   610  
   611  	"strings"
   612  
   613  	_ "github.com/foo/qux"
   614  )
   615  
   616  func main() {
   617  	_, _, _, _, _ = fmt.Errorf, io.Copy, strings.Contains, renamed_bar.A, B
   618  }
   619  `,
   620  	},
   621  
   622  	// Non-idempotent comment formatting
   623  	// golang.org/issue/8035
   624  	{
   625  		name: "issue 8035",
   626  		in: `package main
   627  
   628  import (
   629  	"fmt"                     // A
   630  	"go/ast"                  // B
   631  	_ "launchpad.net/gocheck" // C
   632  )
   633  
   634  func main() { _, _ = fmt.Print, ast.Walk }
   635  `,
   636  		out: `package main
   637  
   638  import (
   639  	"fmt"    // A
   640  	"go/ast" // B
   641  
   642  	_ "launchpad.net/gocheck" // C
   643  )
   644  
   645  func main() { _, _ = fmt.Print, ast.Walk }
   646  `,
   647  	},
   648  
   649  	// Failure to delete all duplicate imports
   650  	// golang.org/issue/8459
   651  	{
   652  		name: "issue 8459",
   653  		in: `package main
   654  
   655  import (
   656  	"fmt"
   657  	"log"
   658  	"log"
   659  	"math"
   660  )
   661  
   662  func main() { fmt.Println("pi:", math.Pi) }
   663  `,
   664  		out: `package main
   665  
   666  import (
   667  	"fmt"
   668  	"math"
   669  )
   670  
   671  func main() { fmt.Println("pi:", math.Pi) }
   672  `,
   673  	},
   674  
   675  	// Too aggressive prefix matching
   676  	// golang.org/issue/9961
   677  	{
   678  		name: "issue 9961",
   679  		in: `package p
   680  
   681  import (
   682  	"zip"
   683  
   684  	"rsc.io/p"
   685  )
   686  
   687  var (
   688  	_ = fmt.Print
   689  	_ = zip.Store
   690  	_ p.P
   691  	_ = regexp.Compile
   692  )
   693  `,
   694  		out: `package p
   695  
   696  import (
   697  	"fmt"
   698  	"regexp"
   699  	"zip"
   700  
   701  	"rsc.io/p"
   702  )
   703  
   704  var (
   705  	_ = fmt.Print
   706  	_ = zip.Store
   707  	_ p.P
   708  	_ = regexp.Compile
   709  )
   710  `,
   711  	},
   712  
   713  	// Unused named import is mistaken for unnamed import
   714  	// golang.org/issue/8149
   715  	{
   716  		name: "issue 8149",
   717  		in: `package main
   718  
   719  import foo "fmt"
   720  
   721  func main() { fmt.Println() }
   722  `,
   723  		out: `package main
   724  
   725  import "fmt"
   726  
   727  func main() { fmt.Println() }
   728  `,
   729  	},
   730  }
   731  
   732  func TestFixImports(t *testing.T) {
   733  	simplePkgs := map[string]string{
   734  		"appengine": "appengine",
   735  		"bytes":     "bytes",
   736  		"fmt":       "fmt",
   737  		"math":      "math",
   738  		"os":        "os",
   739  		"p":         "rsc.io/p",
   740  		"regexp":    "regexp",
   741  		"snappy":    "code.google.com/p/snappy-go/snappy",
   742  		"str":       "strings",
   743  		"user":      "appengine/user",
   744  		"zip":       "archive/zip",
   745  	}
   746  	findImport = func(pkgName string, symbols map[string]bool) (string, bool, error) {
   747  		return simplePkgs[pkgName], pkgName == "str", nil
   748  	}
   749  
   750  	options := &Options{
   751  		TabWidth:  8,
   752  		TabIndent: true,
   753  		Comments:  true,
   754  		Fragment:  true,
   755  	}
   756  
   757  	for _, tt := range tests {
   758  		if *only != "" && tt.name != *only {
   759  			continue
   760  		}
   761  		buf, err := Process(tt.name+".go", []byte(tt.in), options)
   762  		if err != nil {
   763  			t.Errorf("error on %q: %v", tt.name, err)
   764  			continue
   765  		}
   766  		if got := string(buf); got != tt.out {
   767  			t.Errorf("results diff on %q\nGOT:\n%s\nWANT:\n%s\n", tt.name, got, tt.out)
   768  		}
   769  	}
   770  }
   771  
   772  func TestFindImportGoPath(t *testing.T) {
   773  	goroot, err := ioutil.TempDir("", "goimports-")
   774  	if err != nil {
   775  		t.Fatal(err)
   776  	}
   777  	defer os.RemoveAll(goroot)
   778  
   779  	pkgIndexOnce = sync.Once{}
   780  
   781  	origStdlib := stdlib
   782  	defer func() {
   783  		stdlib = origStdlib
   784  	}()
   785  	stdlib = nil
   786  
   787  	// Test against imaginary bits/bytes package in std lib
   788  	bytesDir := filepath.Join(goroot, "src", "pkg", "bits", "bytes")
   789  	for _, tag := range build.Default.ReleaseTags {
   790  		// Go 1.4 rearranged the GOROOT tree to remove the "pkg" path component.
   791  		if tag == "go1.4" {
   792  			bytesDir = filepath.Join(goroot, "src", "bits", "bytes")
   793  		}
   794  	}
   795  	if err := os.MkdirAll(bytesDir, 0755); err != nil {
   796  		t.Fatal(err)
   797  	}
   798  	bytesSrcPath := filepath.Join(bytesDir, "bytes.go")
   799  	bytesPkgPath := "bits/bytes"
   800  	bytesSrc := []byte(`package bytes
   801  
   802  type Buffer2 struct {}
   803  `)
   804  	if err := ioutil.WriteFile(bytesSrcPath, bytesSrc, 0775); err != nil {
   805  		t.Fatal(err)
   806  	}
   807  	oldGOROOT := build.Default.GOROOT
   808  	oldGOPATH := build.Default.GOPATH
   809  	build.Default.GOROOT = goroot
   810  	build.Default.GOPATH = ""
   811  	defer func() {
   812  		build.Default.GOROOT = oldGOROOT
   813  		build.Default.GOPATH = oldGOPATH
   814  	}()
   815  
   816  	got, rename, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true})
   817  	if err != nil {
   818  		t.Fatal(err)
   819  	}
   820  	if got != bytesPkgPath || rename {
   821  		t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, %t, want "%s", false`, got, rename, bytesPkgPath)
   822  	}
   823  
   824  	got, rename, err = findImportGoPath("bytes", map[string]bool{"Missing": true})
   825  	if err != nil {
   826  		t.Fatal(err)
   827  	}
   828  	if got != "" || rename {
   829  		t.Errorf(`findImportGoPath("bytes", Missing ...)=%q, %t, want "", false`, got, rename)
   830  	}
   831  }
   832  
   833  func TestFindImportStdlib(t *testing.T) {
   834  	tests := []struct {
   835  		pkg     string
   836  		symbols []string
   837  		want    string
   838  	}{
   839  		{"http", []string{"Get"}, "net/http"},
   840  		{"http", []string{"Get", "Post"}, "net/http"},
   841  		{"http", []string{"Get", "Foo"}, ""},
   842  		{"bytes", []string{"Buffer"}, "bytes"},
   843  		{"ioutil", []string{"Discard"}, "io/ioutil"},
   844  	}
   845  	for _, tt := range tests {
   846  		got, rename, ok := findImportStdlib(tt.pkg, strSet(tt.symbols))
   847  		if (got != "") != ok {
   848  			t.Error("findImportStdlib return value inconsistent")
   849  		}
   850  		if got != tt.want || rename {
   851  			t.Errorf("findImportStdlib(%q, %q) = %q, %t; want %q, false", tt.pkg, tt.symbols, got, rename, tt.want)
   852  		}
   853  	}
   854  }
   855  
   856  func strSet(ss []string) map[string]bool {
   857  	m := make(map[string]bool)
   858  	for _, s := range ss {
   859  		m[s] = true
   860  	}
   861  	return m
   862  }