github.com/bgentry/go@v0.0.0-20150121062915-6cf5a733d54d/src/cmd/fix/import_test.go (about)

     1  // Copyright 2011 The Go Authors.  All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import "go/ast"
     8  
     9  func init() {
    10  	addTestCases(importTests, nil)
    11  }
    12  
    13  var importTests = []testCase{
    14  	{
    15  		Name: "import.0",
    16  		Fn:   addImportFn("os"),
    17  		In: `package main
    18  
    19  import (
    20  	"os"
    21  )
    22  `,
    23  		Out: `package main
    24  
    25  import (
    26  	"os"
    27  )
    28  `,
    29  	},
    30  	{
    31  		Name: "import.1",
    32  		Fn:   addImportFn("os"),
    33  		In: `package main
    34  `,
    35  		Out: `package main
    36  
    37  import "os"
    38  `,
    39  	},
    40  	{
    41  		Name: "import.2",
    42  		Fn:   addImportFn("os"),
    43  		In: `package main
    44  
    45  // Comment
    46  import "C"
    47  `,
    48  		Out: `package main
    49  
    50  // Comment
    51  import "C"
    52  import "os"
    53  `,
    54  	},
    55  	{
    56  		Name: "import.3",
    57  		Fn:   addImportFn("os"),
    58  		In: `package main
    59  
    60  // Comment
    61  import "C"
    62  
    63  import (
    64  	"io"
    65  	"utf8"
    66  )
    67  `,
    68  		Out: `package main
    69  
    70  // Comment
    71  import "C"
    72  
    73  import (
    74  	"io"
    75  	"os"
    76  	"utf8"
    77  )
    78  `,
    79  	},
    80  	{
    81  		Name: "import.4",
    82  		Fn:   deleteImportFn("os"),
    83  		In: `package main
    84  
    85  import (
    86  	"os"
    87  )
    88  `,
    89  		Out: `package main
    90  `,
    91  	},
    92  	{
    93  		Name: "import.5",
    94  		Fn:   deleteImportFn("os"),
    95  		In: `package main
    96  
    97  // Comment
    98  import "C"
    99  import "os"
   100  `,
   101  		Out: `package main
   102  
   103  // Comment
   104  import "C"
   105  `,
   106  	},
   107  	{
   108  		Name: "import.6",
   109  		Fn:   deleteImportFn("os"),
   110  		In: `package main
   111  
   112  // Comment
   113  import "C"
   114  
   115  import (
   116  	"io"
   117  	"os"
   118  	"utf8"
   119  )
   120  `,
   121  		Out: `package main
   122  
   123  // Comment
   124  import "C"
   125  
   126  import (
   127  	"io"
   128  	"utf8"
   129  )
   130  `,
   131  	},
   132  	{
   133  		Name: "import.7",
   134  		Fn:   deleteImportFn("io"),
   135  		In: `package main
   136  
   137  import (
   138  	"io"   // a
   139  	"os"   // b
   140  	"utf8" // c
   141  )
   142  `,
   143  		Out: `package main
   144  
   145  import (
   146  	// a
   147  	"os"   // b
   148  	"utf8" // c
   149  )
   150  `,
   151  	},
   152  	{
   153  		Name: "import.8",
   154  		Fn:   deleteImportFn("os"),
   155  		In: `package main
   156  
   157  import (
   158  	"io"   // a
   159  	"os"   // b
   160  	"utf8" // c
   161  )
   162  `,
   163  		Out: `package main
   164  
   165  import (
   166  	"io" // a
   167  	// b
   168  	"utf8" // c
   169  )
   170  `,
   171  	},
   172  	{
   173  		Name: "import.9",
   174  		Fn:   deleteImportFn("utf8"),
   175  		In: `package main
   176  
   177  import (
   178  	"io"   // a
   179  	"os"   // b
   180  	"utf8" // c
   181  )
   182  `,
   183  		Out: `package main
   184  
   185  import (
   186  	"io" // a
   187  	"os" // b
   188  	// c
   189  )
   190  `,
   191  	},
   192  	{
   193  		Name: "import.10",
   194  		Fn:   deleteImportFn("io"),
   195  		In: `package main
   196  
   197  import (
   198  	"io"
   199  	"os"
   200  	"utf8"
   201  )
   202  `,
   203  		Out: `package main
   204  
   205  import (
   206  	"os"
   207  	"utf8"
   208  )
   209  `,
   210  	},
   211  	{
   212  		Name: "import.11",
   213  		Fn:   deleteImportFn("os"),
   214  		In: `package main
   215  
   216  import (
   217  	"io"
   218  	"os"
   219  	"utf8"
   220  )
   221  `,
   222  		Out: `package main
   223  
   224  import (
   225  	"io"
   226  	"utf8"
   227  )
   228  `,
   229  	},
   230  	{
   231  		Name: "import.12",
   232  		Fn:   deleteImportFn("utf8"),
   233  		In: `package main
   234  
   235  import (
   236  	"io"
   237  	"os"
   238  	"utf8"
   239  )
   240  `,
   241  		Out: `package main
   242  
   243  import (
   244  	"io"
   245  	"os"
   246  )
   247  `,
   248  	},
   249  	{
   250  		Name: "import.13",
   251  		Fn:   rewriteImportFn("utf8", "encoding/utf8"),
   252  		In: `package main
   253  
   254  import (
   255  	"io"
   256  	"os"
   257  	"utf8" // thanks ken
   258  )
   259  `,
   260  		Out: `package main
   261  
   262  import (
   263  	"encoding/utf8" // thanks ken
   264  	"io"
   265  	"os"
   266  )
   267  `,
   268  	},
   269  	{
   270  		Name: "import.14",
   271  		Fn:   rewriteImportFn("asn1", "encoding/asn1"),
   272  		In: `package main
   273  
   274  import (
   275  	"asn1"
   276  	"crypto"
   277  	"crypto/rsa"
   278  	_ "crypto/sha1"
   279  	"crypto/x509"
   280  	"crypto/x509/pkix"
   281  	"time"
   282  )
   283  
   284  var x = 1
   285  `,
   286  		Out: `package main
   287  
   288  import (
   289  	"crypto"
   290  	"crypto/rsa"
   291  	_ "crypto/sha1"
   292  	"crypto/x509"
   293  	"crypto/x509/pkix"
   294  	"encoding/asn1"
   295  	"time"
   296  )
   297  
   298  var x = 1
   299  `,
   300  	},
   301  	{
   302  		Name: "import.15",
   303  		Fn:   rewriteImportFn("url", "net/url"),
   304  		In: `package main
   305  
   306  import (
   307  	"bufio"
   308  	"net"
   309  	"path"
   310  	"url"
   311  )
   312  
   313  var x = 1 // comment on x, not on url
   314  `,
   315  		Out: `package main
   316  
   317  import (
   318  	"bufio"
   319  	"net"
   320  	"net/url"
   321  	"path"
   322  )
   323  
   324  var x = 1 // comment on x, not on url
   325  `,
   326  	},
   327  	{
   328  		Name: "import.16",
   329  		Fn:   rewriteImportFn("http", "net/http", "template", "text/template"),
   330  		In: `package main
   331  
   332  import (
   333  	"flag"
   334  	"http"
   335  	"log"
   336  	"template"
   337  )
   338  
   339  var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18
   340  `,
   341  		Out: `package main
   342  
   343  import (
   344  	"flag"
   345  	"log"
   346  	"net/http"
   347  	"text/template"
   348  )
   349  
   350  var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18
   351  `,
   352  	},
   353  	{
   354  		Name: "import.17",
   355  		Fn:   addImportFn("x/y/z", "x/a/c"),
   356  		In: `package main
   357  
   358  // Comment
   359  import "C"
   360  
   361  import (
   362  	"a"
   363  	"b"
   364  
   365  	"x/w"
   366  
   367  	"d/f"
   368  )
   369  `,
   370  		Out: `package main
   371  
   372  // Comment
   373  import "C"
   374  
   375  import (
   376  	"a"
   377  	"b"
   378  
   379  	"x/a/c"
   380  	"x/w"
   381  	"x/y/z"
   382  
   383  	"d/f"
   384  )
   385  `,
   386  	},
   387  	{
   388  		Name: "import.18",
   389  		Fn:   addDelImportFn("e", "o"),
   390  		In: `package main
   391  
   392  import (
   393  	"f"
   394  	"o"
   395  	"z"
   396  )
   397  `,
   398  		Out: `package main
   399  
   400  import (
   401  	"e"
   402  	"f"
   403  	"z"
   404  )
   405  `,
   406  	},
   407  }
   408  
   409  func addImportFn(path ...string) func(*ast.File) bool {
   410  	return func(f *ast.File) bool {
   411  		fixed := false
   412  		for _, p := range path {
   413  			if !imports(f, p) {
   414  				addImport(f, p)
   415  				fixed = true
   416  			}
   417  		}
   418  		return fixed
   419  	}
   420  }
   421  
   422  func deleteImportFn(path string) func(*ast.File) bool {
   423  	return func(f *ast.File) bool {
   424  		if imports(f, path) {
   425  			deleteImport(f, path)
   426  			return true
   427  		}
   428  		return false
   429  	}
   430  }
   431  
   432  func addDelImportFn(p1 string, p2 string) func(*ast.File) bool {
   433  	return func(f *ast.File) bool {
   434  		fixed := false
   435  		if !imports(f, p1) {
   436  			addImport(f, p1)
   437  			fixed = true
   438  		}
   439  		if imports(f, p2) {
   440  			deleteImport(f, p2)
   441  			fixed = true
   442  		}
   443  		return fixed
   444  	}
   445  }
   446  
   447  func rewriteImportFn(oldnew ...string) func(*ast.File) bool {
   448  	return func(f *ast.File) bool {
   449  		fixed := false
   450  		for i := 0; i < len(oldnew); i += 2 {
   451  			if imports(f, oldnew[i]) {
   452  				rewriteImport(f, oldnew[i], oldnew[i+1])
   453  				fixed = true
   454  			}
   455  		}
   456  		return fixed
   457  	}
   458  }