github.com/pachyderm/pachyderm@v1.13.4/src/client/pkg/grpcutil/stream.go (about)

     1  package grpcutil
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"io"
     7  
     8  	units "github.com/docker/go-units"
     9  	"github.com/gogo/protobuf/types"
    10  
    11  	"github.com/pachyderm/pachyderm/src/client/pkg/errors"
    12  )
    13  
    14  var (
    15  	// MaxMsgSize is used to define the GRPC frame size.
    16  	MaxMsgSize = 20 * units.MiB
    17  	// MaxMsgPayloadSize is the max message payload size.
    18  	// This is slightly less than MaxMsgSize to account
    19  	// for the GRPC message wrapping the payload.
    20  	MaxMsgPayloadSize = MaxMsgSize - units.MiB
    21  )
    22  
    23  // Chunk splits a piece of data up, this is useful for splitting up data that's
    24  // bigger than MaxMsgPayloadSize.
    25  func Chunk(data []byte) [][]byte {
    26  	chunkSize := MaxMsgPayloadSize
    27  	var result [][]byte
    28  	for i := 0; i < len(data); i += chunkSize {
    29  		end := i + chunkSize
    30  		if end > len(data) {
    31  			end = len(data)
    32  		}
    33  		result = append(result, data[i:end])
    34  	}
    35  	return result
    36  }
    37  
    38  // ChunkReader splits a reader into reasonably sized chunks for the purpose
    39  // of transmitting the chunks over gRPC. For each chunk, it calls the given
    40  // function.
    41  func ChunkReader(r io.Reader, f func([]byte) error) (int, error) {
    42  	var total int
    43  	buf := GetBuffer()
    44  	defer PutBuffer(buf)
    45  	for {
    46  		n, err := r.Read(buf)
    47  		if n == 0 && err != nil {
    48  			if errors.Is(err, io.EOF) {
    49  				return total, nil
    50  			}
    51  			return total, err
    52  		}
    53  		if err := f(buf[:n]); err != nil {
    54  			return total, err
    55  		}
    56  		total += n
    57  	}
    58  }
    59  
    60  // ChunkWriteCloser is a utility for buffering writes into buffers obtained from a buffer pool.
    61  // The ChunkWriteCloser will buffer up to the capacity of a buffer obtained from a buffer pool,
    62  // then execute a callback that will receive the buffered data. The ChunkWriteCloser will get
    63  // a new buffer from the pool for subsequent writes, so it is expected that the callback will
    64  // return the buffer to the pool.
    65  type ChunkWriteCloser struct {
    66  	bufPool *BufPool
    67  	buf     []byte
    68  	f       func([]byte) error
    69  }
    70  
    71  // NewChunkWriteCloser creates a new ChunkWriteCloser.
    72  func NewChunkWriteCloser(bufPool *BufPool, f func(chunk []byte) error) *ChunkWriteCloser {
    73  	return &ChunkWriteCloser{
    74  		bufPool: bufPool,
    75  		buf:     bufPool.GetBuffer()[:0],
    76  		f:       f,
    77  	}
    78  }
    79  
    80  // Write performs a write.
    81  func (w *ChunkWriteCloser) Write(data []byte) (int, error) {
    82  	var written int
    83  	for len(w.buf)+len(data) > cap(w.buf) {
    84  		// Write the bytes that fit into w.buf, then
    85  		// remove those bytes from data.
    86  		i := cap(w.buf) - len(w.buf)
    87  		w.buf = append(w.buf, data[:i]...)
    88  		if err := w.f(w.buf); err != nil {
    89  			return 0, err
    90  		}
    91  		w.buf = bufPool.GetBuffer()[:0]
    92  		written += i
    93  		data = data[i:]
    94  	}
    95  	w.buf = append(w.buf, data...)
    96  	written += len(data)
    97  	return written, nil
    98  }
    99  
   100  // Close closes the writer.
   101  func (w *ChunkWriteCloser) Close() error {
   102  	if len(w.buf) == 0 {
   103  		return nil
   104  	}
   105  	return w.f(w.buf)
   106  }
   107  
   108  // StreamingBytesServer represents a server for an rpc method of the form:
   109  //   rpc Foo(Bar) returns (stream google.protobuf.BytesValue) {}
   110  type StreamingBytesServer interface {
   111  	Send(bytesValue *types.BytesValue) error
   112  }
   113  
   114  // StreamingBytesClient represents a client for an rpc method of the form:
   115  //   rpc Foo(Bar) returns (stream google.protobuf.BytesValue) {}
   116  type StreamingBytesClient interface {
   117  	Recv() (*types.BytesValue, error)
   118  }
   119  
   120  // NewStreamingBytesReader returns an io.Reader for a StreamingBytesClient.
   121  func NewStreamingBytesReader(streamingBytesClient StreamingBytesClient, cancel context.CancelFunc) io.ReadCloser {
   122  	return &streamingBytesReader{streamingBytesClient: streamingBytesClient, cancel: cancel}
   123  }
   124  
   125  type streamingBytesReader struct {
   126  	streamingBytesClient StreamingBytesClient
   127  	buffer               bytes.Buffer
   128  	cancel               context.CancelFunc
   129  }
   130  
   131  func (s *streamingBytesReader) Read(p []byte) (int, error) {
   132  	// TODO this is doing an unneeded copy (unless go is smarter than I think it is)
   133  	if s.buffer.Len() == 0 {
   134  		value, err := s.streamingBytesClient.Recv()
   135  		if err != nil {
   136  			return 0, err
   137  		}
   138  		s.buffer.Reset()
   139  		if _, err := s.buffer.Write(value.Value); err != nil {
   140  			return 0, err
   141  		}
   142  	}
   143  	return s.buffer.Read(p)
   144  }
   145  
   146  func (s *streamingBytesReader) Close() error {
   147  	if s.cancel != nil {
   148  		s.cancel()
   149  	}
   150  	return nil
   151  }
   152  
   153  // NewStreamingBytesWriter returns an io.Writer for a StreamingBytesServer.
   154  func NewStreamingBytesWriter(streamingBytesServer StreamingBytesServer) io.Writer {
   155  	return &streamingBytesWriter{streamingBytesServer}
   156  }
   157  
   158  type streamingBytesWriter struct {
   159  	streamingBytesServer StreamingBytesServer
   160  }
   161  
   162  func (s *streamingBytesWriter) Write(data []byte) (int, error) {
   163  	var bytesWritten int
   164  	for _, val := range Chunk(data) {
   165  		if err := s.streamingBytesServer.Send(&types.BytesValue{Value: val}); err != nil {
   166  			return bytesWritten, err
   167  		}
   168  		bytesWritten += len(val)
   169  	}
   170  	return bytesWritten, nil
   171  }
   172  
   173  // ReaderWrapper wraps a reader for the following reason: Go's io.CopyBuffer
   174  // has an annoying optimization wherein if the reader has the WriteTo function
   175  // defined, it doesn't actually use the given buffer.  As a result, we might
   176  // write a large chunk to the gRPC streaming server even though we intend to
   177  // use a small buffer.  Therefore we wrap readers in this wrapper so that only
   178  // Read is defined.
   179  type ReaderWrapper struct {
   180  	Reader io.Reader
   181  }
   182  
   183  func (r ReaderWrapper) Read(p []byte) (int, error) {
   184  	return r.Reader.Read(p)
   185  }
   186  
   187  // WriteToStreamingBytesServer writes the data from the io.Reader to the StreamingBytesServer.
   188  func WriteToStreamingBytesServer(reader io.Reader, streamingBytesServer StreamingBytesServer) error {
   189  	buf := GetBuffer()
   190  	defer PutBuffer(buf)
   191  	_, err := io.CopyBuffer(NewStreamingBytesWriter(streamingBytesServer), ReaderWrapper{reader}, buf)
   192  	return err
   193  }
   194  
   195  // WriteFromStreamingBytesClient writes from the StreamingBytesClient to the io.Writer.
   196  func WriteFromStreamingBytesClient(streamingBytesClient StreamingBytesClient, writer io.Writer) error {
   197  	for bytesValue, err := streamingBytesClient.Recv(); !errors.Is(err, io.EOF); bytesValue, err = streamingBytesClient.Recv() {
   198  		if err != nil {
   199  			return err
   200  		}
   201  		if _, err = writer.Write(bytesValue.Value); err != nil {
   202  			return err
   203  		}
   204  	}
   205  	return nil
   206  }