github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/grpc/stream/stream.go (about) 1 // Copyright 2023 Gravitational, Inc 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package stream 16 17 import ( 18 "errors" 19 "io" 20 "net" 21 "sync" 22 "time" 23 24 "github.com/gravitational/trace" 25 "google.golang.org/grpc/codes" 26 "google.golang.org/grpc/status" 27 ) 28 29 // MaxChunkSize is the maximum number of bytes to send in a single data message. 30 // According to https://github.com/grpc/grpc.github.io/issues/371 the optimal 31 // size is between 16KiB to 64KiB. 32 const MaxChunkSize int = 1024 * 16 33 34 // Source is a common interface for grpc client and server streams 35 // that transport opaque data. 36 type Source interface { 37 Send([]byte) error 38 Recv() ([]byte, error) 39 } 40 41 // ReadWriter wraps a grpc source with an [io.ReadWriter] interface. 42 // All reads are consumed from [Source.Recv] and all writes and sent 43 // via [Source.Send]. 44 type ReadWriter struct { 45 source Source 46 47 wLock sync.Mutex 48 rLock sync.Mutex 49 rBytes []byte 50 } 51 52 // NewReadWriter creates a new ReadWriter that leverages the provided 53 // source to retrieve data from and write data to. 54 func NewReadWriter(source Source) (*ReadWriter, error) { 55 if source == nil { 56 return nil, trace.BadParameter("parameter source required") 57 } 58 59 return &ReadWriter{ 60 source: source, 61 }, nil 62 } 63 64 // Read returns data received from the stream source. Any 65 // data received from the stream that is not consumed will 66 // be buffered and returned on subsequent reads until there 67 // is none left. Only then will data be sourced from the stream 68 // again. 69 func (c *ReadWriter) Read(b []byte) (n int, err error) { 70 c.rLock.Lock() 71 defer c.rLock.Unlock() 72 73 if len(c.rBytes) == 0 { 74 data, err := c.source.Recv() 75 if errors.Is(err, io.EOF) || status.Code(err) == codes.Canceled { 76 return 0, io.EOF 77 } 78 79 if err != nil { 80 return 0, trace.ConnectionProblem(trace.Wrap(err), "failed to receive from source: %v", err) 81 } 82 83 if data == nil { 84 return 0, trace.BadParameter("received invalid data from source") 85 } 86 87 c.rBytes = data 88 } 89 90 n = copy(b, c.rBytes) 91 c.rBytes = c.rBytes[n:] 92 93 // Stop holding onto buffer immediately 94 if len(c.rBytes) == 0 { 95 c.rBytes = nil 96 } 97 98 return n, nil 99 } 100 101 // Write consumes all data provided and sends it on 102 // the grpc stream. To prevent exhausting the stream all 103 // sends on the stream are limited to be at most MaxChunkSize. 104 // If the data exceeds the MaxChunkSize it will be sent in 105 // batches. 106 func (c *ReadWriter) Write(b []byte) (int, error) { 107 c.wLock.Lock() 108 defer c.wLock.Unlock() 109 110 var sent int 111 for len(b) > 0 { 112 chunk := b 113 if len(chunk) > MaxChunkSize { 114 chunk = chunk[:MaxChunkSize] 115 } 116 117 if err := c.source.Send(chunk); err != nil { 118 return sent, trace.ConnectionProblem(trace.Wrap(err), "failed to send on source: %v", err) 119 } 120 121 sent += len(chunk) 122 b = b[len(chunk):] 123 } 124 125 return sent, nil 126 } 127 128 // Close cleans up resources used by the stream. 129 func (c *ReadWriter) Close() error { 130 if cs, ok := c.source.(io.Closer); ok { 131 return trace.Wrap(cs.Close()) 132 } 133 134 return nil 135 } 136 137 // Conn wraps [ReadWriter] in a [net.Conn] interface. 138 type Conn struct { 139 *ReadWriter 140 141 src net.Addr 142 dst net.Addr 143 } 144 145 // NewConn creates a new Conn which transfers data via the provided ReadWriter. 146 func NewConn(rw *ReadWriter, src net.Addr, dst net.Addr) *Conn { 147 return &Conn{ 148 ReadWriter: rw, 149 src: src, 150 dst: dst, 151 } 152 } 153 154 // LocalAddr is the original source address of the client. 155 func (c *Conn) LocalAddr() net.Addr { 156 return c.src 157 } 158 159 // RemoteAddr is the address of the reverse tunnel node. 160 func (c *Conn) RemoteAddr() net.Addr { 161 return c.dst 162 } 163 164 func (c *Conn) SetDeadline(t time.Time) error { 165 return nil 166 } 167 168 func (c *Conn) SetReadDeadline(t time.Time) error { 169 return nil 170 } 171 172 func (c *Conn) SetWriteDeadline(t time.Time) error { 173 return nil 174 }