github.com/graybobo/golang.org-package-offline-cache@v0.0.0-20200626051047-6608995c132f/x/tools/astutil/imports_test.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package astutil
     6  
     7  import (
     8  	"bytes"
     9  	"go/ast"
    10  	"go/format"
    11  	"go/parser"
    12  	"go/token"
    13  	"reflect"
    14  	"strconv"
    15  	"testing"
    16  )
    17  
    18  var fset = token.NewFileSet()
    19  
    20  func parse(t *testing.T, name, in string) *ast.File {
    21  	file, err := parser.ParseFile(fset, name, in, parser.ParseComments)
    22  	if err != nil {
    23  		t.Fatalf("%s parse: %v", name, err)
    24  	}
    25  	return file
    26  }
    27  
    28  func print(t *testing.T, name string, f *ast.File) string {
    29  	var buf bytes.Buffer
    30  	if err := format.Node(&buf, fset, f); err != nil {
    31  		t.Fatalf("%s gofmt: %v", name, err)
    32  	}
    33  	return string(buf.Bytes())
    34  }
    35  
    36  type test struct {
    37  	name       string
    38  	renamedPkg string
    39  	pkg        string
    40  	in         string
    41  	out        string
    42  	broken     bool // known broken
    43  }
    44  
    45  var addTests = []test{
    46  	{
    47  		name: "leave os alone",
    48  		pkg:  "os",
    49  		in: `package main
    50  
    51  import (
    52  	"os"
    53  )
    54  `,
    55  		out: `package main
    56  
    57  import (
    58  	"os"
    59  )
    60  `,
    61  	},
    62  	{
    63  		name: "import.1",
    64  		pkg:  "os",
    65  		in: `package main
    66  `,
    67  		out: `package main
    68  
    69  import "os"
    70  `,
    71  	},
    72  	{
    73  		name: "import.2",
    74  		pkg:  "os",
    75  		in: `package main
    76  
    77  // Comment
    78  import "C"
    79  `,
    80  		out: `package main
    81  
    82  // Comment
    83  import "C"
    84  import "os"
    85  `,
    86  	},
    87  	{
    88  		name: "import.3",
    89  		pkg:  "os",
    90  		in: `package main
    91  
    92  // Comment
    93  import "C"
    94  
    95  import (
    96  	"io"
    97  	"utf8"
    98  )
    99  `,
   100  		out: `package main
   101  
   102  // Comment
   103  import "C"
   104  
   105  import (
   106  	"io"
   107  	"os"
   108  	"utf8"
   109  )
   110  `,
   111  	},
   112  	{
   113  		name: "import.17",
   114  		pkg:  "x/y/z",
   115  		in: `package main
   116  
   117  // Comment
   118  import "C"
   119  
   120  import (
   121  	"a"
   122  	"b"
   123  
   124  	"x/w"
   125  
   126  	"d/f"
   127  )
   128  `,
   129  		out: `package main
   130  
   131  // Comment
   132  import "C"
   133  
   134  import (
   135  	"a"
   136  	"b"
   137  
   138  	"x/w"
   139  	"x/y/z"
   140  
   141  	"d/f"
   142  )
   143  `,
   144  	},
   145  	{
   146  		name: "import into singular block",
   147  		pkg:  "bytes",
   148  		in: `package main
   149  
   150  import "os"
   151  
   152  `,
   153  		out: `package main
   154  
   155  import (
   156  	"bytes"
   157  	"os"
   158  )
   159  `,
   160  	},
   161  	{
   162  		name:       "",
   163  		renamedPkg: "fmtpkg",
   164  		pkg:        "fmt",
   165  		in: `package main
   166  
   167  import "os"
   168  
   169  `,
   170  		out: `package main
   171  
   172  import (
   173  	fmtpkg "fmt"
   174  	"os"
   175  )
   176  `,
   177  	},
   178  	{
   179  		name: "struct comment",
   180  		pkg:  "time",
   181  		in: `package main
   182  
   183  // This is a comment before a struct.
   184  type T struct {
   185  	t  time.Time
   186  }
   187  `,
   188  		out: `package main
   189  
   190  import "time"
   191  
   192  // This is a comment before a struct.
   193  type T struct {
   194  	t time.Time
   195  }
   196  `,
   197  	},
   198  }
   199  
   200  func TestAddImport(t *testing.T) {
   201  	for _, test := range addTests {
   202  		file := parse(t, test.name, test.in)
   203  		var before bytes.Buffer
   204  		ast.Fprint(&before, fset, file, nil)
   205  		AddNamedImport(fset, file, test.renamedPkg, test.pkg)
   206  		if got := print(t, test.name, file); got != test.out {
   207  			if test.broken {
   208  				t.Logf("%s is known broken:\ngot: %s\nwant: %s", test.name, got, test.out)
   209  			} else {
   210  				t.Errorf("%s:\ngot: %s\nwant: %s", test.name, got, test.out)
   211  			}
   212  			var after bytes.Buffer
   213  			ast.Fprint(&after, fset, file, nil)
   214  
   215  			t.Logf("AST before:\n%s\nAST after:\n%s\n", before.String(), after.String())
   216  		}
   217  	}
   218  }
   219  
   220  func TestDoubleAddImport(t *testing.T) {
   221  	file := parse(t, "doubleimport", "package main\n")
   222  	AddImport(fset, file, "os")
   223  	AddImport(fset, file, "bytes")
   224  	want := `package main
   225  
   226  import (
   227  	"bytes"
   228  	"os"
   229  )
   230  `
   231  	if got := print(t, "doubleimport", file); got != want {
   232  		t.Errorf("got: %s\nwant: %s", got, want)
   233  	}
   234  }
   235  
   236  var deleteTests = []test{
   237  	{
   238  		name: "import.4",
   239  		pkg:  "os",
   240  		in: `package main
   241  
   242  import (
   243  	"os"
   244  )
   245  `,
   246  		out: `package main
   247  `,
   248  	},
   249  	{
   250  		name: "import.5",
   251  		pkg:  "os",
   252  		in: `package main
   253  
   254  // Comment
   255  import "C"
   256  import "os"
   257  `,
   258  		out: `package main
   259  
   260  // Comment
   261  import "C"
   262  `,
   263  	},
   264  	{
   265  		name: "import.6",
   266  		pkg:  "os",
   267  		in: `package main
   268  
   269  // Comment
   270  import "C"
   271  
   272  import (
   273  	"io"
   274  	"os"
   275  	"utf8"
   276  )
   277  `,
   278  		out: `package main
   279  
   280  // Comment
   281  import "C"
   282  
   283  import (
   284  	"io"
   285  	"utf8"
   286  )
   287  `,
   288  	},
   289  	{
   290  		name: "import.7",
   291  		pkg:  "io",
   292  		in: `package main
   293  
   294  import (
   295  	"io"   // a
   296  	"os"   // b
   297  	"utf8" // c
   298  )
   299  `,
   300  		out: `package main
   301  
   302  import (
   303  	// a
   304  	"os"   // b
   305  	"utf8" // c
   306  )
   307  `,
   308  	},
   309  	{
   310  		name: "import.8",
   311  		pkg:  "os",
   312  		in: `package main
   313  
   314  import (
   315  	"io"   // a
   316  	"os"   // b
   317  	"utf8" // c
   318  )
   319  `,
   320  		out: `package main
   321  
   322  import (
   323  	"io" // a
   324  	// b
   325  	"utf8" // c
   326  )
   327  `,
   328  	},
   329  	{
   330  		name: "import.9",
   331  		pkg:  "utf8",
   332  		in: `package main
   333  
   334  import (
   335  	"io"   // a
   336  	"os"   // b
   337  	"utf8" // c
   338  )
   339  `,
   340  		out: `package main
   341  
   342  import (
   343  	"io" // a
   344  	"os" // b
   345  	// c
   346  )
   347  `,
   348  	},
   349  	{
   350  		name: "import.10",
   351  		pkg:  "io",
   352  		in: `package main
   353  
   354  import (
   355  	"io"
   356  	"os"
   357  	"utf8"
   358  )
   359  `,
   360  		out: `package main
   361  
   362  import (
   363  	"os"
   364  	"utf8"
   365  )
   366  `,
   367  	},
   368  	{
   369  		name: "import.11",
   370  		pkg:  "os",
   371  		in: `package main
   372  
   373  import (
   374  	"io"
   375  	"os"
   376  	"utf8"
   377  )
   378  `,
   379  		out: `package main
   380  
   381  import (
   382  	"io"
   383  	"utf8"
   384  )
   385  `,
   386  	},
   387  	{
   388  		name: "import.12",
   389  		pkg:  "utf8",
   390  		in: `package main
   391  
   392  import (
   393  	"io"
   394  	"os"
   395  	"utf8"
   396  )
   397  `,
   398  		out: `package main
   399  
   400  import (
   401  	"io"
   402  	"os"
   403  )
   404  `,
   405  	},
   406  	{
   407  		name: "handle.raw.quote.imports",
   408  		pkg:  "os",
   409  		in:   "package main\n\nimport `os`",
   410  		out: `package main
   411  `,
   412  	},
   413  	{
   414  		name: "import.13",
   415  		pkg:  "io",
   416  		in: `package main
   417  
   418  import (
   419  	"fmt"
   420  
   421  	"io"
   422  	"os"
   423  	"utf8"
   424  
   425  	"go/format"
   426  )
   427  `,
   428  		out: `package main
   429  
   430  import (
   431  	"fmt"
   432  
   433  	"os"
   434  	"utf8"
   435  
   436  	"go/format"
   437  )
   438  `,
   439  	},
   440  	{
   441  		name: "import.14",
   442  		pkg:  "io",
   443  		in: `package main
   444  
   445  import (
   446  	"fmt" // a
   447  
   448  	"io"   // b
   449  	"os"   // c
   450  	"utf8" // d
   451  
   452  	"go/format" // e
   453  )
   454  `,
   455  		out: `package main
   456  
   457  import (
   458  	"fmt" // a
   459  
   460  	// b
   461  	"os"   // c
   462  	"utf8" // d
   463  
   464  	"go/format" // e
   465  )
   466  `,
   467  	},
   468  	{
   469  		name: "import.15",
   470  		pkg:  "double",
   471  		in: `package main
   472  
   473  import (
   474  	"double"
   475  	"double"
   476  )
   477  `,
   478  		out: `package main
   479  `,
   480  	},
   481  	{
   482  		name: "import.16",
   483  		pkg:  "bubble",
   484  		in: `package main
   485  
   486  import (
   487  	"toil"
   488  	"bubble"
   489  	"bubble"
   490  	"trouble"
   491  )
   492  `,
   493  		out: `package main
   494  
   495  import (
   496  	"toil"
   497  	"trouble"
   498  )
   499  `,
   500  	},
   501  	{
   502  		name: "import.17",
   503  		pkg:  "quad",
   504  		in: `package main
   505  
   506  import (
   507  	"quad"
   508  	"quad"
   509  )
   510  
   511  import (
   512  	"quad"
   513  	"quad"
   514  )
   515  `,
   516  		out: `package main
   517  `,
   518  	},
   519  }
   520  
   521  func TestDeleteImport(t *testing.T) {
   522  	for _, test := range deleteTests {
   523  		file := parse(t, test.name, test.in)
   524  		DeleteImport(fset, file, test.pkg)
   525  		if got := print(t, test.name, file); got != test.out {
   526  			t.Errorf("%s:\ngot: %s\nwant: %s", test.name, got, test.out)
   527  		}
   528  	}
   529  }
   530  
   531  type rewriteTest struct {
   532  	name   string
   533  	srcPkg string
   534  	dstPkg string
   535  	in     string
   536  	out    string
   537  }
   538  
   539  var rewriteTests = []rewriteTest{
   540  	{
   541  		name:   "import.13",
   542  		srcPkg: "utf8",
   543  		dstPkg: "encoding/utf8",
   544  		in: `package main
   545  
   546  import (
   547  	"io"
   548  	"os"
   549  	"utf8" // thanks ken
   550  )
   551  `,
   552  		out: `package main
   553  
   554  import (
   555  	"encoding/utf8" // thanks ken
   556  	"io"
   557  	"os"
   558  )
   559  `,
   560  	},
   561  	{
   562  		name:   "import.14",
   563  		srcPkg: "asn1",
   564  		dstPkg: "encoding/asn1",
   565  		in: `package main
   566  
   567  import (
   568  	"asn1"
   569  	"crypto"
   570  	"crypto/rsa"
   571  	_ "crypto/sha1"
   572  	"crypto/x509"
   573  	"crypto/x509/pkix"
   574  	"time"
   575  )
   576  
   577  var x = 1
   578  `,
   579  		out: `package main
   580  
   581  import (
   582  	"crypto"
   583  	"crypto/rsa"
   584  	_ "crypto/sha1"
   585  	"crypto/x509"
   586  	"crypto/x509/pkix"
   587  	"encoding/asn1"
   588  	"time"
   589  )
   590  
   591  var x = 1
   592  `,
   593  	},
   594  	{
   595  		name:   "import.15",
   596  		srcPkg: "url",
   597  		dstPkg: "net/url",
   598  		in: `package main
   599  
   600  import (
   601  	"bufio"
   602  	"net"
   603  	"path"
   604  	"url"
   605  )
   606  
   607  var x = 1 // comment on x, not on url
   608  `,
   609  		out: `package main
   610  
   611  import (
   612  	"bufio"
   613  	"net"
   614  	"net/url"
   615  	"path"
   616  )
   617  
   618  var x = 1 // comment on x, not on url
   619  `,
   620  	},
   621  	{
   622  		name:   "import.16",
   623  		srcPkg: "http",
   624  		dstPkg: "net/http",
   625  		in: `package main
   626  
   627  import (
   628  	"flag"
   629  	"http"
   630  	"log"
   631  	"text/template"
   632  )
   633  
   634  var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18
   635  `,
   636  		out: `package main
   637  
   638  import (
   639  	"flag"
   640  	"log"
   641  	"net/http"
   642  	"text/template"
   643  )
   644  
   645  var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18
   646  `,
   647  	},
   648  }
   649  
   650  func TestRewriteImport(t *testing.T) {
   651  	for _, test := range rewriteTests {
   652  		file := parse(t, test.name, test.in)
   653  		RewriteImport(fset, file, test.srcPkg, test.dstPkg)
   654  		if got := print(t, test.name, file); got != test.out {
   655  			t.Errorf("%s:\ngot: %s\nwant: %s", test.name, got, test.out)
   656  		}
   657  	}
   658  }
   659  
   660  var importsTests = []struct {
   661  	name string
   662  	in   string
   663  	want [][]string
   664  }{
   665  	{
   666  		name: "no packages",
   667  		in: `package foo
   668  `,
   669  		want: nil,
   670  	},
   671  	{
   672  		name: "one group",
   673  		in: `package foo
   674  
   675  import (
   676  	"fmt"
   677  	"testing"
   678  )
   679  `,
   680  		want: [][]string{{"fmt", "testing"}},
   681  	},
   682  	{
   683  		name: "four groups",
   684  		in: `package foo
   685  
   686  import "C"
   687  import (
   688  	"fmt"
   689  	"testing"
   690  
   691  	"appengine"
   692  
   693  	"myproject/mylib1"
   694  	"myproject/mylib2"
   695  )
   696  `,
   697  		want: [][]string{
   698  			{"C"},
   699  			{"fmt", "testing"},
   700  			{"appengine"},
   701  			{"myproject/mylib1", "myproject/mylib2"},
   702  		},
   703  	},
   704  	{
   705  		name: "multiple factored groups",
   706  		in: `package foo
   707  
   708  import (
   709  	"fmt"
   710  	"testing"
   711  
   712  	"appengine"
   713  )
   714  import (
   715  	"reflect"
   716  
   717  	"bytes"
   718  )
   719  `,
   720  		want: [][]string{
   721  			{"fmt", "testing"},
   722  			{"appengine"},
   723  			{"reflect"},
   724  			{"bytes"},
   725  		},
   726  	},
   727  }
   728  
   729  func unquote(s string) string {
   730  	res, err := strconv.Unquote(s)
   731  	if err != nil {
   732  		return "could_not_unquote"
   733  	}
   734  	return res
   735  }
   736  
   737  func TestImports(t *testing.T) {
   738  	fset := token.NewFileSet()
   739  	for _, test := range importsTests {
   740  		f, err := parser.ParseFile(fset, "test.go", test.in, 0)
   741  		if err != nil {
   742  			t.Errorf("%s: %v", test.name, err)
   743  			continue
   744  		}
   745  		var got [][]string
   746  		for _, block := range Imports(fset, f) {
   747  			var b []string
   748  			for _, spec := range block {
   749  				b = append(b, unquote(spec.Path.Value))
   750  			}
   751  			got = append(got, b)
   752  		}
   753  		if !reflect.DeepEqual(got, test.want) {
   754  			t.Errorf("Imports(%s)=%v, want %v", test.name, got, test.want)
   755  		}
   756  	}
   757  }