github.com/stackb/rules_proto@v0.0.0-20240221195024-5428336c51f1/pkg/rule/rules_scala/scala_library_test.go (about) 1 package rules_scala 2 3 import ( 4 "strings" 5 "testing" 6 7 "github.com/bazelbuild/bazel-gazelle/config" 8 "github.com/bazelbuild/bazel-gazelle/label" 9 "github.com/bazelbuild/bazel-gazelle/resolve" 10 "github.com/bazelbuild/bazel-gazelle/rule" 11 "github.com/google/go-cmp/cmp" 12 "github.com/stackb/rules_proto/pkg/protoc" 13 ) 14 15 // TestGetJavaPackageOption shows that an import named in (scalapb.options) works as expected. 16 func TestGetJavaPackageOption(t *testing.T) { 17 for name, tc := range map[string]struct { 18 in string 19 want string 20 }{ 21 "degenerate case": {}, 22 "with go_package": { 23 in: `syntax="proto3"; option go_package="com.foo";`, 24 }, 25 "with java_package": { 26 in: `syntax="proto3"; option java_package="com.foo";`, 27 want: "com.foo", 28 }, 29 } { 30 t.Run(name, func(t *testing.T) { 31 file := protoc.NewFile("", "test.proto") 32 if err := file.ParseReader(strings.NewReader(tc.in)); err != nil { 33 t.Fatal("parse file:", err) 34 } 35 got, ok := javaPackageOption(file.Options()) 36 if ok { 37 if diff := cmp.Diff(tc.want, got); diff != "" { 38 t.Errorf("TestGetScalaImports() mismatch (-want +got):\n%s", diff) 39 } 40 } else { 41 if tc.want != "" { 42 t.Errorf("TestGetScalaImports() unexpected miss: %v", tc) 43 } 44 } 45 }) 46 } 47 } 48 49 // TestParseScalaImportNamedLiteral asserts the ability to parse 50 // a subset of scala import expressions. 51 func TestParseScalaImportNamedLiteral(t *testing.T) { 52 for name, tc := range map[string]struct { 53 imp string 54 want []string 55 }{ 56 "degenerate": { 57 want: []string{""}, 58 }, 59 "single import": { 60 imp: "a.b.c.Foo", 61 want: []string{"a.b.c.Foo"}, 62 }, 63 "multiple import": { 64 imp: "a.b.c.{Foo,Bar}", 65 want: []string{"a.b.c.Foo", "a.b.c.Bar"}, 66 }, 67 "multiple import +ws": { 68 imp: "a.b.c.{ Foo , Bar } ", 69 want: []string{"a.b.c.Foo", "a.b.c.Bar"}, 70 }, 71 "alias import": { 72 imp: "a.b.c.{ Foo => Fog , Bar => Baz }", 73 want: []string{"a.b.c.Foo", "a.b.c.Bar"}, 74 }, 75 } { 76 t.Run(name, func(t *testing.T) { 77 got := parseScalaImportNamedLiteral(tc.imp) 78 if diff := cmp.Diff(tc.want, got); diff != "" { 79 t.Errorf("(-want +got):\n%s", diff) 80 } 81 }) 82 } 83 } 84 85 // TestGetScalapbImports shows that an import named in (scalapb.options) works as expected. 86 func TestGetScalapbImports(t *testing.T) { 87 for name, tc := range map[string]struct { 88 // in is a mapping of source filename to content 89 in map[string]string 90 want []string 91 }{ 92 "degenerate case": { 93 want: []string{}, 94 }, 95 "without imports": { 96 in: map[string]string{ 97 "foo.proto": `syntax = "proto3"; 98 message Thing {}`, 99 }, 100 want: []string{}, 101 }, 102 "with scalapb import": { 103 in: map[string]string{ 104 "foo.proto": `syntax = "proto3"; 105 import "scalapb/scalapb.proto"; 106 107 option (scalapb.options) = { 108 import: "corp.common.utils.WithORM" 109 };`, 110 }, 111 want: []string{"corp.common.utils.WithORM"}, 112 }, 113 "with scalapb import (aliased)": { 114 in: map[string]string{ 115 "foo.proto": `syntax = "proto3"; 116 import "scalapb/scalapb.proto"; 117 118 option (scalapb.options) = { 119 import: "corp.common.utils.{WithORM => WithORMAlias}" 120 };`, 121 }, 122 want: []string{"corp.common.utils.WithORM"}, 123 }, 124 125 "with field type": { 126 in: map[string]string{ 127 "foo.proto": ` 128 syntax = "proto2"; 129 130 import "thirdparty/protobuf/scalapb/scalapb.proto"; 131 132 message TraderId { 133 required int32 trader_id = 1 [(scalapb.field).type = "corp.common.utils.TraderId"]; 134 } 135 136 message TeamId { 137 required int32 team_id = 1 [(scalapb.field).type = "corp.common.utils.TeamId"]; 138 } 139 `, 140 }, 141 want: []string{"corp.common.utils.TeamId", "corp.common.utils.TraderId"}, 142 }, 143 } { 144 t.Run(name, func(t *testing.T) { 145 files := make([]*protoc.File, len(tc.in)) 146 i := 0 147 for name, content := range tc.in { 148 file := protoc.NewFile("", name) 149 if err := file.ParseReader(strings.NewReader(content)); err != nil { 150 t.Fatal("parse file:", name, err) 151 } 152 files[i] = file 153 i++ 154 } 155 got := getScalapbImports(files) 156 if diff := cmp.Diff(tc.want, got); diff != "" { 157 t.Errorf("TestGetScalaImports() mismatch (-want +got):\n%s", diff) 158 } 159 }) 160 } 161 } 162 163 // TestProvideScalaImports shows the imports provided. 164 func TestProvideScalaImports(t *testing.T) { 165 for name, tc := range map[string]struct { 166 // in is a mapping of source filename to content 167 in map[string]string 168 // options is a mapping of protoc options 169 options map[string]bool 170 want []resolve.ImportSpec 171 }{ 172 "degenerate case": {}, 173 "message": { 174 in: map[string]string{ 175 "foo.proto": `syntax = "proto3"; 176 message Thing {}`, 177 }, 178 want: []resolve.ImportSpec{ 179 {Lang: "message", Imp: "Thing"}, 180 {Lang: "message", Imp: "ThingProto"}, 181 }, 182 }, 183 "service": { 184 in: map[string]string{ 185 "foo.proto": `syntax = "proto3"; 186 service Thinger {}`, 187 }, 188 want: []resolve.ImportSpec{ 189 {Lang: "service", Imp: "Thinger"}, 190 {Lang: "service", Imp: "ThingerGrpc"}, 191 {Lang: "service", Imp: "ThingerProto"}, 192 {Lang: "service", Imp: "ThingerClient"}, 193 {Lang: "service", Imp: "ThingerHandler"}, 194 {Lang: "service", Imp: "ThingerServer"}, 195 {Lang: "service", Imp: "ThingerPowerApi"}, 196 {Lang: "service", Imp: "ThingerPowerApiHandler"}, 197 {Lang: "service", Imp: "ThingerClientPowerApi"}, 198 }, 199 }, 200 "enum": { 201 in: map[string]string{ 202 "foo.proto": `syntax = "proto3"; 203 enum Things {}`, 204 }, 205 want: []resolve.ImportSpec{ 206 {Lang: "enum", Imp: "Things"}, 207 }, 208 }, 209 } { 210 t.Run(name, func(t *testing.T) { 211 files := make([]*protoc.File, len(tc.in)) 212 i := 0 213 for name, content := range tc.in { 214 file := protoc.NewFile("", name) 215 if err := file.ParseReader(strings.NewReader(content)); err != nil { 216 t.Fatal("parse file:", name, err) 217 } 218 files[i] = file 219 i++ 220 } 221 resolver := &fakeImportResolver{} 222 from := label.New("repo", "dir", "name") 223 224 provideScalaImports(files, resolver, from, tc.options) 225 if diff := cmp.Diff(tc.want, resolver.got); diff != "" { 226 t.Errorf("TestGetScalaImports() mismatch (-want +got):\n%s", diff) 227 } 228 }) 229 } 230 } 231 232 type fakeImportResolver struct { 233 got []resolve.ImportSpec 234 } 235 236 func (r *fakeImportResolver) Imports(lang, impLang string, visitor func(imp string, location []label.Label) bool) { 237 panic("not implemented") 238 } 239 240 func (r *fakeImportResolver) Resolve(lang, impLang, imp string) []resolve.FindResult { 241 panic("not implemented") 242 } 243 244 func (r *fakeImportResolver) Provide(lang, impLang, imp string, from label.Label) { 245 r.got = append(r.got, resolve.ImportSpec{Imp: imp, Lang: impLang}) 246 } 247 248 func TestScalaLibraryOptionsNoResolve(t *testing.T) { 249 for name, tc := range map[string]struct { 250 args []string 251 imports []string 252 want []string 253 }{ 254 "degenerate case": {}, 255 "prototypical": { 256 args: []string{"--noresolve=scalapb/scalapb.proto"}, 257 imports: []string{"scalapb/scalapb.proto", "google/protobuf/any.proto"}, 258 want: []string{"google/protobuf/any.proto"}, 259 }, 260 "csv": { 261 args: []string{"--noresolve=a.proto,b.proto"}, 262 imports: []string{"a.proto", "b.proto"}, 263 want: nil, 264 }, 265 } { 266 t.Run(name, func(t *testing.T) { 267 options := parseScalaLibraryOptions("proto_scala_library", tc.args) 268 got := options.filterImports(tc.imports) 269 270 if diff := cmp.Diff(tc.want, got); diff != "" { 271 t.Errorf("(-want +got):\n%s", diff) 272 } 273 }) 274 } 275 } 276 277 func TestScalaLibraryOptionsNoOutput(t *testing.T) { 278 for name, tc := range map[string]struct { 279 args []string 280 outputs []string 281 want []string 282 }{ 283 "degenerate case": {}, 284 "prototypical": { 285 args: []string{"--exclude=package_scala.srcjar"}, 286 outputs: []string{"package_scala.srcjar"}, 287 want: nil, 288 }, 289 "csv": { 290 args: []string{"--exclude=a.srcjar,b.srcjar"}, 291 outputs: []string{"a.srcjar", "b.srcjar"}, 292 want: nil, 293 }, 294 "pattern": { 295 args: []string{"--exclude=**/*.srcjar"}, 296 outputs: []string{"a.srcjar", "lib/b.srcjar", "lib/c.jar"}, 297 want: []string{"lib/c.jar"}, 298 }, 299 } { 300 t.Run(name, func(t *testing.T) { 301 options := parseScalaLibraryOptions("proto_scala_library", tc.args) 302 got := options.filterOutputs(tc.outputs) 303 304 if diff := cmp.Diff(tc.want, got); diff != "" { 305 t.Errorf("(-want +got):\n%s", diff) 306 } 307 }) 308 } 309 } 310 311 func TestResolveScalaDeps(t *testing.T) { 312 for name, tc := range map[string]struct { 313 overrideFn findRuleWithOverride 314 byImportFn findRulesByImportWithConfig 315 r *rule.Rule 316 from label.Label 317 unresolvedDeps map[string]error 318 wantUnresolvedDeps map[string]error 319 wantDeps []string 320 }{ 321 "degenerate case": { 322 overrideFn: func(c *config.Config, imp resolve.ImportSpec, lang string) (label.Label, bool) { 323 return label.NoLabel, false 324 }, 325 byImportFn: func(c *config.Config, imp resolve.ImportSpec, lang string) []resolve.FindResult { 326 return nil 327 }, 328 wantUnresolvedDeps: map[string]error{}, 329 }, 330 "resolve from cross-resolver": { 331 from: label.New("", "proto", "foo_proto_scala_library"), 332 overrideFn: func(c *config.Config, imp resolve.ImportSpec, lang string) (label.Label, bool) { 333 return label.NoLabel, false 334 }, 335 byImportFn: func(c *config.Config, imp resolve.ImportSpec, lang string) []resolve.FindResult { 336 if lang == "scala" && imp.Imp == "foo.bar.baz.mapper" { 337 return []resolve.FindResult{{Label: label.New("", "mapper", "scala_lib")}} 338 } 339 return nil 340 }, 341 unresolvedDeps: map[string]error{ 342 "foo.bar.baz.mapper": protoc.ErrNoLabel, 343 }, 344 wantUnresolvedDeps: map[string]error{}, 345 wantDeps: []string{"//mapper:scala_lib"}, 346 }, 347 "resolve from overrideFn": { 348 from: label.New("", "proto", "foo_proto_scala_library"), 349 overrideFn: func(c *config.Config, imp resolve.ImportSpec, lang string) (label.Label, bool) { 350 if imp.Lang == "scala" && imp.Imp == "foo.bar.baz.mapper" { 351 return label.New("", "mapper", "scala_lib"), true 352 } 353 return label.NoLabel, false 354 }, 355 byImportFn: func(c *config.Config, imp resolve.ImportSpec, lang string) []resolve.FindResult { 356 return nil 357 }, 358 unresolvedDeps: map[string]error{ 359 "foo.bar.baz.mapper": protoc.ErrNoLabel, 360 }, 361 wantUnresolvedDeps: map[string]error{}, 362 wantDeps: []string{"//mapper:scala_lib"}, 363 }, 364 "does not resolve self-label": { 365 from: label.New("", "proto", "foo_proto_scala_library"), 366 overrideFn: func(c *config.Config, imp resolve.ImportSpec, lang string) (label.Label, bool) { 367 if imp.Lang == "scala" && imp.Imp == "foo.bar.baz.mapper" { 368 return label.New("", "proto", "foo_proto_scala_library"), true 369 } 370 return label.NoLabel, false 371 }, 372 byImportFn: func(c *config.Config, imp resolve.ImportSpec, lang string) []resolve.FindResult { 373 return nil 374 }, 375 unresolvedDeps: map[string]error{ 376 "foo.bar.baz.mapper": protoc.ErrNoLabel, 377 }, 378 wantUnresolvedDeps: map[string]error{}, 379 wantDeps: nil, 380 }, 381 } { 382 t.Run(name, func(t *testing.T) { 383 c := &config.Config{} 384 r := rule.NewRule("proto_scala_library", "bar_proto_scala_library") 385 386 gotUnresolvedDeps := make(map[string]error) 387 for k, v := range tc.unresolvedDeps { 388 gotUnresolvedDeps[k] = v 389 } 390 resolveScalaDeps(tc.overrideFn, tc.byImportFn, c, r, gotUnresolvedDeps, tc.from) 391 392 gotDeps := r.AttrStrings("deps") 393 394 if diff := cmp.Diff(tc.wantDeps, gotDeps); diff != "" { 395 t.Errorf("deps (-want +got):\n%s", diff) 396 } 397 if diff := cmp.Diff(tc.wantUnresolvedDeps, gotUnresolvedDeps); diff != "" { 398 t.Errorf("unresolved deps (-want +got):\n%s", diff) 399 } 400 }) 401 } 402 } 403 404 type fakeCrossResolver struct { 405 result []resolve.FindResult 406 } 407 408 func (cr *fakeCrossResolver) CrossResolve(c *config.Config, ix *resolve.RuleIndex, imp resolve.ImportSpec, lang string) []resolve.FindResult { 409 return cr.result 410 }