github.com/jhump/protoreflect@v1.16.0/dynamic/msgregistry/fetchers.go (about)

     1  package msgregistry
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"net/http"
     8  	"net/url"
     9  	"sync"
    10  
    11  	"github.com/golang/protobuf/proto"
    12  	"google.golang.org/protobuf/types/known/typepb"
    13  )
    14  
    15  // TypeFetcher is a simple operation that retrieves a type definition for a given type URL.
    16  // The returned proto message will be either a *ptype.Enum or a *ptype.Type, depending on
    17  // whether the enum flag is true or not.
    18  type TypeFetcher func(url string, enum bool) (proto.Message, error)
    19  
    20  // CachingTypeFetcher adds a caching layer to the given type fetcher. Queries for
    21  // types that have already been fetched will not result in another call to the
    22  // underlying fetcher and instead are retrieved from the cache.
    23  func CachingTypeFetcher(fetcher TypeFetcher) TypeFetcher {
    24  	c := protoCache{entries: map[string]*protoCacheEntry{}}
    25  	return func(typeUrl string, enum bool) (proto.Message, error) {
    26  		m, err := c.getOrLoad(typeUrl, func() (proto.Message, error) {
    27  			return fetcher(typeUrl, enum)
    28  		})
    29  		if err != nil {
    30  			return nil, err
    31  		}
    32  		if _, isEnum := m.(*typepb.Enum); enum != isEnum {
    33  			var want, got string
    34  			if enum {
    35  				want = "enum"
    36  				got = "message"
    37  			} else {
    38  				want = "message"
    39  				got = "enum"
    40  			}
    41  			return nil, fmt.Errorf("type for URL %v is the wrong type: wanted %s, got %s", typeUrl, want, got)
    42  		}
    43  		return m.(proto.Message), nil
    44  	}
    45  }
    46  
    47  type protoCache struct {
    48  	mu      sync.RWMutex
    49  	entries map[string]*protoCacheEntry
    50  }
    51  
    52  type protoCacheEntry struct {
    53  	msg proto.Message
    54  	err error
    55  	wg  sync.WaitGroup
    56  }
    57  
    58  func (c *protoCache) getOrLoad(key string, loader func() (proto.Message, error)) (m proto.Message, err error) {
    59  	// see if it's cached
    60  	c.mu.RLock()
    61  	cached, ok := c.entries[key]
    62  	c.mu.RUnlock()
    63  	if ok {
    64  		cached.wg.Wait()
    65  		return cached.msg, cached.err
    66  	}
    67  
    68  	// must delegate and cache the result
    69  	c.mu.Lock()
    70  	// double-check, in case it was added concurrently while we were upgrading lock
    71  	cached, ok = c.entries[key]
    72  	if ok {
    73  		c.mu.Unlock()
    74  		cached.wg.Wait()
    75  		return cached.msg, cached.err
    76  	}
    77  	e := &protoCacheEntry{}
    78  	e.wg.Add(1)
    79  	c.entries[key] = e
    80  	c.mu.Unlock()
    81  	defer func() {
    82  		if err != nil {
    83  			// don't leave broken entry in the cache
    84  			c.mu.Lock()
    85  			delete(c.entries, key)
    86  			c.mu.Unlock()
    87  		}
    88  		e.msg, e.err = m, err
    89  		e.wg.Done()
    90  	}()
    91  
    92  	return loader()
    93  }
    94  
    95  // HttpTypeFetcher returns a TypeFetcher that uses the given HTTP transport to query and
    96  // download type definitions. The given szLimit is the maximum response size accepted. If
    97  // used from multiple goroutines (like when a type's dependency graph is resolved in
    98  // parallel), this resolver limits the number of parallel queries/downloads to the given
    99  // parLimit.
   100  func HttpTypeFetcher(transport http.RoundTripper, szLimit, parLimit int) TypeFetcher {
   101  	sem := semaphore{count: parLimit, permits: parLimit}
   102  	return CachingTypeFetcher(func(typeUrl string, enum bool) (proto.Message, error) {
   103  		sem.Acquire()
   104  		defer sem.Release()
   105  
   106  		// build URL
   107  		u, err := url.Parse(ensureScheme(typeUrl))
   108  		if err != nil {
   109  			return nil, err
   110  		}
   111  
   112  		resp, err := transport.RoundTrip(&http.Request{URL: u})
   113  		if err != nil {
   114  			return nil, err
   115  		}
   116  		defer resp.Body.Close()
   117  
   118  		if resp.StatusCode != 200 {
   119  			return nil, fmt.Errorf("HTTP request returned non-200 status code: %s", resp.Status)
   120  		}
   121  
   122  		if resp.ContentLength > int64(szLimit) {
   123  			return nil, fmt.Errorf("type definition size %d is larger than limit of %d", resp.ContentLength, szLimit)
   124  		}
   125  
   126  		// download the response, up to the given size limit, into a buffer
   127  		bufptr := bufferPool.Get().(*[]byte)
   128  		defer bufferPool.Put(bufptr)
   129  		buf := *bufptr
   130  		var b bytes.Buffer
   131  		for {
   132  			n, err := resp.Body.Read(buf)
   133  			if err != nil && err != io.EOF {
   134  				return nil, err
   135  			}
   136  			if n > 0 {
   137  				if b.Len()+n > szLimit {
   138  					return nil, fmt.Errorf("type definition size %d+ is larger than limit of %d", b.Len()+n, szLimit)
   139  				}
   140  				b.Write(buf[:n])
   141  			}
   142  			if err == io.EOF {
   143  				break
   144  			}
   145  		}
   146  
   147  		// now we can de-serialize the type definition
   148  		if enum {
   149  			var ret typepb.Enum
   150  			if err = proto.Unmarshal(b.Bytes(), &ret); err != nil {
   151  				return nil, err
   152  			}
   153  			return &ret, nil
   154  		} else {
   155  			var ret typepb.Type
   156  			if err = proto.Unmarshal(b.Bytes(), &ret); err != nil {
   157  				return nil, err
   158  			}
   159  			return &ret, nil
   160  		}
   161  	})
   162  }
   163  
   164  var bufferPool = sync.Pool{New: func() interface{} {
   165  	buf := make([]byte, 8192)
   166  	return &buf
   167  }}
   168  
   169  type semaphore struct {
   170  	lock    sync.Mutex
   171  	count   int
   172  	permits int
   173  	cond    sync.Cond
   174  }
   175  
   176  func (s *semaphore) Acquire() {
   177  	s.lock.Lock()
   178  	defer s.lock.Unlock()
   179  
   180  	if s.cond.L == nil {
   181  		s.cond.L = &s.lock
   182  	}
   183  
   184  	for s.count == 0 {
   185  		s.cond.Wait()
   186  	}
   187  	s.count--
   188  }
   189  
   190  func (s *semaphore) Release() {
   191  	s.lock.Lock()
   192  	defer s.lock.Unlock()
   193  
   194  	if s.cond.L == nil {
   195  		s.cond.L = &s.lock
   196  	}
   197  
   198  	if s.count == s.permits {
   199  		panic("call to Release() without corresponding call to Acquire()")
   200  	}
   201  	s.count++
   202  	s.cond.Signal()
   203  }