github.com/bakjos/protoreflect@v1.9.2/desc/imports_test.go (about)

     1  package desc_test
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/golang/protobuf/protoc-gen-go/descriptor"
     7  
     8  	"github.com/bakjos/protoreflect/desc"
     9  	"github.com/bakjos/protoreflect/desc/protoparse"
    10  	_ "github.com/bakjos/protoreflect/internal/testprotos"
    11  	"github.com/bakjos/protoreflect/internal/testutil"
    12  )
    13  
    14  func TestResolveImport(t *testing.T) {
    15  	desc.RegisterImportPath("desc_test1.proto", "foobar/desc_test1.proto")
    16  	testutil.Eq(t, "desc_test1.proto", desc.ResolveImport("foobar/desc_test1.proto"))
    17  	testutil.Eq(t, "foobar/snafu.proto", desc.ResolveImport("foobar/snafu.proto"))
    18  
    19  	expectPanic(t, func() {
    20  		desc.RegisterImportPath("", "foobar/desc_test1.proto")
    21  	})
    22  	expectPanic(t, func() {
    23  		desc.RegisterImportPath("desc_test1.proto", "")
    24  	})
    25  	expectPanic(t, func() {
    26  		// not a real registered path
    27  		desc.RegisterImportPath("github.com/jhump/x/y/z/foobar.proto", "x/y/z/foobar.proto")
    28  	})
    29  }
    30  
    31  func TestImportResolver(t *testing.T) {
    32  	var r desc.ImportResolver
    33  
    34  	expectPanic(t, func() {
    35  		r.RegisterImportPath("", "a/b/c/d.proto")
    36  	})
    37  	expectPanic(t, func() {
    38  		r.RegisterImportPath("d.proto", "")
    39  	})
    40  
    41  	// no source constraints
    42  	r.RegisterImportPath("foo/bar.proto", "bar.proto")
    43  	testutil.Eq(t, "foo/bar.proto", r.ResolveImport("test.proto", "bar.proto"))
    44  	testutil.Eq(t, "foo/bar.proto", r.ResolveImport("some/other/source.proto", "bar.proto"))
    45  
    46  	// with specific source file
    47  	r.RegisterImportPathFrom("fubar/baz.proto", "baz.proto", "test/test.proto")
    48  	// match
    49  	testutil.Eq(t, "fubar/baz.proto", r.ResolveImport("test/test.proto", "baz.proto"))
    50  	// no match
    51  	testutil.Eq(t, "baz.proto", r.ResolveImport("test.proto", "baz.proto"))
    52  	testutil.Eq(t, "baz.proto", r.ResolveImport("test/test2.proto", "baz.proto"))
    53  	testutil.Eq(t, "baz.proto", r.ResolveImport("some/other/source.proto", "baz.proto"))
    54  
    55  	// with specific source file with long path
    56  	r.RegisterImportPathFrom("fubar/frobnitz/baz.proto", "baz.proto", "a/b/c/d/e/f/g/test/test.proto")
    57  	// match
    58  	testutil.Eq(t, "fubar/frobnitz/baz.proto", r.ResolveImport("a/b/c/d/e/f/g/test/test.proto", "baz.proto"))
    59  	// no match
    60  	testutil.Eq(t, "baz.proto", r.ResolveImport("test.proto", "baz.proto"))
    61  	testutil.Eq(t, "baz.proto", r.ResolveImport("test/test2.proto", "baz.proto"))
    62  	testutil.Eq(t, "baz.proto", r.ResolveImport("some/other/source.proto", "baz.proto"))
    63  
    64  	// with source path
    65  	r.RegisterImportPathFrom("fubar/frobnitz/snafu.proto", "frobnitz/snafu.proto", "a/b/c/d/e/f/g/h")
    66  	// match
    67  	testutil.Eq(t, "fubar/frobnitz/snafu.proto", r.ResolveImport("a/b/c/d/e/f/g/h/test/test.proto", "frobnitz/snafu.proto"))
    68  	testutil.Eq(t, "fubar/frobnitz/snafu.proto", r.ResolveImport("a/b/c/d/e/f/g/h/abc.proto", "frobnitz/snafu.proto"))
    69  	// no match
    70  	testutil.Eq(t, "frobnitz/snafu.proto", r.ResolveImport("a/b/c/d/e/f/g/test/test.proto", "frobnitz/snafu.proto"))
    71  	testutil.Eq(t, "frobnitz/snafu.proto", r.ResolveImport("test.proto", "frobnitz/snafu.proto"))
    72  	testutil.Eq(t, "frobnitz/snafu.proto", r.ResolveImport("test/test2.proto", "frobnitz/snafu.proto"))
    73  	testutil.Eq(t, "frobnitz/snafu.proto", r.ResolveImport("some/other/source.proto", "frobnitz/snafu.proto"))
    74  
    75  	// falls back to global registered paths
    76  	desc.RegisterImportPath("desc_test1.proto", "x/y/z/desc_test1.proto")
    77  	testutil.Eq(t, "desc_test1.proto", r.ResolveImport("a/b/c/d/e/f/g/h/test/test.proto", "x/y/z/desc_test1.proto"))
    78  }
    79  
    80  func TestImportResolver_CreateFileDescriptors(t *testing.T) {
    81  	p := protoparse.Parser{
    82  		Accessor: protoparse.FileContentsFromMap(map[string]string{
    83  			"foo/bar.proto": `
    84  				syntax = "proto3";
    85  				package foo;
    86  				message Bar {
    87  					string name = 1;
    88  					uint64 id = 2;
    89  				}
    90  				`,
    91  			// imports above file as just "bar.proto", so we need an
    92  			// import resolver to properly load and link
    93  			"fu/baz.proto": `
    94  				syntax = "proto3";
    95  				package fu;
    96  				import "bar.proto";
    97  				message Baz {
    98  					repeated foo.Bar foobar = 1;
    99  				}
   100  				`,
   101  		}),
   102  		ImportPaths: []string{"foo"},
   103  	}
   104  	fds, err := p.ParseFilesButDoNotLink("foo/bar.proto", "fu/baz.proto")
   105  	testutil.Ok(t, err)
   106  
   107  	// Since we didn't link, fu.Baz.foobar field in second file has no type
   108  	// (it can't know whether it's a message or enum until linking is done).
   109  	// So go ahead and fill in the correct type:
   110  	fds[1].MessageType[0].Field[0].Type = descriptor.FieldDescriptorProto_TYPE_MESSAGE.Enum()
   111  
   112  	// sanity check: make sure linking fails without an import resolver
   113  	_, err = desc.CreateFileDescriptors(fds)
   114  	testutil.Require(t, err != nil)
   115  	testutil.Eq(t, `no such file: "bar.proto"`, err.Error())
   116  
   117  	// now try again with resolver
   118  	var r desc.ImportResolver
   119  	r.RegisterImportPath("foo/bar.proto", "bar.proto")
   120  	linkedFiles, err := r.CreateFileDescriptors(fds)
   121  	// success!
   122  	testutil.Ok(t, err)
   123  
   124  	// quick check of the resulting files
   125  	fd := linkedFiles["foo/bar.proto"]
   126  	testutil.Require(t, fd != nil)
   127  	md := fd.FindMessage("foo.Bar")
   128  	testutil.Require(t, md != nil)
   129  
   130  	fd2 := linkedFiles["fu/baz.proto"]
   131  	testutil.Require(t, fd2 != nil)
   132  	md2 := fd2.FindMessage("fu.Baz")
   133  	testutil.Require(t, md2 != nil)
   134  	fld := md2.FindFieldByNumber(1)
   135  	testutil.Require(t, fld != nil)
   136  	testutil.Eq(t, md, fld.GetMessageType())
   137  	testutil.Eq(t, fd, fd2.GetDependencies()[0])
   138  }
   139  
   140  func expectPanic(t *testing.T, fn func()) {
   141  	defer func() {
   142  		p := recover()
   143  		testutil.Require(t, p != nil, "expecting panic")
   144  	}()
   145  
   146  	fn()
   147  }