github.com/jhump/protoreflect@v1.16.0/dynamic/msgregistry/fetchers_test.go (about) 1 package msgregistry 2 3 import ( 4 "bytes" 5 "context" 6 "errors" 7 "fmt" 8 "io/ioutil" 9 "net/http" 10 "strings" 11 "sync" 12 "sync/atomic" 13 "testing" 14 "time" 15 16 "github.com/golang/protobuf/proto" 17 "google.golang.org/protobuf/types/known/sourcecontextpb" 18 "google.golang.org/protobuf/types/known/typepb" 19 20 "github.com/jhump/protoreflect/internal/testutil" 21 ) 22 23 func TestCachingTypeFetcher(t *testing.T) { 24 counts := map[string]int{} 25 uncached := func(url string, enum bool) (proto.Message, error) { 26 counts[url] = counts[url] + 1 27 return testFetcher(url, enum) 28 } 29 30 // observe the underlying type fetcher get invoked 10x 31 for i := 0; i < 10; i++ { 32 pm, err := uncached("blah.blah.blah/fee.fi.fo.Fum", false) 33 testutil.Ok(t, err) 34 typ := pm.(*typepb.Type) 35 testutil.Eq(t, "fee.fi.fo.Fum", typ.Name) 36 } 37 for i := 0; i < 10; i++ { 38 pm, err := uncached("blah.blah.blah/fee.fi.fo.Foo", true) 39 testutil.Ok(t, err) 40 en := pm.(*typepb.Enum) 41 testutil.Eq(t, "fee.fi.fo.Foo", en.Name) 42 } 43 44 testutil.Eq(t, 10, counts["blah.blah.blah/fee.fi.fo.Fum"]) 45 testutil.Eq(t, 10, counts["blah.blah.blah/fee.fi.fo.Foo"]) 46 47 // now we'll see the underlying fetcher invoked just one more time, 48 // after which the result is cached 49 cached := CachingTypeFetcher(uncached) 50 51 for i := 0; i < 10; i++ { 52 pm, err := cached("blah.blah.blah/fee.fi.fo.Fum", false) 53 testutil.Ok(t, err) 54 typ := pm.(*typepb.Type) 55 testutil.Eq(t, "fee.fi.fo.Fum", typ.Name) 56 } 57 58 for i := 0; i < 10; i++ { 59 pm, err := cached("blah.blah.blah/fee.fi.fo.Foo", true) 60 testutil.Ok(t, err) 61 en := pm.(*typepb.Enum) 62 testutil.Eq(t, "fee.fi.fo.Foo", en.Name) 63 } 64 65 testutil.Eq(t, 11, counts["blah.blah.blah/fee.fi.fo.Fum"]) 66 testutil.Eq(t, 11, counts["blah.blah.blah/fee.fi.fo.Foo"]) 67 } 68 69 func TestCachingTypeFetcher_MismatchType(t *testing.T) { 70 fetcher := CachingTypeFetcher(testFetcher) 71 // get a message type 72 pm, err := fetcher("blah.blah.blah/fee.fi.fo.Fum", false) 73 testutil.Ok(t, err) 74 typ := pm.(*typepb.Type) 75 testutil.Eq(t, "fee.fi.fo.Fum", typ.Name) 76 // and an enum type 77 pm, err = fetcher("blah.blah.blah/fee.fi.fo.Foo", true) 78 testutil.Ok(t, err) 79 en := pm.(*typepb.Enum) 80 testutil.Eq(t, "fee.fi.fo.Foo", en.Name) 81 82 // now ask for same URL, but swapped types 83 _, err = fetcher("blah.blah.blah/fee.fi.fo.Fum", true) 84 testutil.Require(t, err != nil && strings.Contains(err.Error(), "wanted enum, got message")) 85 _, err = fetcher("blah.blah.blah/fee.fi.fo.Foo", false) 86 testutil.Require(t, err != nil && strings.Contains(err.Error(), "wanted message, got enum")) 87 } 88 89 func TestCachingTypeFetcher_Concurrency(t *testing.T) { 90 // make sure we are thread safe 91 var mu sync.Mutex 92 counts := map[string]int{} 93 tf := CachingTypeFetcher(func(url string, enum bool) (proto.Message, error) { 94 mu.Lock() 95 counts[url] = counts[url] + 1 96 mu.Unlock() 97 return testFetcher(url, enum) 98 }) 99 100 ctx, cancel := context.WithCancel(context.Background()) 101 names := []string{"Fee", "Fi", "Fo", "Fum", "I", "Smell", "Blood", "Of", "Englishman"} 102 var queryCount int32 103 var wg sync.WaitGroup 104 for i := 0; i < 10; i++ { 105 wg.Add(1) 106 go func() { 107 defer wg.Done() 108 for i := 0; ctx.Err() == nil; i = (i + 1) % len(names) { 109 n := "fee.fi.fo." + names[i] 110 // message 111 pm, err := tf("blah.blah.blah/"+n, false) 112 testutil.Ok(t, err) 113 typ := pm.(*typepb.Type) 114 testutil.Eq(t, n, typ.Name) 115 atomic.AddInt32(&queryCount, 1) 116 // enum 117 pm, err = tf("blah.blah.blah.en/"+n, true) 118 testutil.Ok(t, err) 119 en := pm.(*typepb.Enum) 120 testutil.Eq(t, n, en.Name) 121 atomic.AddInt32(&queryCount, 1) 122 } 123 }() 124 } 125 126 time.Sleep(2 * time.Second) 127 cancel() 128 wg.Wait() 129 130 // underlying fetcher invoked just once per URL 131 for _, v := range counts { 132 testutil.Eq(t, 1, v) 133 } 134 135 testutil.Require(t, atomic.LoadInt32(&queryCount) > int32(len(counts))) 136 } 137 138 func TestHttpTypeFetcher(t *testing.T) { 139 trt := &testRoundTripper{counts: map[string]int{}} 140 fetcher := HttpTypeFetcher(trt, 65536, 10) 141 142 for i := 0; i < 10; i++ { 143 pm, err := fetcher("blah.blah.blah/fee.fi.fo.Message", false) 144 testutil.Ok(t, err) 145 typ := pm.(*typepb.Type) 146 testutil.Eq(t, "fee.fi.fo.Message", typ.Name) 147 } 148 149 for i := 0; i < 10; i++ { 150 // name must have Enum for test fetcher to return an enum type 151 pm, err := fetcher("blah.blah.blah/fee.fi.fo.Enum", true) 152 testutil.Ok(t, err) 153 en := pm.(*typepb.Enum) 154 testutil.Eq(t, "fee.fi.fo.Enum", en.Name) 155 } 156 157 // HttpTypeFetcher caches results 158 testutil.Eq(t, 1, trt.counts["https://blah.blah.blah/fee.fi.fo.Message"]) 159 testutil.Eq(t, 1, trt.counts["https://blah.blah.blah/fee.fi.fo.Enum"]) 160 } 161 162 func TestHttpTypeFetcher_ParallelDownloads(t *testing.T) { 163 trt := &testRoundTripper{counts: map[string]int{}, delay: 100 * time.Millisecond} 164 fetcher := HttpTypeFetcher(trt, 65536, 10) 165 // We spin up 100 fetches in parallel, but only 10 can go at a time and each 166 // one takes 100millis. So it should take about 1 second. 167 start := time.Now() 168 var wg sync.WaitGroup 169 for i := 0; i < 100; i++ { 170 wg.Add(1) 171 index := i // don't capture loop variable 172 go func() { 173 defer wg.Done() 174 name := fmt.Sprintf("fee.fi.fo.Fum%d", index) 175 pm, err := fetcher("blah.blah.blah/"+name, false) 176 testutil.Ok(t, err) 177 typ := pm.(*typepb.Type) 178 testutil.Eq(t, name, typ.Name) 179 }() 180 } 181 wg.Wait() 182 elapsed := time.Since(start) 183 184 // we should have observed exactly the maximum number of parallel downloads 185 testutil.Eq(t, 10, trt.max) 186 // should have taken about a second 187 testutil.Require(t, elapsed >= time.Second) 188 } 189 190 func TestHttpTypeFetcher_SizeLimits(t *testing.T) { 191 trt := &testRoundTripper{counts: map[string]int{}} 192 // small size that will always get tripped 193 fetcher := HttpTypeFetcher(trt, 32, 10) 194 195 // name with "Size" causes content-length to be reported in header 196 _, err := fetcher("blah.blah.blah/fee.fi.fo.FumSize", false) 197 testutil.Require(t, err != nil && strings.Contains(err.Error(), "is larger than limit of 32")) 198 199 // without size in the name, no content-length (e.g. streaming response) 200 _, err = fetcher("blah.blah.blah/fee.fi.fo.Fum", false) 201 testutil.Require(t, err != nil && strings.Contains(err.Error(), "is larger than limit of 32")) 202 } 203 204 type testRoundTripper struct { 205 // artificial delay that each fake HTTP request will take 206 delay time.Duration 207 mu sync.Mutex 208 // counts by requested URL 209 counts map[string]int 210 // total active downloads 211 active int 212 // max observed active downloads 213 max int 214 } 215 216 func (t *testRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 217 url := req.URL.String() 218 219 t.mu.Lock() 220 t.counts[url] = t.counts[url] + 1 221 t.active++ 222 if t.active > t.max { 223 t.max = t.active 224 } 225 t.mu.Unlock() 226 227 defer func() { 228 t.mu.Lock() 229 t.active-- 230 t.mu.Unlock() 231 }() 232 233 time.Sleep(t.delay) 234 235 name := url[strings.LastIndex(req.URL.Path, "/")+1:] 236 includeContentLength := strings.Contains(name, "Size") 237 pm, err := testFetcher(url, strings.Contains(name, "Enum")) 238 if err != nil { 239 return nil, err 240 } 241 b, err := proto.Marshal(pm) 242 if err != nil { 243 return nil, err 244 } 245 contentLength := int64(-1) 246 if includeContentLength { 247 contentLength = int64(len(b)) 248 } 249 return &http.Response{ 250 StatusCode: 200, 251 Status: "200 OK", 252 ContentLength: contentLength, 253 Body: ioutil.NopCloser(bytes.NewReader(b)), 254 }, nil 255 } 256 257 func testFetcher(url string, enum bool) (proto.Message, error) { 258 name := url[strings.LastIndex(url, "/")+1:] 259 if strings.Contains(name, "Error") { 260 return nil, errors.New(name) 261 } else if enum { 262 return &typepb.Enum{ 263 Name: name, 264 SourceContext: &sourcecontextpb.SourceContext{FileName: "test.proto"}, 265 Syntax: typepb.Syntax_SYNTAX_PROTO3, 266 Enumvalue: []*typepb.EnumValue{ 267 {Name: "A", Number: 0}, 268 {Name: "B", Number: 1}, 269 {Name: "C", Number: 2}, 270 }, 271 }, nil 272 } else { 273 return &typepb.Type{ 274 Name: name, 275 SourceContext: &sourcecontextpb.SourceContext{FileName: "test.proto"}, 276 Syntax: typepb.Syntax_SYNTAX_PROTO3, 277 Fields: []*typepb.Field{ 278 {Name: "a", Number: 1, Cardinality: typepb.Field_CARDINALITY_OPTIONAL, Kind: typepb.Field_TYPE_INT64}, 279 {Name: "b", Number: 2, Cardinality: typepb.Field_CARDINALITY_OPTIONAL, Kind: typepb.Field_TYPE_STRING}, 280 {Name: "c1", Number: 3, OneofIndex: 1, Cardinality: typepb.Field_CARDINALITY_OPTIONAL, Kind: typepb.Field_TYPE_STRING}, 281 {Name: "c2", Number: 4, OneofIndex: 1, Cardinality: typepb.Field_CARDINALITY_OPTIONAL, Kind: typepb.Field_TYPE_BOOL}, 282 {Name: "c3", Number: 5, OneofIndex: 1, Cardinality: typepb.Field_CARDINALITY_OPTIONAL, Kind: typepb.Field_TYPE_DOUBLE}, 283 {Name: "d", Number: 6, Cardinality: typepb.Field_CARDINALITY_REPEATED, Kind: typepb.Field_TYPE_MESSAGE, TypeUrl: "type.googleapis.com/foo.bar.Baz"}, 284 {Name: "e", Number: 7, Cardinality: typepb.Field_CARDINALITY_OPTIONAL, Kind: typepb.Field_TYPE_ENUM, TypeUrl: "type.googleapis.com/foo.bar.Blah"}, 285 }, 286 Oneofs: []string{"union"}, 287 }, nil 288 } 289 }