github.com/llvm-mirror/llgo@v0.0.0-20190322182713-bf6f0a60fce1/third_party/gotools/go/ast/astutil/imports_test.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package astutil
     6  
     7  import (
     8  	"bytes"
     9  	"go/ast"
    10  	"go/format"
    11  	"go/parser"
    12  	"go/token"
    13  	"reflect"
    14  	"strconv"
    15  	"testing"
    16  )
    17  
    18  var fset = token.NewFileSet()
    19  
    20  func parse(t *testing.T, name, in string) *ast.File {
    21  	file, err := parser.ParseFile(fset, name, in, parser.ParseComments)
    22  	if err != nil {
    23  		t.Fatalf("%s parse: %v", name, err)
    24  	}
    25  	return file
    26  }
    27  
    28  func print(t *testing.T, name string, f *ast.File) string {
    29  	var buf bytes.Buffer
    30  	if err := format.Node(&buf, fset, f); err != nil {
    31  		t.Fatalf("%s gofmt: %v", name, err)
    32  	}
    33  	return string(buf.Bytes())
    34  }
    35  
    36  type test struct {
    37  	name       string
    38  	renamedPkg string
    39  	pkg        string
    40  	in         string
    41  	out        string
    42  	broken     bool // known broken
    43  }
    44  
    45  var addTests = []test{
    46  	{
    47  		name: "leave os alone",
    48  		pkg:  "os",
    49  		in: `package main
    50  
    51  import (
    52  	"os"
    53  )
    54  `,
    55  		out: `package main
    56  
    57  import (
    58  	"os"
    59  )
    60  `,
    61  	},
    62  	{
    63  		name: "import.1",
    64  		pkg:  "os",
    65  		in: `package main
    66  `,
    67  		out: `package main
    68  
    69  import "os"
    70  `,
    71  	},
    72  	{
    73  		name: "import.2",
    74  		pkg:  "os",
    75  		in: `package main
    76  
    77  // Comment
    78  import "C"
    79  `,
    80  		out: `package main
    81  
    82  // Comment
    83  import "C"
    84  import "os"
    85  `,
    86  	},
    87  	{
    88  		name: "import.3",
    89  		pkg:  "os",
    90  		in: `package main
    91  
    92  // Comment
    93  import "C"
    94  
    95  import (
    96  	"io"
    97  	"utf8"
    98  )
    99  `,
   100  		out: `package main
   101  
   102  // Comment
   103  import "C"
   104  
   105  import (
   106  	"io"
   107  	"os"
   108  	"utf8"
   109  )
   110  `,
   111  	},
   112  	{
   113  		name: "import.17",
   114  		pkg:  "x/y/z",
   115  		in: `package main
   116  
   117  // Comment
   118  import "C"
   119  
   120  import (
   121  	"a"
   122  	"b"
   123  
   124  	"x/w"
   125  
   126  	"d/f"
   127  )
   128  `,
   129  		out: `package main
   130  
   131  // Comment
   132  import "C"
   133  
   134  import (
   135  	"a"
   136  	"b"
   137  
   138  	"x/w"
   139  	"x/y/z"
   140  
   141  	"d/f"
   142  )
   143  `,
   144  	},
   145  	{
   146  		name: "import into singular group",
   147  		pkg:  "bytes",
   148  		in: `package main
   149  
   150  import "os"
   151  
   152  `,
   153  		out: `package main
   154  
   155  import (
   156  	"bytes"
   157  	"os"
   158  )
   159  `,
   160  	},
   161  	{
   162  		name: "import into singular group with comment",
   163  		pkg:  "bytes",
   164  		in: `package main
   165  
   166  import /* why */ /* comment here? */ "os"
   167  
   168  `,
   169  		out: `package main
   170  
   171  import /* why */ /* comment here? */ (
   172  	"bytes"
   173  	"os"
   174  )
   175  `,
   176  	},
   177  	{
   178  		name: "import into group with leading comment",
   179  		pkg:  "strings",
   180  		in: `package main
   181  
   182  import (
   183  	// comment before bytes
   184  	"bytes"
   185  	"os"
   186  )
   187  
   188  `,
   189  		out: `package main
   190  
   191  import (
   192  	// comment before bytes
   193  	"bytes"
   194  	"os"
   195  	"strings"
   196  )
   197  `,
   198  	},
   199  	{
   200  		name:       "",
   201  		renamedPkg: "fmtpkg",
   202  		pkg:        "fmt",
   203  		in: `package main
   204  
   205  import "os"
   206  
   207  `,
   208  		out: `package main
   209  
   210  import (
   211  	fmtpkg "fmt"
   212  	"os"
   213  )
   214  `,
   215  	},
   216  	{
   217  		name: "struct comment",
   218  		pkg:  "time",
   219  		in: `package main
   220  
   221  // This is a comment before a struct.
   222  type T struct {
   223  	t  time.Time
   224  }
   225  `,
   226  		out: `package main
   227  
   228  import "time"
   229  
   230  // This is a comment before a struct.
   231  type T struct {
   232  	t time.Time
   233  }
   234  `,
   235  	},
   236  	{
   237  		name: "issue 8729 import C",
   238  		pkg:  "time",
   239  		in: `package main
   240  
   241  import "C"
   242  
   243  // comment
   244  type T time.Time
   245  `,
   246  		out: `package main
   247  
   248  import "C"
   249  import "time"
   250  
   251  // comment
   252  type T time.Time
   253  `,
   254  	},
   255  	{
   256  		name: "issue 8729 empty import",
   257  		pkg:  "time",
   258  		in: `package main
   259  
   260  import ()
   261  
   262  // comment
   263  type T time.Time
   264  `,
   265  		out: `package main
   266  
   267  import "time"
   268  
   269  // comment
   270  type T time.Time
   271  `,
   272  	},
   273  	{
   274  		name: "issue 8729 comment on package line",
   275  		pkg:  "time",
   276  		in: `package main // comment
   277  
   278  type T time.Time
   279  `,
   280  		out: `package main // comment
   281  import "time"
   282  
   283  type T time.Time
   284  `,
   285  	},
   286  	{
   287  		name: "issue 8729 comment after package",
   288  		pkg:  "time",
   289  		in: `package main
   290  // comment
   291  
   292  type T time.Time
   293  `,
   294  		out: `package main
   295  
   296  import "time"
   297  
   298  // comment
   299  
   300  type T time.Time
   301  `,
   302  	},
   303  	{
   304  		name: "issue 8729 comment before and on package line",
   305  		pkg:  "time",
   306  		in: `// comment before
   307  package main // comment on
   308  
   309  type T time.Time
   310  `,
   311  		out: `// comment before
   312  package main // comment on
   313  import "time"
   314  
   315  type T time.Time
   316  `,
   317  	},
   318  
   319  	// Issue 9961: Match prefixes using path segments rather than bytes
   320  	{
   321  		name: "issue 9961",
   322  		pkg:  "regexp",
   323  		in: `package main
   324  
   325  import (
   326  	"flag"
   327  	"testing"
   328  
   329  	"rsc.io/p"
   330  )
   331  `,
   332  		out: `package main
   333  
   334  import (
   335  	"flag"
   336  	"regexp"
   337  	"testing"
   338  
   339  	"rsc.io/p"
   340  )
   341  `,
   342  	},
   343  }
   344  
   345  func TestAddImport(t *testing.T) {
   346  	for _, test := range addTests {
   347  		file := parse(t, test.name, test.in)
   348  		var before bytes.Buffer
   349  		ast.Fprint(&before, fset, file, nil)
   350  		AddNamedImport(fset, file, test.renamedPkg, test.pkg)
   351  		if got := print(t, test.name, file); got != test.out {
   352  			if test.broken {
   353  				t.Logf("%s is known broken:\ngot: %s\nwant: %s", test.name, got, test.out)
   354  			} else {
   355  				t.Errorf("%s:\ngot: %s\nwant: %s", test.name, got, test.out)
   356  			}
   357  			var after bytes.Buffer
   358  			ast.Fprint(&after, fset, file, nil)
   359  
   360  			t.Logf("AST before:\n%s\nAST after:\n%s\n", before.String(), after.String())
   361  		}
   362  	}
   363  }
   364  
   365  func TestDoubleAddImport(t *testing.T) {
   366  	file := parse(t, "doubleimport", "package main\n")
   367  	AddImport(fset, file, "os")
   368  	AddImport(fset, file, "bytes")
   369  	want := `package main
   370  
   371  import (
   372  	"bytes"
   373  	"os"
   374  )
   375  `
   376  	if got := print(t, "doubleimport", file); got != want {
   377  		t.Errorf("got: %s\nwant: %s", got, want)
   378  	}
   379  }
   380  
   381  // Part of issue 8729.
   382  func TestDoubleAddImportWithDeclComment(t *testing.T) {
   383  	file := parse(t, "doubleimport", `package main
   384  
   385  import (
   386  )
   387  
   388  // comment
   389  type I int
   390  `)
   391  	// The AddImport order here matters.
   392  	AddImport(fset, file, "llvm.org/llgo/third_party/gotools/go/ast/astutil")
   393  	AddImport(fset, file, "os")
   394  	want := `package main
   395  
   396  import (
   397  	"llvm.org/llgo/third_party/gotools/go/ast/astutil"
   398  	"os"
   399  )
   400  
   401  // comment
   402  type I int
   403  `
   404  	if got := print(t, "doubleimport_with_decl_comment", file); got != want {
   405  		t.Errorf("got: %s\nwant: %s", got, want)
   406  	}
   407  }
   408  
   409  var deleteTests = []test{
   410  	{
   411  		name: "import.4",
   412  		pkg:  "os",
   413  		in: `package main
   414  
   415  import (
   416  	"os"
   417  )
   418  `,
   419  		out: `package main
   420  `,
   421  	},
   422  	{
   423  		name: "import.5",
   424  		pkg:  "os",
   425  		in: `package main
   426  
   427  // Comment
   428  import "C"
   429  import "os"
   430  `,
   431  		out: `package main
   432  
   433  // Comment
   434  import "C"
   435  `,
   436  	},
   437  	{
   438  		name: "import.6",
   439  		pkg:  "os",
   440  		in: `package main
   441  
   442  // Comment
   443  import "C"
   444  
   445  import (
   446  	"io"
   447  	"os"
   448  	"utf8"
   449  )
   450  `,
   451  		out: `package main
   452  
   453  // Comment
   454  import "C"
   455  
   456  import (
   457  	"io"
   458  	"utf8"
   459  )
   460  `,
   461  	},
   462  	{
   463  		name: "import.7",
   464  		pkg:  "io",
   465  		in: `package main
   466  
   467  import (
   468  	"io"   // a
   469  	"os"   // b
   470  	"utf8" // c
   471  )
   472  `,
   473  		out: `package main
   474  
   475  import (
   476  	// a
   477  	"os"   // b
   478  	"utf8" // c
   479  )
   480  `,
   481  	},
   482  	{
   483  		name: "import.8",
   484  		pkg:  "os",
   485  		in: `package main
   486  
   487  import (
   488  	"io"   // a
   489  	"os"   // b
   490  	"utf8" // c
   491  )
   492  `,
   493  		out: `package main
   494  
   495  import (
   496  	"io" // a
   497  	// b
   498  	"utf8" // c
   499  )
   500  `,
   501  	},
   502  	{
   503  		name: "import.9",
   504  		pkg:  "utf8",
   505  		in: `package main
   506  
   507  import (
   508  	"io"   // a
   509  	"os"   // b
   510  	"utf8" // c
   511  )
   512  `,
   513  		out: `package main
   514  
   515  import (
   516  	"io" // a
   517  	"os" // b
   518  	// c
   519  )
   520  `,
   521  	},
   522  	{
   523  		name: "import.10",
   524  		pkg:  "io",
   525  		in: `package main
   526  
   527  import (
   528  	"io"
   529  	"os"
   530  	"utf8"
   531  )
   532  `,
   533  		out: `package main
   534  
   535  import (
   536  	"os"
   537  	"utf8"
   538  )
   539  `,
   540  	},
   541  	{
   542  		name: "import.11",
   543  		pkg:  "os",
   544  		in: `package main
   545  
   546  import (
   547  	"io"
   548  	"os"
   549  	"utf8"
   550  )
   551  `,
   552  		out: `package main
   553  
   554  import (
   555  	"io"
   556  	"utf8"
   557  )
   558  `,
   559  	},
   560  	{
   561  		name: "import.12",
   562  		pkg:  "utf8",
   563  		in: `package main
   564  
   565  import (
   566  	"io"
   567  	"os"
   568  	"utf8"
   569  )
   570  `,
   571  		out: `package main
   572  
   573  import (
   574  	"io"
   575  	"os"
   576  )
   577  `,
   578  	},
   579  	{
   580  		name: "handle.raw.quote.imports",
   581  		pkg:  "os",
   582  		in:   "package main\n\nimport `os`",
   583  		out: `package main
   584  `,
   585  	},
   586  	{
   587  		name: "import.13",
   588  		pkg:  "io",
   589  		in: `package main
   590  
   591  import (
   592  	"fmt"
   593  
   594  	"io"
   595  	"os"
   596  	"utf8"
   597  
   598  	"go/format"
   599  )
   600  `,
   601  		out: `package main
   602  
   603  import (
   604  	"fmt"
   605  
   606  	"os"
   607  	"utf8"
   608  
   609  	"go/format"
   610  )
   611  `,
   612  	},
   613  	{
   614  		name: "import.14",
   615  		pkg:  "io",
   616  		in: `package main
   617  
   618  import (
   619  	"fmt" // a
   620  
   621  	"io"   // b
   622  	"os"   // c
   623  	"utf8" // d
   624  
   625  	"go/format" // e
   626  )
   627  `,
   628  		out: `package main
   629  
   630  import (
   631  	"fmt" // a
   632  
   633  	// b
   634  	"os"   // c
   635  	"utf8" // d
   636  
   637  	"go/format" // e
   638  )
   639  `,
   640  	},
   641  	{
   642  		name: "import.15",
   643  		pkg:  "double",
   644  		in: `package main
   645  
   646  import (
   647  	"double"
   648  	"double"
   649  )
   650  `,
   651  		out: `package main
   652  `,
   653  	},
   654  	{
   655  		name: "import.16",
   656  		pkg:  "bubble",
   657  		in: `package main
   658  
   659  import (
   660  	"toil"
   661  	"bubble"
   662  	"bubble"
   663  	"trouble"
   664  )
   665  `,
   666  		out: `package main
   667  
   668  import (
   669  	"toil"
   670  	"trouble"
   671  )
   672  `,
   673  	},
   674  	{
   675  		name: "import.17",
   676  		pkg:  "quad",
   677  		in: `package main
   678  
   679  import (
   680  	"quad"
   681  	"quad"
   682  )
   683  
   684  import (
   685  	"quad"
   686  	"quad"
   687  )
   688  `,
   689  		out: `package main
   690  `,
   691  	},
   692  }
   693  
   694  func TestDeleteImport(t *testing.T) {
   695  	for _, test := range deleteTests {
   696  		file := parse(t, test.name, test.in)
   697  		DeleteImport(fset, file, test.pkg)
   698  		if got := print(t, test.name, file); got != test.out {
   699  			t.Errorf("%s:\ngot: %s\nwant: %s", test.name, got, test.out)
   700  		}
   701  	}
   702  }
   703  
   704  type rewriteTest struct {
   705  	name   string
   706  	srcPkg string
   707  	dstPkg string
   708  	in     string
   709  	out    string
   710  }
   711  
   712  var rewriteTests = []rewriteTest{
   713  	{
   714  		name:   "import.13",
   715  		srcPkg: "utf8",
   716  		dstPkg: "encoding/utf8",
   717  		in: `package main
   718  
   719  import (
   720  	"io"
   721  	"os"
   722  	"utf8" // thanks ken
   723  )
   724  `,
   725  		out: `package main
   726  
   727  import (
   728  	"encoding/utf8" // thanks ken
   729  	"io"
   730  	"os"
   731  )
   732  `,
   733  	},
   734  	{
   735  		name:   "import.14",
   736  		srcPkg: "asn1",
   737  		dstPkg: "encoding/asn1",
   738  		in: `package main
   739  
   740  import (
   741  	"asn1"
   742  	"crypto"
   743  	"crypto/rsa"
   744  	_ "crypto/sha1"
   745  	"crypto/x509"
   746  	"crypto/x509/pkix"
   747  	"time"
   748  )
   749  
   750  var x = 1
   751  `,
   752  		out: `package main
   753  
   754  import (
   755  	"crypto"
   756  	"crypto/rsa"
   757  	_ "crypto/sha1"
   758  	"crypto/x509"
   759  	"crypto/x509/pkix"
   760  	"encoding/asn1"
   761  	"time"
   762  )
   763  
   764  var x = 1
   765  `,
   766  	},
   767  	{
   768  		name:   "import.15",
   769  		srcPkg: "url",
   770  		dstPkg: "net/url",
   771  		in: `package main
   772  
   773  import (
   774  	"bufio"
   775  	"net"
   776  	"path"
   777  	"url"
   778  )
   779  
   780  var x = 1 // comment on x, not on url
   781  `,
   782  		out: `package main
   783  
   784  import (
   785  	"bufio"
   786  	"net"
   787  	"net/url"
   788  	"path"
   789  )
   790  
   791  var x = 1 // comment on x, not on url
   792  `,
   793  	},
   794  	{
   795  		name:   "import.16",
   796  		srcPkg: "http",
   797  		dstPkg: "net/http",
   798  		in: `package main
   799  
   800  import (
   801  	"flag"
   802  	"http"
   803  	"log"
   804  	"text/template"
   805  )
   806  
   807  var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18
   808  `,
   809  		out: `package main
   810  
   811  import (
   812  	"flag"
   813  	"log"
   814  	"net/http"
   815  	"text/template"
   816  )
   817  
   818  var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18
   819  `,
   820  	},
   821  }
   822  
   823  func TestRewriteImport(t *testing.T) {
   824  	for _, test := range rewriteTests {
   825  		file := parse(t, test.name, test.in)
   826  		RewriteImport(fset, file, test.srcPkg, test.dstPkg)
   827  		if got := print(t, test.name, file); got != test.out {
   828  			t.Errorf("%s:\ngot: %s\nwant: %s", test.name, got, test.out)
   829  		}
   830  	}
   831  }
   832  
   833  var importsTests = []struct {
   834  	name string
   835  	in   string
   836  	want [][]string
   837  }{
   838  	{
   839  		name: "no packages",
   840  		in: `package foo
   841  `,
   842  		want: nil,
   843  	},
   844  	{
   845  		name: "one group",
   846  		in: `package foo
   847  
   848  import (
   849  	"fmt"
   850  	"testing"
   851  )
   852  `,
   853  		want: [][]string{{"fmt", "testing"}},
   854  	},
   855  	{
   856  		name: "four groups",
   857  		in: `package foo
   858  
   859  import "C"
   860  import (
   861  	"fmt"
   862  	"testing"
   863  
   864  	"appengine"
   865  
   866  	"myproject/mylib1"
   867  	"myproject/mylib2"
   868  )
   869  `,
   870  		want: [][]string{
   871  			{"C"},
   872  			{"fmt", "testing"},
   873  			{"appengine"},
   874  			{"myproject/mylib1", "myproject/mylib2"},
   875  		},
   876  	},
   877  	{
   878  		name: "multiple factored groups",
   879  		in: `package foo
   880  
   881  import (
   882  	"fmt"
   883  	"testing"
   884  
   885  	"appengine"
   886  )
   887  import (
   888  	"reflect"
   889  
   890  	"bytes"
   891  )
   892  `,
   893  		want: [][]string{
   894  			{"fmt", "testing"},
   895  			{"appengine"},
   896  			{"reflect"},
   897  			{"bytes"},
   898  		},
   899  	},
   900  }
   901  
   902  func unquote(s string) string {
   903  	res, err := strconv.Unquote(s)
   904  	if err != nil {
   905  		return "could_not_unquote"
   906  	}
   907  	return res
   908  }
   909  
   910  func TestImports(t *testing.T) {
   911  	fset := token.NewFileSet()
   912  	for _, test := range importsTests {
   913  		f, err := parser.ParseFile(fset, "test.go", test.in, 0)
   914  		if err != nil {
   915  			t.Errorf("%s: %v", test.name, err)
   916  			continue
   917  		}
   918  		var got [][]string
   919  		for _, group := range Imports(fset, f) {
   920  			var b []string
   921  			for _, spec := range group {
   922  				b = append(b, unquote(spec.Path.Value))
   923  			}
   924  			got = append(got, b)
   925  		}
   926  		if !reflect.DeepEqual(got, test.want) {
   927  			t.Errorf("Imports(%s)=%v, want %v", test.name, got, test.want)
   928  		}
   929  	}
   930  }