github.com/jhump/protoreflect@v1.16.0/dynamic/msgregistry/fetchers.go (about) 1 package msgregistry 2 3 import ( 4 "bytes" 5 "fmt" 6 "io" 7 "net/http" 8 "net/url" 9 "sync" 10 11 "github.com/golang/protobuf/proto" 12 "google.golang.org/protobuf/types/known/typepb" 13 ) 14 15 // TypeFetcher is a simple operation that retrieves a type definition for a given type URL. 16 // The returned proto message will be either a *ptype.Enum or a *ptype.Type, depending on 17 // whether the enum flag is true or not. 18 type TypeFetcher func(url string, enum bool) (proto.Message, error) 19 20 // CachingTypeFetcher adds a caching layer to the given type fetcher. Queries for 21 // types that have already been fetched will not result in another call to the 22 // underlying fetcher and instead are retrieved from the cache. 23 func CachingTypeFetcher(fetcher TypeFetcher) TypeFetcher { 24 c := protoCache{entries: map[string]*protoCacheEntry{}} 25 return func(typeUrl string, enum bool) (proto.Message, error) { 26 m, err := c.getOrLoad(typeUrl, func() (proto.Message, error) { 27 return fetcher(typeUrl, enum) 28 }) 29 if err != nil { 30 return nil, err 31 } 32 if _, isEnum := m.(*typepb.Enum); enum != isEnum { 33 var want, got string 34 if enum { 35 want = "enum" 36 got = "message" 37 } else { 38 want = "message" 39 got = "enum" 40 } 41 return nil, fmt.Errorf("type for URL %v is the wrong type: wanted %s, got %s", typeUrl, want, got) 42 } 43 return m.(proto.Message), nil 44 } 45 } 46 47 type protoCache struct { 48 mu sync.RWMutex 49 entries map[string]*protoCacheEntry 50 } 51 52 type protoCacheEntry struct { 53 msg proto.Message 54 err error 55 wg sync.WaitGroup 56 } 57 58 func (c *protoCache) getOrLoad(key string, loader func() (proto.Message, error)) (m proto.Message, err error) { 59 // see if it's cached 60 c.mu.RLock() 61 cached, ok := c.entries[key] 62 c.mu.RUnlock() 63 if ok { 64 cached.wg.Wait() 65 return cached.msg, cached.err 66 } 67 68 // must delegate and cache the result 69 c.mu.Lock() 70 // double-check, in case it was added concurrently while we were upgrading lock 71 cached, ok = c.entries[key] 72 if ok { 73 c.mu.Unlock() 74 cached.wg.Wait() 75 return cached.msg, cached.err 76 } 77 e := &protoCacheEntry{} 78 e.wg.Add(1) 79 c.entries[key] = e 80 c.mu.Unlock() 81 defer func() { 82 if err != nil { 83 // don't leave broken entry in the cache 84 c.mu.Lock() 85 delete(c.entries, key) 86 c.mu.Unlock() 87 } 88 e.msg, e.err = m, err 89 e.wg.Done() 90 }() 91 92 return loader() 93 } 94 95 // HttpTypeFetcher returns a TypeFetcher that uses the given HTTP transport to query and 96 // download type definitions. The given szLimit is the maximum response size accepted. If 97 // used from multiple goroutines (like when a type's dependency graph is resolved in 98 // parallel), this resolver limits the number of parallel queries/downloads to the given 99 // parLimit. 100 func HttpTypeFetcher(transport http.RoundTripper, szLimit, parLimit int) TypeFetcher { 101 sem := semaphore{count: parLimit, permits: parLimit} 102 return CachingTypeFetcher(func(typeUrl string, enum bool) (proto.Message, error) { 103 sem.Acquire() 104 defer sem.Release() 105 106 // build URL 107 u, err := url.Parse(ensureScheme(typeUrl)) 108 if err != nil { 109 return nil, err 110 } 111 112 resp, err := transport.RoundTrip(&http.Request{URL: u}) 113 if err != nil { 114 return nil, err 115 } 116 defer resp.Body.Close() 117 118 if resp.StatusCode != 200 { 119 return nil, fmt.Errorf("HTTP request returned non-200 status code: %s", resp.Status) 120 } 121 122 if resp.ContentLength > int64(szLimit) { 123 return nil, fmt.Errorf("type definition size %d is larger than limit of %d", resp.ContentLength, szLimit) 124 } 125 126 // download the response, up to the given size limit, into a buffer 127 bufptr := bufferPool.Get().(*[]byte) 128 defer bufferPool.Put(bufptr) 129 buf := *bufptr 130 var b bytes.Buffer 131 for { 132 n, err := resp.Body.Read(buf) 133 if err != nil && err != io.EOF { 134 return nil, err 135 } 136 if n > 0 { 137 if b.Len()+n > szLimit { 138 return nil, fmt.Errorf("type definition size %d+ is larger than limit of %d", b.Len()+n, szLimit) 139 } 140 b.Write(buf[:n]) 141 } 142 if err == io.EOF { 143 break 144 } 145 } 146 147 // now we can de-serialize the type definition 148 if enum { 149 var ret typepb.Enum 150 if err = proto.Unmarshal(b.Bytes(), &ret); err != nil { 151 return nil, err 152 } 153 return &ret, nil 154 } else { 155 var ret typepb.Type 156 if err = proto.Unmarshal(b.Bytes(), &ret); err != nil { 157 return nil, err 158 } 159 return &ret, nil 160 } 161 }) 162 } 163 164 var bufferPool = sync.Pool{New: func() interface{} { 165 buf := make([]byte, 8192) 166 return &buf 167 }} 168 169 type semaphore struct { 170 lock sync.Mutex 171 count int 172 permits int 173 cond sync.Cond 174 } 175 176 func (s *semaphore) Acquire() { 177 s.lock.Lock() 178 defer s.lock.Unlock() 179 180 if s.cond.L == nil { 181 s.cond.L = &s.lock 182 } 183 184 for s.count == 0 { 185 s.cond.Wait() 186 } 187 s.count-- 188 } 189 190 func (s *semaphore) Release() { 191 s.lock.Lock() 192 defer s.lock.Unlock() 193 194 if s.cond.L == nil { 195 s.cond.L = &s.lock 196 } 197 198 if s.count == s.permits { 199 panic("call to Release() without corresponding call to Acquire()") 200 } 201 s.count++ 202 s.cond.Signal() 203 }