github.com/stackb/rules_proto@v0.0.0-20240221195024-5428336c51f1/pkg/rule/rules_scala/scala_library_test.go (about)

     1  package rules_scala
     2  
     3  import (
     4  	"strings"
     5  	"testing"
     6  
     7  	"github.com/bazelbuild/bazel-gazelle/config"
     8  	"github.com/bazelbuild/bazel-gazelle/label"
     9  	"github.com/bazelbuild/bazel-gazelle/resolve"
    10  	"github.com/bazelbuild/bazel-gazelle/rule"
    11  	"github.com/google/go-cmp/cmp"
    12  	"github.com/stackb/rules_proto/pkg/protoc"
    13  )
    14  
    15  // TestGetJavaPackageOption shows that an import named in (scalapb.options) works as expected.
    16  func TestGetJavaPackageOption(t *testing.T) {
    17  	for name, tc := range map[string]struct {
    18  		in   string
    19  		want string
    20  	}{
    21  		"degenerate case": {},
    22  		"with go_package": {
    23  			in: `syntax="proto3"; option go_package="com.foo";`,
    24  		},
    25  		"with java_package": {
    26  			in:   `syntax="proto3"; option java_package="com.foo";`,
    27  			want: "com.foo",
    28  		},
    29  	} {
    30  		t.Run(name, func(t *testing.T) {
    31  			file := protoc.NewFile("", "test.proto")
    32  			if err := file.ParseReader(strings.NewReader(tc.in)); err != nil {
    33  				t.Fatal("parse file:", err)
    34  			}
    35  			got, ok := javaPackageOption(file.Options())
    36  			if ok {
    37  				if diff := cmp.Diff(tc.want, got); diff != "" {
    38  					t.Errorf("TestGetScalaImports() mismatch (-want +got):\n%s", diff)
    39  				}
    40  			} else {
    41  				if tc.want != "" {
    42  					t.Errorf("TestGetScalaImports() unexpected miss: %v", tc)
    43  				}
    44  			}
    45  		})
    46  	}
    47  }
    48  
    49  // TestParseScalaImportNamedLiteral asserts the ability to parse
    50  // a subset of scala import expressions.
    51  func TestParseScalaImportNamedLiteral(t *testing.T) {
    52  	for name, tc := range map[string]struct {
    53  		imp  string
    54  		want []string
    55  	}{
    56  		"degenerate": {
    57  			want: []string{""},
    58  		},
    59  		"single import": {
    60  			imp:  "a.b.c.Foo",
    61  			want: []string{"a.b.c.Foo"},
    62  		},
    63  		"multiple import": {
    64  			imp:  "a.b.c.{Foo,Bar}",
    65  			want: []string{"a.b.c.Foo", "a.b.c.Bar"},
    66  		},
    67  		"multiple import +ws": {
    68  			imp:  "a.b.c.{ Foo  , Bar  }  ",
    69  			want: []string{"a.b.c.Foo", "a.b.c.Bar"},
    70  		},
    71  		"alias import": {
    72  			imp:  "a.b.c.{ Foo => Fog , Bar => Baz }",
    73  			want: []string{"a.b.c.Foo", "a.b.c.Bar"},
    74  		},
    75  	} {
    76  		t.Run(name, func(t *testing.T) {
    77  			got := parseScalaImportNamedLiteral(tc.imp)
    78  			if diff := cmp.Diff(tc.want, got); diff != "" {
    79  				t.Errorf("(-want +got):\n%s", diff)
    80  			}
    81  		})
    82  	}
    83  }
    84  
    85  // TestGetScalapbImports shows that an import named in (scalapb.options) works as expected.
    86  func TestGetScalapbImports(t *testing.T) {
    87  	for name, tc := range map[string]struct {
    88  		// in is a mapping of source filename to content
    89  		in   map[string]string
    90  		want []string
    91  	}{
    92  		"degenerate case": {
    93  			want: []string{},
    94  		},
    95  		"without imports": {
    96  			in: map[string]string{
    97  				"foo.proto": `syntax = "proto3";
    98  message Thing {}`,
    99  			},
   100  			want: []string{},
   101  		},
   102  		"with scalapb import": {
   103  			in: map[string]string{
   104  				"foo.proto": `syntax = "proto3";
   105  import "scalapb/scalapb.proto";
   106  
   107  option (scalapb.options) = {
   108  	import: "corp.common.utils.WithORM"
   109  };`,
   110  			},
   111  			want: []string{"corp.common.utils.WithORM"},
   112  		},
   113  		"with scalapb import (aliased)": {
   114  			in: map[string]string{
   115  				"foo.proto": `syntax = "proto3";
   116  import "scalapb/scalapb.proto";
   117  
   118  option (scalapb.options) = {
   119  	import: "corp.common.utils.{WithORM => WithORMAlias}"
   120  };`,
   121  			},
   122  			want: []string{"corp.common.utils.WithORM"},
   123  		},
   124  
   125  		"with field type": {
   126  			in: map[string]string{
   127  				"foo.proto": `
   128  syntax = "proto2";
   129  
   130  import "thirdparty/protobuf/scalapb/scalapb.proto";
   131  
   132  message TraderId {
   133  	required int32 trader_id = 1 [(scalapb.field).type = "corp.common.utils.TraderId"];
   134  }
   135  
   136  message TeamId {
   137  	required int32 team_id = 1 [(scalapb.field).type = "corp.common.utils.TeamId"];
   138  }				
   139  `,
   140  			},
   141  			want: []string{"corp.common.utils.TeamId", "corp.common.utils.TraderId"},
   142  		},
   143  	} {
   144  		t.Run(name, func(t *testing.T) {
   145  			files := make([]*protoc.File, len(tc.in))
   146  			i := 0
   147  			for name, content := range tc.in {
   148  				file := protoc.NewFile("", name)
   149  				if err := file.ParseReader(strings.NewReader(content)); err != nil {
   150  					t.Fatal("parse file:", name, err)
   151  				}
   152  				files[i] = file
   153  				i++
   154  			}
   155  			got := getScalapbImports(files)
   156  			if diff := cmp.Diff(tc.want, got); diff != "" {
   157  				t.Errorf("TestGetScalaImports() mismatch (-want +got):\n%s", diff)
   158  			}
   159  		})
   160  	}
   161  }
   162  
   163  // TestProvideScalaImports shows the imports provided.
   164  func TestProvideScalaImports(t *testing.T) {
   165  	for name, tc := range map[string]struct {
   166  		// in is a mapping of source filename to content
   167  		in map[string]string
   168  		// options is a mapping of protoc options
   169  		options map[string]bool
   170  		want    []resolve.ImportSpec
   171  	}{
   172  		"degenerate case": {},
   173  		"message": {
   174  			in: map[string]string{
   175  				"foo.proto": `syntax = "proto3";
   176  message Thing {}`,
   177  			},
   178  			want: []resolve.ImportSpec{
   179  				{Lang: "message", Imp: "Thing"},
   180  				{Lang: "message", Imp: "ThingProto"},
   181  			},
   182  		},
   183  		"service": {
   184  			in: map[string]string{
   185  				"foo.proto": `syntax = "proto3";
   186  service Thinger {}`,
   187  			},
   188  			want: []resolve.ImportSpec{
   189  				{Lang: "service", Imp: "Thinger"},
   190  				{Lang: "service", Imp: "ThingerGrpc"},
   191  				{Lang: "service", Imp: "ThingerProto"},
   192  				{Lang: "service", Imp: "ThingerClient"},
   193  				{Lang: "service", Imp: "ThingerHandler"},
   194  				{Lang: "service", Imp: "ThingerServer"},
   195  				{Lang: "service", Imp: "ThingerPowerApi"},
   196  				{Lang: "service", Imp: "ThingerPowerApiHandler"},
   197  				{Lang: "service", Imp: "ThingerClientPowerApi"},
   198  			},
   199  		},
   200  		"enum": {
   201  			in: map[string]string{
   202  				"foo.proto": `syntax = "proto3";
   203  enum Things {}`,
   204  			},
   205  			want: []resolve.ImportSpec{
   206  				{Lang: "enum", Imp: "Things"},
   207  			},
   208  		},
   209  	} {
   210  		t.Run(name, func(t *testing.T) {
   211  			files := make([]*protoc.File, len(tc.in))
   212  			i := 0
   213  			for name, content := range tc.in {
   214  				file := protoc.NewFile("", name)
   215  				if err := file.ParseReader(strings.NewReader(content)); err != nil {
   216  					t.Fatal("parse file:", name, err)
   217  				}
   218  				files[i] = file
   219  				i++
   220  			}
   221  			resolver := &fakeImportResolver{}
   222  			from := label.New("repo", "dir", "name")
   223  
   224  			provideScalaImports(files, resolver, from, tc.options)
   225  			if diff := cmp.Diff(tc.want, resolver.got); diff != "" {
   226  				t.Errorf("TestGetScalaImports() mismatch (-want +got):\n%s", diff)
   227  			}
   228  		})
   229  	}
   230  }
   231  
   232  type fakeImportResolver struct {
   233  	got []resolve.ImportSpec
   234  }
   235  
   236  func (r *fakeImportResolver) Imports(lang, impLang string, visitor func(imp string, location []label.Label) bool) {
   237  	panic("not implemented")
   238  }
   239  
   240  func (r *fakeImportResolver) Resolve(lang, impLang, imp string) []resolve.FindResult {
   241  	panic("not implemented")
   242  }
   243  
   244  func (r *fakeImportResolver) Provide(lang, impLang, imp string, from label.Label) {
   245  	r.got = append(r.got, resolve.ImportSpec{Imp: imp, Lang: impLang})
   246  }
   247  
   248  func TestScalaLibraryOptionsNoResolve(t *testing.T) {
   249  	for name, tc := range map[string]struct {
   250  		args    []string
   251  		imports []string
   252  		want    []string
   253  	}{
   254  		"degenerate case": {},
   255  		"prototypical": {
   256  			args:    []string{"--noresolve=scalapb/scalapb.proto"},
   257  			imports: []string{"scalapb/scalapb.proto", "google/protobuf/any.proto"},
   258  			want:    []string{"google/protobuf/any.proto"},
   259  		},
   260  		"csv": {
   261  			args:    []string{"--noresolve=a.proto,b.proto"},
   262  			imports: []string{"a.proto", "b.proto"},
   263  			want:    nil,
   264  		},
   265  	} {
   266  		t.Run(name, func(t *testing.T) {
   267  			options := parseScalaLibraryOptions("proto_scala_library", tc.args)
   268  			got := options.filterImports(tc.imports)
   269  
   270  			if diff := cmp.Diff(tc.want, got); diff != "" {
   271  				t.Errorf("(-want +got):\n%s", diff)
   272  			}
   273  		})
   274  	}
   275  }
   276  
   277  func TestScalaLibraryOptionsNoOutput(t *testing.T) {
   278  	for name, tc := range map[string]struct {
   279  		args    []string
   280  		outputs []string
   281  		want    []string
   282  	}{
   283  		"degenerate case": {},
   284  		"prototypical": {
   285  			args:    []string{"--exclude=package_scala.srcjar"},
   286  			outputs: []string{"package_scala.srcjar"},
   287  			want:    nil,
   288  		},
   289  		"csv": {
   290  			args:    []string{"--exclude=a.srcjar,b.srcjar"},
   291  			outputs: []string{"a.srcjar", "b.srcjar"},
   292  			want:    nil,
   293  		},
   294  		"pattern": {
   295  			args:    []string{"--exclude=**/*.srcjar"},
   296  			outputs: []string{"a.srcjar", "lib/b.srcjar", "lib/c.jar"},
   297  			want:    []string{"lib/c.jar"},
   298  		},
   299  	} {
   300  		t.Run(name, func(t *testing.T) {
   301  			options := parseScalaLibraryOptions("proto_scala_library", tc.args)
   302  			got := options.filterOutputs(tc.outputs)
   303  
   304  			if diff := cmp.Diff(tc.want, got); diff != "" {
   305  				t.Errorf("(-want +got):\n%s", diff)
   306  			}
   307  		})
   308  	}
   309  }
   310  
   311  func TestResolveScalaDeps(t *testing.T) {
   312  	for name, tc := range map[string]struct {
   313  		overrideFn         findRuleWithOverride
   314  		byImportFn         findRulesByImportWithConfig
   315  		r                  *rule.Rule
   316  		from               label.Label
   317  		unresolvedDeps     map[string]error
   318  		wantUnresolvedDeps map[string]error
   319  		wantDeps           []string
   320  	}{
   321  		"degenerate case": {
   322  			overrideFn: func(c *config.Config, imp resolve.ImportSpec, lang string) (label.Label, bool) {
   323  				return label.NoLabel, false
   324  			},
   325  			byImportFn: func(c *config.Config, imp resolve.ImportSpec, lang string) []resolve.FindResult {
   326  				return nil
   327  			},
   328  			wantUnresolvedDeps: map[string]error{},
   329  		},
   330  		"resolve from cross-resolver": {
   331  			from: label.New("", "proto", "foo_proto_scala_library"),
   332  			overrideFn: func(c *config.Config, imp resolve.ImportSpec, lang string) (label.Label, bool) {
   333  				return label.NoLabel, false
   334  			},
   335  			byImportFn: func(c *config.Config, imp resolve.ImportSpec, lang string) []resolve.FindResult {
   336  				if lang == "scala" && imp.Imp == "foo.bar.baz.mapper" {
   337  					return []resolve.FindResult{{Label: label.New("", "mapper", "scala_lib")}}
   338  				}
   339  				return nil
   340  			},
   341  			unresolvedDeps: map[string]error{
   342  				"foo.bar.baz.mapper": protoc.ErrNoLabel,
   343  			},
   344  			wantUnresolvedDeps: map[string]error{},
   345  			wantDeps:           []string{"//mapper:scala_lib"},
   346  		},
   347  		"resolve from overrideFn": {
   348  			from: label.New("", "proto", "foo_proto_scala_library"),
   349  			overrideFn: func(c *config.Config, imp resolve.ImportSpec, lang string) (label.Label, bool) {
   350  				if imp.Lang == "scala" && imp.Imp == "foo.bar.baz.mapper" {
   351  					return label.New("", "mapper", "scala_lib"), true
   352  				}
   353  				return label.NoLabel, false
   354  			},
   355  			byImportFn: func(c *config.Config, imp resolve.ImportSpec, lang string) []resolve.FindResult {
   356  				return nil
   357  			},
   358  			unresolvedDeps: map[string]error{
   359  				"foo.bar.baz.mapper": protoc.ErrNoLabel,
   360  			},
   361  			wantUnresolvedDeps: map[string]error{},
   362  			wantDeps:           []string{"//mapper:scala_lib"},
   363  		},
   364  		"does not resolve self-label": {
   365  			from: label.New("", "proto", "foo_proto_scala_library"),
   366  			overrideFn: func(c *config.Config, imp resolve.ImportSpec, lang string) (label.Label, bool) {
   367  				if imp.Lang == "scala" && imp.Imp == "foo.bar.baz.mapper" {
   368  					return label.New("", "proto", "foo_proto_scala_library"), true
   369  				}
   370  				return label.NoLabel, false
   371  			},
   372  			byImportFn: func(c *config.Config, imp resolve.ImportSpec, lang string) []resolve.FindResult {
   373  				return nil
   374  			},
   375  			unresolvedDeps: map[string]error{
   376  				"foo.bar.baz.mapper": protoc.ErrNoLabel,
   377  			},
   378  			wantUnresolvedDeps: map[string]error{},
   379  			wantDeps:           nil,
   380  		},
   381  	} {
   382  		t.Run(name, func(t *testing.T) {
   383  			c := &config.Config{}
   384  			r := rule.NewRule("proto_scala_library", "bar_proto_scala_library")
   385  
   386  			gotUnresolvedDeps := make(map[string]error)
   387  			for k, v := range tc.unresolvedDeps {
   388  				gotUnresolvedDeps[k] = v
   389  			}
   390  			resolveScalaDeps(tc.overrideFn, tc.byImportFn, c, r, gotUnresolvedDeps, tc.from)
   391  
   392  			gotDeps := r.AttrStrings("deps")
   393  
   394  			if diff := cmp.Diff(tc.wantDeps, gotDeps); diff != "" {
   395  				t.Errorf("deps (-want +got):\n%s", diff)
   396  			}
   397  			if diff := cmp.Diff(tc.wantUnresolvedDeps, gotUnresolvedDeps); diff != "" {
   398  				t.Errorf("unresolved deps (-want +got):\n%s", diff)
   399  			}
   400  		})
   401  	}
   402  }
   403  
   404  type fakeCrossResolver struct {
   405  	result []resolve.FindResult
   406  }
   407  
   408  func (cr *fakeCrossResolver) CrossResolve(c *config.Config, ix *resolve.RuleIndex, imp resolve.ImportSpec, lang string) []resolve.FindResult {
   409  	return cr.result
   410  }