github.com/0xKiwi/rules_go@v0.24.3/tests/core/nogo/custom/custom_test.go (about)

     1  // Copyright 2019 The Bazel Authors. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //    http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package custom_test
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"io/ioutil"
    21  	"regexp"
    22  	"testing"
    23  
    24  	"github.com/bazelbuild/rules_go/go/tools/bazel_testing"
    25  )
    26  
    27  const origConfig = `# config = "",`
    28  
    29  func TestMain(m *testing.M) {
    30  	bazel_testing.TestMain(m, bazel_testing.Args{
    31  		Nogo: "@//:nogo",
    32  		Main: `
    33  -- BUILD.bazel --
    34  load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_tool_library", "nogo")
    35  
    36  nogo(
    37      name = "nogo",
    38      deps = [
    39          ":foofuncname",
    40          ":importfmt",
    41          ":visibility",
    42      ],
    43      # config = "",
    44      visibility = ["//visibility:public"],
    45  )
    46  
    47  go_tool_library(
    48      name = "importfmt",
    49      srcs = ["importfmt.go"],
    50      importpath = "importfmtanalyzer",
    51      deps = ["@org_golang_x_tools//go/analysis:go_tool_library"],
    52      visibility = ["//visibility:public"],
    53  )
    54  
    55  go_tool_library(
    56      name = "foofuncname",
    57      srcs = ["foofuncname.go"],
    58      importpath = "foofuncanalyzer",
    59      deps = ["@org_golang_x_tools//go/analysis:go_tool_library"],
    60      visibility = ["//visibility:public"],
    61  )
    62  
    63  go_tool_library(
    64      name = "visibility",
    65      srcs = ["visibility.go"],
    66      importpath = "visibilityanalyzer",
    67      deps = [
    68          "@org_golang_x_tools//go/analysis:go_tool_library",
    69          "@org_golang_x_tools//go/ast/inspector:go_tool_library",
    70      ],
    71      visibility = ["//visibility:public"],
    72  )
    73  
    74  go_library(
    75      name = "has_errors",
    76      srcs = ["has_errors.go"],
    77      importpath = "haserrors",
    78      deps = [":dep"],
    79  )
    80  
    81  go_library(
    82      name = "no_errors",
    83      srcs = ["no_errors.go"],
    84      importpath = "noerrors",
    85      deps = [":dep"],
    86  )
    87  
    88  go_library(
    89      name = "dep",
    90      srcs = ["dep.go"],
    91      importpath = "dep",
    92  )
    93  
    94  -- foofuncname.go --
    95  // importfmt checks for functions named "Foo".
    96  // It has the same package name as another check to test the checks with
    97  // the same package name do not conflict.
    98  package importfmt
    99  
   100  import (
   101  	"go/ast"
   102  
   103  	"golang.org/x/tools/go/analysis"
   104  )
   105  
   106  const doc = "report calls of functions named \"Foo\"\n\nThe foofuncname analyzer reports calls to functions that are\nnamed \"Foo\"."
   107  
   108  var Analyzer = &analysis.Analyzer{
   109  	Name: "foofuncname",
   110  	Run:  run,
   111  	Doc:  doc,
   112  }
   113  
   114  func run(pass *analysis.Pass) (interface{}, error) {
   115  	for _, f := range pass.Files {
   116  		// TODO(samueltan): use package inspector once the latest golang.org/x/tools
   117  		// changes are pulled into this branch  (see #1755).
   118  		ast.Inspect(f, func(n ast.Node) bool {
   119  			switch n := n.(type) {
   120  			case *ast.FuncDecl:
   121  				if n.Name.Name == "Foo" {
   122  					pass.Reportf(n.Pos(), "function must not be named Foo")
   123  				}
   124  				return true
   125  			}
   126  			return true
   127  		})
   128  	}
   129  	return nil, nil
   130  }
   131  
   132  -- importfmt.go --
   133  // importfmt checks for the import of package fmt.
   134  package importfmt
   135  
   136  import (
   137  	"go/ast"
   138  	"strconv"
   139  
   140  	"golang.org/x/tools/go/analysis"
   141  )
   142  
   143  const doc = "report imports of package fmt\n\nThe importfmt analyzer reports imports of package fmt."
   144  
   145  var Analyzer = &analysis.Analyzer{
   146  	Name: "importfmt",
   147  	Run:  run,
   148  	Doc:  doc,
   149  }
   150  
   151  func run(pass *analysis.Pass) (interface{}, error) {
   152  	for _, f := range pass.Files {
   153  		// TODO(samueltan): use package inspector once the latest golang.org/x/tools
   154  		// changes are pulled into this branch (see #1755).
   155  		ast.Inspect(f, func(n ast.Node) bool {
   156  			switch n := n.(type) {
   157  			case *ast.ImportSpec:
   158  				if path, _ := strconv.Unquote(n.Path.Value); path == "fmt" {
   159  					pass.Reportf(n.Pos(), "package fmt must not be imported")
   160  				}
   161  				return true
   162  			}
   163  			return true
   164  		})
   165  	}
   166  	return nil, nil
   167  }
   168  
   169  -- visibility.go --
   170  // visibility looks for visibility annotations on functions and
   171  // checks they are only called from packages allowed to call them.
   172  package visibility
   173  
   174  import (
   175  	"encoding/gob"
   176  	"go/ast"
   177  	"regexp"
   178  
   179  	"golang.org/x/tools/go/analysis"
   180  	"golang.org/x/tools/go/ast/inspector"
   181  )
   182  
   183  var Analyzer = &analysis.Analyzer{
   184  	Name: "visibility",
   185  	Run:  run,
   186  	Doc: "enforce visibility requirements for functions\n\nThe visibility analyzer reads visibility annotations on functions and\nchecks that packages that call those functions are allowed to do so.",
   187  	FactTypes: []analysis.Fact{(*VisibilityFact)(nil)},
   188  }
   189  
   190  type VisibilityFact struct {
   191  	Paths []string
   192  }
   193  
   194  func (_ *VisibilityFact) AFact() {} // dummy method to satisfy interface
   195  
   196  func init() { gob.Register((*VisibilityFact)(nil)) }
   197  
   198  var visibilityRegexp = regexp.MustCompile("visibility:([^\\s]+)")
   199  
   200  func run(pass *analysis.Pass) (interface{}, error) {
   201  	in := inspector.New(pass.Files)
   202  
   203  	// Find visibility annotations on function declarations.
   204  	in.Nodes([]ast.Node{(*ast.FuncDecl)(nil)}, func(n ast.Node, push bool) (prune bool) {
   205  		if !push {
   206  			return false
   207  		}
   208  
   209  		fn := n.(*ast.FuncDecl)
   210  
   211  		if fn.Doc == nil {
   212  			return true
   213  		}
   214  		obj := pass.TypesInfo.ObjectOf(fn.Name)
   215  		if obj == nil {
   216  			return true
   217  		}
   218  		doc := fn.Doc.Text()
   219  
   220  		if matches := visibilityRegexp.FindAllStringSubmatch(doc, -1); matches != nil {
   221  			fact := &VisibilityFact{Paths: make([]string, len(matches))}
   222  			for i, m := range matches {
   223  				fact.Paths[i] = m[1]
   224  			}
   225  			pass.ExportObjectFact(obj, fact)
   226  		}
   227  
   228  		return true
   229  	})
   230  
   231  	// Find calls that may be affected by visibility declarations.
   232  	in.Nodes([]ast.Node{(*ast.CallExpr)(nil)}, func(n ast.Node, push bool) (prune bool) {
   233  		if !push {
   234  			return false
   235  		}
   236  
   237  		callee, ok := n.(*ast.CallExpr).Fun.(*ast.SelectorExpr)
   238  		if !ok {
   239  			return false
   240  		}
   241  		obj := pass.TypesInfo.ObjectOf(callee.Sel)
   242  		if obj == nil {
   243  			return false
   244  		}
   245  		var fact VisibilityFact
   246  		if ok := pass.ImportObjectFact(obj, &fact); !ok {
   247  			return false
   248  		}
   249  		visible := false
   250  		for _, path := range fact.Paths {
   251  			if path == pass.Pkg.Path() {
   252  				visible = true
   253  				break
   254  			}
   255  		}
   256  		if !visible {
   257  			pass.Reportf(callee.Pos(), "function %s is not visible in this package", callee.Sel.Name)
   258  		}
   259  
   260  		return false
   261  	})
   262  
   263  	return nil, nil
   264  }
   265  
   266  -- config.json --
   267  {
   268    "importfmt": {
   269      "only_files": {
   270        "has_errors\\.go": ""
   271      }
   272    },
   273    "foofuncname": {
   274      "description": "no exemptions since we know this check is 100% accurate"
   275    },
   276    "visibility": {
   277      "exclude_files": {
   278        "has_.*\\.go": "special exception to visibility rules"
   279      }
   280    }
   281  }
   282  
   283  -- has_errors.go --
   284  package haserrors
   285  
   286  import (
   287  	_ "fmt" // This should fail importfmt
   288  
   289  	"dep"
   290  )
   291  
   292  func Foo() bool { // This should fail foofuncname
   293  	dep.D()     // This should fail visibility
   294  	return true // This should fail boolreturn
   295  }
   296  
   297  -- no_errors.go --
   298  // package noerrors contains no analyzer errors.
   299  package noerrors
   300  
   301  import "dep"
   302  
   303  func Baz() int {
   304  	dep.D()
   305  	return 1
   306  }
   307  
   308  -- dep.go --
   309  package dep
   310  
   311  // visibility:noerrors
   312  func D() {
   313  }
   314  
   315  `,
   316  	})
   317  }
   318  
   319  func Test(t *testing.T) {
   320  	for _, test := range []struct {
   321  		desc, config, target string
   322  		wantSuccess          bool
   323  		includes, excludes   []string
   324  	}{
   325  		{
   326  			desc:        "default_config",
   327  			target:      "//:has_errors",
   328  			wantSuccess: false,
   329  			includes: []string{
   330  				"has_errors.go:.*package fmt must not be imported",
   331  				"has_errors.go:.*function must not be named Foo",
   332  				"has_errors.go:.*function D is not visible in this package",
   333  			},
   334  		}, {
   335  			desc:        "custom_config",
   336  			target:      "//:has_errors",
   337  			wantSuccess: false,
   338  			includes: []string{
   339  				"has_errors.go:.*package fmt must not be imported",
   340  				"has_errors.go:.*function must not be named Foo",
   341  			},
   342  			excludes: []string{
   343  				"custom/has_errors.go:.*function D is not visible in this package",
   344  			},
   345  		}, {
   346  			desc:        "no_errors",
   347  			target:      "//:no_errors",
   348  			wantSuccess: true,
   349  			excludes:    []string{"no_errors.go"},
   350  		},
   351  	} {
   352  		t.Run(test.desc, func(t *testing.T) {
   353  			if test.config != "" {
   354  				customConfig := fmt.Sprintf("config = %q,", test.config)
   355  				if err := replaceInFile("BUILD.bazel", origConfig, customConfig); err != nil {
   356  					t.Fatal(err)
   357  				}
   358  				defer replaceInFile("BUILD.bazel", customConfig, origConfig)
   359  			}
   360  
   361  			cmd := bazel_testing.BazelCmd("build", test.target)
   362  			stderr := &bytes.Buffer{}
   363  			cmd.Stderr = stderr
   364  			if err := cmd.Run(); err == nil && !test.wantSuccess {
   365  				t.Fatal("unexpected success")
   366  			} else if err != nil && test.wantSuccess {
   367  				t.Fatalf("unexpected error: %v", err)
   368  			}
   369  
   370  			for _, pattern := range test.includes {
   371  				if matched, err := regexp.Match(pattern, stderr.Bytes()); err != nil {
   372  					t.Fatal(err)
   373  				} else if !matched {
   374  					t.Errorf("output did not contain pattern: %s", pattern)
   375  				}
   376  			}
   377  			for _, pattern := range test.excludes {
   378  				if matched, err := regexp.Match(pattern, stderr.Bytes()); err != nil {
   379  					t.Fatal(err)
   380  				} else if matched {
   381  					t.Errorf("output contained pattern: %s", pattern)
   382  				}
   383  			}
   384  		})
   385  	}
   386  }
   387  
   388  func replaceInFile(path, old, new string) error {
   389  	data, err := ioutil.ReadFile(path)
   390  	if err != nil {
   391  		return err
   392  	}
   393  	data = bytes.ReplaceAll(data, []byte(old), []byte(new))
   394  	return ioutil.WriteFile(path, data, 0666)
   395  }