github.com/jhump/protoreflect@v1.16.0/dynamic/msgregistry/fetchers_test.go (about)

     1  package msgregistry
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"io/ioutil"
     9  	"net/http"
    10  	"strings"
    11  	"sync"
    12  	"sync/atomic"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/golang/protobuf/proto"
    17  	"google.golang.org/protobuf/types/known/sourcecontextpb"
    18  	"google.golang.org/protobuf/types/known/typepb"
    19  
    20  	"github.com/jhump/protoreflect/internal/testutil"
    21  )
    22  
    23  func TestCachingTypeFetcher(t *testing.T) {
    24  	counts := map[string]int{}
    25  	uncached := func(url string, enum bool) (proto.Message, error) {
    26  		counts[url] = counts[url] + 1
    27  		return testFetcher(url, enum)
    28  	}
    29  
    30  	// observe the underlying type fetcher get invoked 10x
    31  	for i := 0; i < 10; i++ {
    32  		pm, err := uncached("blah.blah.blah/fee.fi.fo.Fum", false)
    33  		testutil.Ok(t, err)
    34  		typ := pm.(*typepb.Type)
    35  		testutil.Eq(t, "fee.fi.fo.Fum", typ.Name)
    36  	}
    37  	for i := 0; i < 10; i++ {
    38  		pm, err := uncached("blah.blah.blah/fee.fi.fo.Foo", true)
    39  		testutil.Ok(t, err)
    40  		en := pm.(*typepb.Enum)
    41  		testutil.Eq(t, "fee.fi.fo.Foo", en.Name)
    42  	}
    43  
    44  	testutil.Eq(t, 10, counts["blah.blah.blah/fee.fi.fo.Fum"])
    45  	testutil.Eq(t, 10, counts["blah.blah.blah/fee.fi.fo.Foo"])
    46  
    47  	// now we'll see the underlying fetcher invoked just one more time,
    48  	// after which the result is cached
    49  	cached := CachingTypeFetcher(uncached)
    50  
    51  	for i := 0; i < 10; i++ {
    52  		pm, err := cached("blah.blah.blah/fee.fi.fo.Fum", false)
    53  		testutil.Ok(t, err)
    54  		typ := pm.(*typepb.Type)
    55  		testutil.Eq(t, "fee.fi.fo.Fum", typ.Name)
    56  	}
    57  
    58  	for i := 0; i < 10; i++ {
    59  		pm, err := cached("blah.blah.blah/fee.fi.fo.Foo", true)
    60  		testutil.Ok(t, err)
    61  		en := pm.(*typepb.Enum)
    62  		testutil.Eq(t, "fee.fi.fo.Foo", en.Name)
    63  	}
    64  
    65  	testutil.Eq(t, 11, counts["blah.blah.blah/fee.fi.fo.Fum"])
    66  	testutil.Eq(t, 11, counts["blah.blah.blah/fee.fi.fo.Foo"])
    67  }
    68  
    69  func TestCachingTypeFetcher_MismatchType(t *testing.T) {
    70  	fetcher := CachingTypeFetcher(testFetcher)
    71  	// get a message type
    72  	pm, err := fetcher("blah.blah.blah/fee.fi.fo.Fum", false)
    73  	testutil.Ok(t, err)
    74  	typ := pm.(*typepb.Type)
    75  	testutil.Eq(t, "fee.fi.fo.Fum", typ.Name)
    76  	// and an enum type
    77  	pm, err = fetcher("blah.blah.blah/fee.fi.fo.Foo", true)
    78  	testutil.Ok(t, err)
    79  	en := pm.(*typepb.Enum)
    80  	testutil.Eq(t, "fee.fi.fo.Foo", en.Name)
    81  
    82  	// now ask for same URL, but swapped types
    83  	_, err = fetcher("blah.blah.blah/fee.fi.fo.Fum", true)
    84  	testutil.Require(t, err != nil && strings.Contains(err.Error(), "wanted enum, got message"))
    85  	_, err = fetcher("blah.blah.blah/fee.fi.fo.Foo", false)
    86  	testutil.Require(t, err != nil && strings.Contains(err.Error(), "wanted message, got enum"))
    87  }
    88  
    89  func TestCachingTypeFetcher_Concurrency(t *testing.T) {
    90  	// make sure we are thread safe
    91  	var mu sync.Mutex
    92  	counts := map[string]int{}
    93  	tf := CachingTypeFetcher(func(url string, enum bool) (proto.Message, error) {
    94  		mu.Lock()
    95  		counts[url] = counts[url] + 1
    96  		mu.Unlock()
    97  		return testFetcher(url, enum)
    98  	})
    99  
   100  	ctx, cancel := context.WithCancel(context.Background())
   101  	names := []string{"Fee", "Fi", "Fo", "Fum", "I", "Smell", "Blood", "Of", "Englishman"}
   102  	var queryCount int32
   103  	var wg sync.WaitGroup
   104  	for i := 0; i < 10; i++ {
   105  		wg.Add(1)
   106  		go func() {
   107  			defer wg.Done()
   108  			for i := 0; ctx.Err() == nil; i = (i + 1) % len(names) {
   109  				n := "fee.fi.fo." + names[i]
   110  				// message
   111  				pm, err := tf("blah.blah.blah/"+n, false)
   112  				testutil.Ok(t, err)
   113  				typ := pm.(*typepb.Type)
   114  				testutil.Eq(t, n, typ.Name)
   115  				atomic.AddInt32(&queryCount, 1)
   116  				// enum
   117  				pm, err = tf("blah.blah.blah.en/"+n, true)
   118  				testutil.Ok(t, err)
   119  				en := pm.(*typepb.Enum)
   120  				testutil.Eq(t, n, en.Name)
   121  				atomic.AddInt32(&queryCount, 1)
   122  			}
   123  		}()
   124  	}
   125  
   126  	time.Sleep(2 * time.Second)
   127  	cancel()
   128  	wg.Wait()
   129  
   130  	// underlying fetcher invoked just once per URL
   131  	for _, v := range counts {
   132  		testutil.Eq(t, 1, v)
   133  	}
   134  
   135  	testutil.Require(t, atomic.LoadInt32(&queryCount) > int32(len(counts)))
   136  }
   137  
   138  func TestHttpTypeFetcher(t *testing.T) {
   139  	trt := &testRoundTripper{counts: map[string]int{}}
   140  	fetcher := HttpTypeFetcher(trt, 65536, 10)
   141  
   142  	for i := 0; i < 10; i++ {
   143  		pm, err := fetcher("blah.blah.blah/fee.fi.fo.Message", false)
   144  		testutil.Ok(t, err)
   145  		typ := pm.(*typepb.Type)
   146  		testutil.Eq(t, "fee.fi.fo.Message", typ.Name)
   147  	}
   148  
   149  	for i := 0; i < 10; i++ {
   150  		// name must have Enum for test fetcher to return an enum type
   151  		pm, err := fetcher("blah.blah.blah/fee.fi.fo.Enum", true)
   152  		testutil.Ok(t, err)
   153  		en := pm.(*typepb.Enum)
   154  		testutil.Eq(t, "fee.fi.fo.Enum", en.Name)
   155  	}
   156  
   157  	// HttpTypeFetcher caches results
   158  	testutil.Eq(t, 1, trt.counts["https://blah.blah.blah/fee.fi.fo.Message"])
   159  	testutil.Eq(t, 1, trt.counts["https://blah.blah.blah/fee.fi.fo.Enum"])
   160  }
   161  
   162  func TestHttpTypeFetcher_ParallelDownloads(t *testing.T) {
   163  	trt := &testRoundTripper{counts: map[string]int{}, delay: 100 * time.Millisecond}
   164  	fetcher := HttpTypeFetcher(trt, 65536, 10)
   165  	// We spin up 100 fetches in parallel, but only 10 can go at a time and each
   166  	// one takes 100millis. So it should take about 1 second.
   167  	start := time.Now()
   168  	var wg sync.WaitGroup
   169  	for i := 0; i < 100; i++ {
   170  		wg.Add(1)
   171  		index := i // don't capture loop variable
   172  		go func() {
   173  			defer wg.Done()
   174  			name := fmt.Sprintf("fee.fi.fo.Fum%d", index)
   175  			pm, err := fetcher("blah.blah.blah/"+name, false)
   176  			testutil.Ok(t, err)
   177  			typ := pm.(*typepb.Type)
   178  			testutil.Eq(t, name, typ.Name)
   179  		}()
   180  	}
   181  	wg.Wait()
   182  	elapsed := time.Since(start)
   183  
   184  	// we should have observed exactly the maximum number of parallel downloads
   185  	testutil.Eq(t, 10, trt.max)
   186  	// should have taken about a second
   187  	testutil.Require(t, elapsed >= time.Second)
   188  }
   189  
   190  func TestHttpTypeFetcher_SizeLimits(t *testing.T) {
   191  	trt := &testRoundTripper{counts: map[string]int{}}
   192  	// small size that will always get tripped
   193  	fetcher := HttpTypeFetcher(trt, 32, 10)
   194  
   195  	// name with "Size" causes content-length to be reported in header
   196  	_, err := fetcher("blah.blah.blah/fee.fi.fo.FumSize", false)
   197  	testutil.Require(t, err != nil && strings.Contains(err.Error(), "is larger than limit of 32"))
   198  
   199  	// without size in the name, no content-length (e.g. streaming response)
   200  	_, err = fetcher("blah.blah.blah/fee.fi.fo.Fum", false)
   201  	testutil.Require(t, err != nil && strings.Contains(err.Error(), "is larger than limit of 32"))
   202  }
   203  
   204  type testRoundTripper struct {
   205  	// artificial delay that each fake HTTP request will take
   206  	delay time.Duration
   207  	mu    sync.Mutex
   208  	// counts by requested URL
   209  	counts map[string]int
   210  	// total active downloads
   211  	active int
   212  	// max observed active downloads
   213  	max int
   214  }
   215  
   216  func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   217  	url := req.URL.String()
   218  
   219  	t.mu.Lock()
   220  	t.counts[url] = t.counts[url] + 1
   221  	t.active++
   222  	if t.active > t.max {
   223  		t.max = t.active
   224  	}
   225  	t.mu.Unlock()
   226  
   227  	defer func() {
   228  		t.mu.Lock()
   229  		t.active--
   230  		t.mu.Unlock()
   231  	}()
   232  
   233  	time.Sleep(t.delay)
   234  
   235  	name := url[strings.LastIndex(req.URL.Path, "/")+1:]
   236  	includeContentLength := strings.Contains(name, "Size")
   237  	pm, err := testFetcher(url, strings.Contains(name, "Enum"))
   238  	if err != nil {
   239  		return nil, err
   240  	}
   241  	b, err := proto.Marshal(pm)
   242  	if err != nil {
   243  		return nil, err
   244  	}
   245  	contentLength := int64(-1)
   246  	if includeContentLength {
   247  		contentLength = int64(len(b))
   248  	}
   249  	return &http.Response{
   250  		StatusCode:    200,
   251  		Status:        "200 OK",
   252  		ContentLength: contentLength,
   253  		Body:          ioutil.NopCloser(bytes.NewReader(b)),
   254  	}, nil
   255  }
   256  
   257  func testFetcher(url string, enum bool) (proto.Message, error) {
   258  	name := url[strings.LastIndex(url, "/")+1:]
   259  	if strings.Contains(name, "Error") {
   260  		return nil, errors.New(name)
   261  	} else if enum {
   262  		return &typepb.Enum{
   263  			Name:          name,
   264  			SourceContext: &sourcecontextpb.SourceContext{FileName: "test.proto"},
   265  			Syntax:        typepb.Syntax_SYNTAX_PROTO3,
   266  			Enumvalue: []*typepb.EnumValue{
   267  				{Name: "A", Number: 0},
   268  				{Name: "B", Number: 1},
   269  				{Name: "C", Number: 2},
   270  			},
   271  		}, nil
   272  	} else {
   273  		return &typepb.Type{
   274  			Name:          name,
   275  			SourceContext: &sourcecontextpb.SourceContext{FileName: "test.proto"},
   276  			Syntax:        typepb.Syntax_SYNTAX_PROTO3,
   277  			Fields: []*typepb.Field{
   278  				{Name: "a", Number: 1, Cardinality: typepb.Field_CARDINALITY_OPTIONAL, Kind: typepb.Field_TYPE_INT64},
   279  				{Name: "b", Number: 2, Cardinality: typepb.Field_CARDINALITY_OPTIONAL, Kind: typepb.Field_TYPE_STRING},
   280  				{Name: "c1", Number: 3, OneofIndex: 1, Cardinality: typepb.Field_CARDINALITY_OPTIONAL, Kind: typepb.Field_TYPE_STRING},
   281  				{Name: "c2", Number: 4, OneofIndex: 1, Cardinality: typepb.Field_CARDINALITY_OPTIONAL, Kind: typepb.Field_TYPE_BOOL},
   282  				{Name: "c3", Number: 5, OneofIndex: 1, Cardinality: typepb.Field_CARDINALITY_OPTIONAL, Kind: typepb.Field_TYPE_DOUBLE},
   283  				{Name: "d", Number: 6, Cardinality: typepb.Field_CARDINALITY_REPEATED, Kind: typepb.Field_TYPE_MESSAGE, TypeUrl: "type.googleapis.com/foo.bar.Baz"},
   284  				{Name: "e", Number: 7, Cardinality: typepb.Field_CARDINALITY_OPTIONAL, Kind: typepb.Field_TYPE_ENUM, TypeUrl: "type.googleapis.com/foo.bar.Blah"},
   285  			},
   286  			Oneofs: []string{"union"},
   287  		}, nil
   288  	}
   289  }