github.com/lixvbnet/courtney@v0.0.0-20221025031132-0dcb02231211/scanner/scanner_test.go (about)

     1  package scanner_test
     2  
     3  import (
     4  	"regexp"
     5  	"strconv"
     6  	"strings"
     7  	"testing"
     8  
     9  	"path/filepath"
    10  
    11  	"github.com/lixvbnet/courtney/scanner"
    12  	"github.com/lixvbnet/courtney/shared"
    13  	"github.com/dave/patsy"
    14  	"github.com/dave/patsy/builder"
    15  	"github.com/dave/patsy/vos"
    16  )
    17  
    18  func TestSingle(t *testing.T) {
    19  	tests := map[string]string{
    20  		"single": `package a
    21  
    22  			func wrap(error) error
    23  			
    24  			func a() error {
    25  				var a bool
    26  				var err error
    27  				if err != nil {
    28  					if a { // this line will not be excluded!
    29  						return wrap(err) // *
    30  					}
    31  					return wrap(err) // *
    32  				}
    33  				return nil
    34  			}
    35  		`,
    36  	}
    37  	test(t, tests)
    38  }
    39  
    40  func TestSwitchCase(t *testing.T) {
    41  	tests := map[string]string{
    42  		"simple switch": `package a
    43  			
    44  			func a() error {
    45  				var err error
    46  				switch {
    47  				case err != nil:
    48  					return err // *
    49  				}
    50  				return nil
    51  			}
    52  		`,
    53  		"switch multi": `package a
    54  			
    55  			func a() error {
    56  				var a bool
    57  				var err error
    58  				switch {
    59  				case err == nil, a:
    60  					return err
    61  				default: 
    62  					return err // *
    63  				}
    64  				return nil
    65  			}
    66  		`,
    67  		"simple switch ignored": `package a
    68  			
    69  			func a() error {
    70  				var a bool
    71  				var err error
    72  				switch a {
    73  				case err != nil:
    74  					return err
    75  				}
    76  				return nil
    77  			}
    78  		`,
    79  		"complex switch": `package a
    80  		
    81  			func foo() error {
    82  				var err error
    83  				var b, c bool
    84  				var d int
    85  				switch {
    86  				case err == nil && (b && d > 0) || c:
    87  					return err
    88  				case d <= 0 || c:
    89  					return err
    90  				case b:
    91  					return err // *
    92  				}
    93  				return err
    94  			}
    95  		`,
    96  	}
    97  	test(t, tests)
    98  }
    99  
   100  func TestNamedParameters(t *testing.T) {
   101  	tests := map[string]string{
   102  		"named parameters simple": `package a
   103  			
   104  			func a() (err error) {
   105  				if err != nil {
   106  					return // *
   107  				}
   108  				return
   109  			}
   110  		`,
   111  		"named parameters ignored": `package a
   112  			
   113  			func a() {
   114  				var err error
   115  				if err != nil {
   116  					return
   117  				}
   118  				return
   119  			}
   120  		`,
   121  		"named parameters 2": `package a
   122  			
   123  			func a() (i int, err error) {
   124  				i = 1
   125  				if err != nil {
   126  					return // *
   127  				}
   128  				return
   129  			}
   130  		`,
   131  		"named parameters must be last": `package a
   132  			
   133  			func a() (err error, i int) {
   134  				i = 1
   135  				if err != nil {
   136  					return
   137  				}
   138  				return
   139  			}
   140  		`,
   141  		"named parameters must be not nil": `package a
   142  			
   143  			func a() (err error) {
   144  				return
   145  			}
   146  		`,
   147  		"named parameters func lit": `package a
   148  			
   149  			func a() {
   150  				func () (err error) {
   151  					if err != nil {
   152  						return // *
   153  					}
   154  					return
   155  				}()
   156  			}
   157  		`,
   158  	}
   159  	test(t, tests)
   160  }
   161  
   162  func TestBool(t *testing.T) {
   163  	tests := map[string]string{
   164  		"wrap1": `package a
   165  			
   166  			func a() error {
   167  				var wrap func(error) error
   168  				var err error
   169  				if err != nil {
   170  					return wrap(err) // *
   171  				}
   172  				return nil
   173  			}
   174  			`,
   175  		"wrap ignored": `package a
   176  			
   177  			func a() int {
   178  				var wrap func(error) int
   179  				var err error
   180  				if err != nil {
   181  					return wrap(err)
   182  				}
   183  				return 0
   184  			}
   185  			`,
   186  		"wrap2": `package a
   187  			
   188  			func a() error {
   189  				var wrap func(error) error
   190  				var err error
   191  				if err != nil {
   192  					w := wrap(err)
   193  					return w // *
   194  				}
   195  				return nil
   196  			}
   197  			`,
   198  		"wrap3": `package a
   199  			
   200  			func a() error {
   201  				var wrap func(error) error
   202  				var err error
   203  				var w error
   204  				if err != nil {
   205  					w = wrap(err)
   206  					return w // *
   207  				}
   208  				return nil
   209  			}
   210  			`,
   211  		"wrap4": `package a
   212  			
   213  			func a() error {
   214  				var wrap func(error) error
   215  				var err error
   216  				if err != nil {
   217  					var w = wrap(err)
   218  					return w // *
   219  				}
   220  				return nil
   221  			}
   222  			`,
   223  		"wrap5": `package a
   224  			
   225  			func a() error {
   226  				var wrap func(error) error
   227  				var err error
   228  				if err != nil {
   229  					var w error = wrap(err)
   230  					return w // *
   231  				}
   232  				return nil
   233  			}
   234  			`,
   235  		"wrap no tuple": `package a
   236  			
   237  			func a() (int, error) {
   238  				var wrap func(error) (int, error)
   239  				var err error
   240  				if err != nil {
   241  					return wrap(err)
   242  				}
   243  				return 0, nil
   244  			}
   245  		`,
   246  		"logical and first": `package a
   247  			
   248  			import "fmt"
   249  			
   250  			func a() error {
   251  				_, err := fmt.Println()
   252  				if err != nil && 1 == 1 {
   253  					return err // *
   254  				}
   255  				return nil
   256  			}
   257  			`,
   258  		"logical and second": `package a
   259  			
   260  			import "fmt"
   261  			
   262  			func a() error {
   263  				_, err := fmt.Println()
   264  				if 1 == 1 && err != nil {
   265  					return err // *
   266  				}
   267  				return nil
   268  			}
   269  			`,
   270  		"logical and third": `package a
   271  			
   272  			import "fmt"
   273  			
   274  			func a() error {
   275  				_, err := fmt.Println()
   276  				if 1 == 1 && 2 == 2 && err != nil {
   277  					return err // *
   278  				}
   279  				return nil
   280  			}
   281  			`,
   282  		"logical and brackets": `package a
   283  			
   284  			import "fmt"
   285  			
   286  			func a() error {
   287  				_, err := fmt.Println()
   288  				if 1 == 1 && (2 == 2 && err != nil) {
   289  					return err // *
   290  				}
   291  				return nil
   292  			}
   293  			`,
   294  		"logical or first": `package a
   295  			
   296  			import "fmt"
   297  			
   298  			func a() error {
   299  				_, err := fmt.Println()
   300  				if err == nil || 1 == 1 {
   301  					return err
   302  				} else {
   303  					return err // *
   304  				}
   305  				return nil
   306  			}
   307  			`,
   308  		"logical or second": `package a
   309  			
   310  			import "fmt"
   311  			
   312  			func a() error {
   313  				_, err := fmt.Println()
   314  				if 1 == 1 || err == nil {
   315  					return err
   316  				} else {
   317  					return err // *
   318  				}
   319  				return nil
   320  			}
   321  			`,
   322  		"logical or third": `package a
   323  			
   324  			import "fmt"
   325  			
   326  			func a() error {
   327  				_, err := fmt.Println()
   328  				if 1 == 1 || 2 == 2 || err == nil {
   329  					return err
   330  				} else {
   331  					return err // *
   332  				}
   333  				return nil
   334  			}
   335  			`,
   336  		"logical or brackets": `package a
   337  			
   338  			import "fmt"
   339  			
   340  			func a() error {
   341  				_, err := fmt.Println()
   342  				if 1 == 1 || (2 == 2 || err == nil) {
   343  					return err
   344  				} else {
   345  					return err // *
   346  				}
   347  				return nil
   348  			}
   349  			`,
   350  		"complex": `package a
   351  		
   352  			func foo() error {
   353  				var err error
   354  				var b, c bool
   355  				var d int
   356  				if err == nil && (b && d > 0) || c {
   357  					return err
   358  				} else if d <= 0 || c {
   359  					return err
   360  				} else if b {
   361  					return err // *
   362  				}
   363  				return err
   364  			}
   365  		`,
   366  	}
   367  	test(t, tests)
   368  }
   369  
   370  func TestGeneral(t *testing.T) {
   371  	tests := map[string]string{
   372  		"simple": `package a
   373  			
   374  			import "fmt"
   375  			
   376  			func a() error {
   377  				_, err := fmt.Println()
   378  				if err != nil {
   379  					return err // *
   380  				}
   381  				return nil
   382  			}
   383  			`,
   384  		"wrong way round": `package a
   385  			
   386  			import "fmt"
   387  			
   388  			func a() error {
   389  				_, err := fmt.Println()
   390  				if nil != err {
   391  					return err // *
   392  				}
   393  				return nil
   394  			}
   395  			`,
   396  		"not else block": `package a
   397  			
   398  			import "fmt"
   399  			
   400  			func a() error {
   401  				_, err := fmt.Println()
   402  				if err != nil {
   403  					return err // *
   404  				} else {
   405  					return err
   406  				}
   407  				return nil
   408  			}
   409  			`,
   410  		"any name": `package a
   411  			
   412  			import "fmt"
   413  			
   414  			func a() error {
   415  				_, foo := fmt.Println()
   416  				if foo != nil {
   417  					return foo // *
   418  				}
   419  				return nil
   420  			}
   421  			`,
   422  		"don't mark if ==": `package a
   423  			
   424  			import "fmt"
   425  			
   426  			func a() error {
   427  				_, err := fmt.Println()
   428  				if err == nil {
   429  					return err
   430  				}
   431  				return nil
   432  			}
   433  			`,
   434  		"use else block if err == nil": `package a
   435  			
   436  			import "fmt"
   437  			
   438  			func a() error {
   439  				_, err := fmt.Println()
   440  				if err == nil {
   441  					return err
   442  				} else {
   443  					return err // *
   444  				}
   445  				return nil
   446  			}
   447  			`,
   448  		"support if with init form": `package a
   449  			
   450  			import "fmt"
   451  			
   452  			func a() error {
   453  				if _, err := fmt.Println(); err != nil {
   454  					return err // *
   455  				}
   456  				return nil
   457  			}
   458  			`,
   459  		"only in if block": `package foo
   460  			
   461  			import "fmt"
   462  			
   463  			func Baz() error {
   464  				return fmt.Errorf("foo")
   465  			}
   466  			`,
   467  	}
   468  	test(t, tests)
   469  }
   470  
   471  func TestZeroValues(t *testing.T) {
   472  	tests := map[string]string{
   473  		"only return if all other return vars are zero": `package a
   474  			
   475  			import "fmt"
   476  			
   477  			type iface interface{}
   478  			
   479  			type strct struct {
   480  				a int
   481  				b string
   482  			}
   483  			
   484  			func Foo() (iface, bool, int, string, float32, strct, strct, error) {
   485  				if _, err := fmt.Println(); err != nil {
   486  					return 1, false, 0, "", 0.0, strct{0, ""}, strct{a: 0, b: ""}, err
   487  				}
   488  				if _, err := fmt.Println(); err != nil {
   489  					return nil, true, 0, "", 0.0, strct{0, ""}, strct{a: 0, b: ""}, err
   490  				}
   491  				if _, err := fmt.Println(); err != nil {
   492  					return nil, false, 1, "", 0.0, strct{0, ""}, strct{a: 0, b: ""}, err
   493  				}
   494  				if _, err := fmt.Println(); err != nil {
   495  					return nil, false, 0, "a", 0.0, strct{0, ""}, strct{a: 0, b: ""}, err
   496  				}
   497  				if _, err := fmt.Println(); err != nil {
   498  					return nil, false, 0, "", 1.0, strct{0, ""}, strct{a: 0, b: ""}, err
   499  				}
   500  				if _, err := fmt.Println(); err != nil {
   501  					return nil, false, 0, "", 0.0, strct{1, ""}, strct{a: 0, b: ""}, err
   502  				}
   503  				if _, err := fmt.Println(); err != nil {
   504  					return nil, false, 0, "", 0.0, strct{0, "a"}, strct{a: 0, b: ""}, err
   505  				}
   506  				if _, err := fmt.Println(); err != nil {
   507  					return nil, false, 0, "", 0.0, strct{0, ""}, strct{a: 1, b: ""}, err
   508  				}
   509  				if _, err := fmt.Println(); err != nil {
   510  					return nil, false, 0, "", 0.0, strct{0, ""}, strct{a: 0, b: "a"}, err
   511  				}
   512  				if _, err := fmt.Println(); err != nil {
   513  					return nil, false, 0, "", 0.0, strct{0, ""}, strct{a: 0, b: ""}, err // *
   514  				}
   515  				return nil, false, 0, "", 0.0, strct{0, ""}, strct{a: 0, b: ""}, nil
   516  			}
   517  			`,
   518  	}
   519  	test(t, tests)
   520  }
   521  
   522  func TestSelectorExpressions(t *testing.T) {
   523  	tests := map[string]string{
   524  		"selector expression": `package foo
   525  			
   526  			func Baz() error { 
   527  				type T struct {
   528  					Err error
   529  				}
   530  				var b T
   531  				if b.Err != nil {   
   532  					return b.Err // *
   533  				}
   534  				return nil
   535  			}
   536  			`,
   537  	}
   538  	test(t, tests)
   539  }
   540  
   541  func TestFunctionExpressions(t *testing.T) {
   542  	tests := map[string]string{
   543  		"function expression": `package foo
   544  			
   545  			func Baz() error { 
   546  				var f func(int) error
   547  				if f(5) != nil {   
   548  					return f(5) // *
   549  				}
   550  				return nil
   551  			}
   552  			`,
   553  		"function expression params": `package foo
   554  			
   555  			func Baz() error { 
   556  				var f func(int) error
   557  				if f(4) != nil {   
   558  					return f(5)
   559  				}
   560  				return nil
   561  			}
   562  			`,
   563  		"function expression params 2": `package foo
   564  			
   565  			func Baz() error { 
   566  				var f func(...int) error
   567  				if f(4) != nil {   
   568  					return f(4, 4)
   569  				}
   570  				return nil
   571  			}
   572  			`,
   573  		"function expression elipsis": `package foo
   574  			
   575  			func Baz() error { 
   576  				var f func(...interface{}) error
   577  				var a []interface{}
   578  				if f(a) != nil {   
   579  					return f(a...)
   580  				}
   581  				return nil
   582  			}
   583  			`,
   584  		"function expression elipsis 2": `package foo
   585  			
   586  			func Baz() error { 
   587  				var f func(...interface{}) error
   588  				var a []interface{}
   589  				if f(a) != nil {   
   590  					return f(a) // *
   591  				}
   592  				return nil
   593  			}
   594  			`,
   595  	}
   596  	test(t, tests)
   597  }
   598  
   599  func TestPanic(t *testing.T) {
   600  	tests := map[string]string{
   601  		"panic": `package foo
   602  			
   603  			func Baz() error {
   604  				panic("") // *
   605  			}
   606  			`,
   607  	}
   608  	test(t, tests)
   609  }
   610  
   611  func TestComments(t *testing.T) {
   612  	tests := map[string]string{
   613  		"scope": `package foo
   614  			
   615  			func Baz() int { 
   616  				i := 1       
   617  				if i > 1 {   
   618  					return i 
   619  				}            
   620  				             
   621  				//notest
   622  				             // *
   623  				if i > 2 {   // *
   624  					return i // *
   625  				}            // *
   626  				return 0     // *
   627  			}
   628  			`,
   629  		"scope if": `package foo
   630  			
   631  			func Baz(i int) int { 
   632  				if i > 2 {
   633  					//notest
   634  					return i // *
   635  				}
   636  				return 0
   637  			}
   638  			`,
   639  		"scope file": `package foo
   640  			
   641  			//notest
   642  			                      // *
   643  			func Baz(i int) int { // *
   644  				if i > 2 {        // *
   645  					return i      // *
   646  				}                 // *
   647  				return 0          // *
   648  			}                     // *
   649  			                      // *
   650  			func Foo(i int) int { // *
   651  				return 0          // *
   652  			}
   653  			`,
   654  		"complex comments": `package foo
   655  			
   656  			type Logger struct {
   657  				Enabled bool
   658  			}
   659  			func (l Logger) Print(i ...interface{}) {}
   660  			
   661  			func Foo() {
   662  				var logger Logger
   663  				var tokens []interface{}
   664  				if logger.Enabled {
   665  					// notest
   666  					for i, token := range tokens {        // *
   667  						logger.Print("[", i, "] ", token) // *
   668  					}                                     // *
   669  				}
   670  			}
   671  			`,
   672  		"case block": `package foo
   673  			
   674  			func Foo() bool {
   675  				switch {
   676  				case true:
   677  					// notest
   678  					if true {       // *
   679  						return true // *
   680  					}               // *
   681  					return false    // *
   682  				}
   683  				return false
   684  			}
   685  			`,
   686  		"case block with explanation comment": `package foo
   687  			
   688  			func Foo() bool {
   689  				switch {
   690  				case true:
   691  					// notest // this condition is always true
   692  					if true {       // *
   693  						return true // *
   694  					}               // *
   695  					return false    // *
   696  				}
   697  				return false
   698  			}
   699  			`,
   700  	}
   701  	test(t, tests)
   702  }
   703  
   704  func test(t *testing.T, tests map[string]string) {
   705  	for name, source := range tests {
   706  		env := vos.Mock()
   707  		b, err := builder.New(env, "ns", true)
   708  		if err != nil {
   709  			t.Fatalf("Error creating builder in %s: %+v", name, err)
   710  		}
   711  		defer b.Cleanup()
   712  
   713  		ppath, pdir, err := b.Package("a", map[string]string{
   714  			"a.go": source,
   715  		})
   716  		if err != nil {
   717  			t.Fatalf("Error creating package in %s: %+v", name, err)
   718  		}
   719  
   720  		paths := patsy.NewCache(env)
   721  		setup := &shared.Setup{
   722  			Env:   env,
   723  			Paths: paths,
   724  		}
   725  		if err := setup.Parse([]string{ppath}); err != nil {
   726  			t.Fatalf("Error parsing args in %s: %+v", name, err)
   727  		}
   728  
   729  		cm := scanner.New(setup)
   730  
   731  		if err := cm.LoadProgram(); err != nil {
   732  			t.Fatalf("Error loading program in %s: %+v", name, err)
   733  		}
   734  
   735  		if err := cm.ScanPackages(); err != nil {
   736  			t.Fatalf("Error scanning packages in %s: %+v", name, err)
   737  		}
   738  
   739  		result := cm.Excludes[filepath.Join(pdir, "a.go")]
   740  
   741  		// matches strings like:
   742  		//   - //notest$
   743  		//   - // notest$
   744  		//   - //notest // because this is glue code$
   745  		//   - // notest // because this is glue code$
   746  		notest := regexp.MustCompile("//\\s?notest(\\s//\\s?.*)?$")
   747  
   748  		for i, line := range strings.Split(source, "\n") {
   749  			expected := strings.HasSuffix(line, "// *") || notest.MatchString(line)
   750  			if result[i+1] != expected {
   751  				t.Fatalf("Unexpected state in %s, line %d: %s\n", name, i, strconv.Quote(strings.Trim(line, "\t")))
   752  			}
   753  		}
   754  	}
   755  }