github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/grpc/stream/stream.go (about)

     1  // Copyright 2023 Gravitational, Inc
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package stream
    16  
    17  import (
    18  	"errors"
    19  	"io"
    20  	"net"
    21  	"sync"
    22  	"time"
    23  
    24  	"github.com/gravitational/trace"
    25  	"google.golang.org/grpc/codes"
    26  	"google.golang.org/grpc/status"
    27  )
    28  
    29  // MaxChunkSize is the maximum number of bytes to send in a single data message.
    30  // According to https://github.com/grpc/grpc.github.io/issues/371 the optimal
    31  // size is between 16KiB to 64KiB.
    32  const MaxChunkSize int = 1024 * 16
    33  
    34  // Source is a common interface for grpc client and server streams
    35  // that transport opaque data.
    36  type Source interface {
    37  	Send([]byte) error
    38  	Recv() ([]byte, error)
    39  }
    40  
    41  // ReadWriter wraps a grpc source with an [io.ReadWriter] interface.
    42  // All reads are consumed from [Source.Recv] and all writes and sent
    43  // via [Source.Send].
    44  type ReadWriter struct {
    45  	source Source
    46  
    47  	wLock  sync.Mutex
    48  	rLock  sync.Mutex
    49  	rBytes []byte
    50  }
    51  
    52  // NewReadWriter creates a new ReadWriter that leverages the provided
    53  // source to retrieve data from and write data to.
    54  func NewReadWriter(source Source) (*ReadWriter, error) {
    55  	if source == nil {
    56  		return nil, trace.BadParameter("parameter source required")
    57  	}
    58  
    59  	return &ReadWriter{
    60  		source: source,
    61  	}, nil
    62  }
    63  
    64  // Read returns data received from the stream source. Any
    65  // data received from the stream that is not consumed will
    66  // be buffered and returned on subsequent reads until there
    67  // is none left. Only then will data be sourced from the stream
    68  // again.
    69  func (c *ReadWriter) Read(b []byte) (n int, err error) {
    70  	c.rLock.Lock()
    71  	defer c.rLock.Unlock()
    72  
    73  	if len(c.rBytes) == 0 {
    74  		data, err := c.source.Recv()
    75  		if errors.Is(err, io.EOF) || status.Code(err) == codes.Canceled {
    76  			return 0, io.EOF
    77  		}
    78  
    79  		if err != nil {
    80  			return 0, trace.ConnectionProblem(trace.Wrap(err), "failed to receive from source: %v", err)
    81  		}
    82  
    83  		if data == nil {
    84  			return 0, trace.BadParameter("received invalid data from source")
    85  		}
    86  
    87  		c.rBytes = data
    88  	}
    89  
    90  	n = copy(b, c.rBytes)
    91  	c.rBytes = c.rBytes[n:]
    92  
    93  	// Stop holding onto buffer immediately
    94  	if len(c.rBytes) == 0 {
    95  		c.rBytes = nil
    96  	}
    97  
    98  	return n, nil
    99  }
   100  
   101  // Write consumes all data provided and sends it on
   102  // the grpc stream. To prevent exhausting the stream all
   103  // sends on the stream are limited to be at most MaxChunkSize.
   104  // If the data exceeds the MaxChunkSize it will be sent in
   105  // batches.
   106  func (c *ReadWriter) Write(b []byte) (int, error) {
   107  	c.wLock.Lock()
   108  	defer c.wLock.Unlock()
   109  
   110  	var sent int
   111  	for len(b) > 0 {
   112  		chunk := b
   113  		if len(chunk) > MaxChunkSize {
   114  			chunk = chunk[:MaxChunkSize]
   115  		}
   116  
   117  		if err := c.source.Send(chunk); err != nil {
   118  			return sent, trace.ConnectionProblem(trace.Wrap(err), "failed to send on source: %v", err)
   119  		}
   120  
   121  		sent += len(chunk)
   122  		b = b[len(chunk):]
   123  	}
   124  
   125  	return sent, nil
   126  }
   127  
   128  // Close cleans up resources used by the stream.
   129  func (c *ReadWriter) Close() error {
   130  	if cs, ok := c.source.(io.Closer); ok {
   131  		return trace.Wrap(cs.Close())
   132  	}
   133  
   134  	return nil
   135  }
   136  
   137  // Conn wraps [ReadWriter] in a [net.Conn] interface.
   138  type Conn struct {
   139  	*ReadWriter
   140  
   141  	src net.Addr
   142  	dst net.Addr
   143  }
   144  
   145  // NewConn creates a new Conn which transfers data via the provided ReadWriter.
   146  func NewConn(rw *ReadWriter, src net.Addr, dst net.Addr) *Conn {
   147  	return &Conn{
   148  		ReadWriter: rw,
   149  		src:        src,
   150  		dst:        dst,
   151  	}
   152  }
   153  
   154  // LocalAddr is the original source address of the client.
   155  func (c *Conn) LocalAddr() net.Addr {
   156  	return c.src
   157  }
   158  
   159  // RemoteAddr is the address of the reverse tunnel node.
   160  func (c *Conn) RemoteAddr() net.Addr {
   161  	return c.dst
   162  }
   163  
   164  func (c *Conn) SetDeadline(t time.Time) error {
   165  	return nil
   166  }
   167  
   168  func (c *Conn) SetReadDeadline(t time.Time) error {
   169  	return nil
   170  }
   171  
   172  func (c *Conn) SetWriteDeadline(t time.Time) error {
   173  	return nil
   174  }