github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/tunnel/stream.go (about) 1 package tunnel 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "io" 8 "net" 9 "sync" 10 "time" 11 12 "go.opentelemetry.io/otel" 13 "google.golang.org/grpc/codes" 14 "google.golang.org/grpc/status" 15 16 "github.com/datawire/dlib/dlog" 17 rpc "github.com/telepresenceio/telepresence/rpc/v2/manager" 18 ) 19 20 // Version 21 // 22 // 0 which didn't report versions and didn't do synchronization 23 // 1 used MuxTunnel instead of one tunnel per connection. 24 const Version = uint16(2) 25 26 // Endpoint is an endpoint for a Stream such as a Dialer or a bidirectional pipe. 27 type Endpoint interface { 28 Start(ctx context.Context) 29 Done() <-chan struct{} 30 } 31 32 // GRPCStream is the bare minimum needed for reading and writing TunnelMessages 33 // on a Manager_TunnelServer or Manager_TunnelClient. 34 type GRPCStream interface { 35 Recv() (*rpc.TunnelMessage, error) 36 Send(*rpc.TunnelMessage) error 37 } 38 39 // The Stream interface represents a bidirectional, synchronized connection Tunnel 40 // that sends TCP or UDP traffic over gRPC using manager.TunnelMessage messages. 41 // 42 // A Stream is closed by one of six things happening at either end (or at both ends). 43 // 44 // 1. Read from local connection fails (typically EOF) 45 // 2. Write to local connection fails (connection peer closed) 46 // 3. Idle timer timed out. 47 // 4. Context is cancelled. 48 // 5. closeSend request received from Tunnel peer. 49 // 6. Disconnect received from Tunnel peer. 50 // 51 // When #1 or #2 happens, the Stream will either call CloseSend() (if it's a client Stream) 52 // or send a closeSend request (if it's a StreamServer) to its Stream peer, shorten the 53 // Idle timer, and then continue to serve incoming data from the Stream peer until it's 54 // closed or a Disconnect is received. Once that happens, it's guaranteed that the Tunnel 55 // peer will send no more messages and the Stream is closed. 56 // 57 // When #3, #4, or #5 happens, the Tunnel will send a Disconnect to its Stream peer and close. 58 // 59 // When #6 happens, the Stream will simply close. 60 type Stream interface { 61 Tag() string 62 ID() ConnID 63 Receive(context.Context) (Message, error) 64 Send(context.Context, Message) error 65 CloseSend(ctx context.Context) error 66 PeerVersion() uint16 67 SessionID() string 68 DialTimeout() time.Duration 69 RoundtripLatency() time.Duration 70 } 71 72 // StreamCreator is a function that creats a Stream. 73 type StreamCreator func(context.Context, ConnID) (Stream, error) 74 75 // ReadLoop reads from the Stream and dispatches messages and error to the give channels. There 76 // will be max one error since the error also terminates the loop. 77 func ReadLoop(ctx context.Context, s Stream, p *CounterProbe) (<-chan Message, <-chan error) { 78 msgCh := make(chan Message, 50) 79 errCh := make(chan error, 1) // Max one message will be sent on this channel 80 dlog.Tracef(ctx, " %s %s, ReadLoop starting", s.Tag(), s.ID()) 81 go func() { 82 ctx, span := otel.GetTracerProvider().Tracer("").Start(ctx, "ReadLoop") 83 defer span.End() 84 s.ID().SpanRecord(span) 85 var endReason string 86 defer func() { 87 close(errCh) 88 close(msgCh) 89 dlog.Tracef(ctx, " %s %s, ReadLoop ended: %s", s.Tag(), s.ID(), endReason) 90 }() 91 92 for { 93 m, err := s.Receive(ctx) 94 if m != nil && p != nil { 95 p.Increment(uint64(len(m.Payload()))) 96 } 97 98 switch { 99 case err == nil: 100 select { 101 case <-ctx.Done(): 102 endReason = ctx.Err().Error() 103 case msgCh <- m: 104 continue 105 } 106 case ctx.Err() != nil: 107 endReason = ctx.Err().Error() 108 case errors.Is(err, io.EOF): 109 endReason = "EOF on input" 110 case errors.Is(err, net.ErrClosed): 111 endReason = "stream closed" 112 case errors.Is(err, context.Canceled), status.Code(err) == codes.Canceled: 113 endReason = err.Error() 114 default: 115 endReason = err.Error() 116 select { 117 case errCh <- fmt.Errorf("!! %s %s, read from grpc.ClientStream failed: %w", s.Tag(), s.ID(), err): 118 default: 119 } 120 } 121 break 122 } 123 }() 124 return msgCh, errCh 125 } 126 127 // WriteLoop reads messages from the channel and writes them to the Stream. It will call CloseSend() on the 128 // stream when the channel is closed. 129 func WriteLoop( 130 ctx context.Context, 131 s Stream, msgCh <-chan Message, 132 wg *sync.WaitGroup, 133 p *CounterProbe, 134 ) { 135 dlog.Tracef(ctx, " %s %s, WriteLoop starting", s.Tag(), s.ID()) 136 go func() { 137 ctx, span := otel.GetTracerProvider().Tracer("").Start(ctx, "WriteLoop") 138 defer span.End() 139 s.ID().SpanRecord(span) 140 var endReason string 141 defer func() { 142 dlog.Tracef(ctx, " %s %s, WriteLoop ended: %s", s.Tag(), s.ID(), endReason) 143 if err := s.CloseSend(ctx); err != nil { 144 dlog.Errorf(ctx, "!! %s %s, Send of closeSend failed: %v", s.Tag(), s.ID(), err) 145 } 146 wg.Done() 147 }() 148 for { 149 select { 150 case <-ctx.Done(): 151 endReason = ctx.Err().Error() 152 case m, ok := <-msgCh: 153 if !ok { 154 endReason = "input channel is closed" 155 break 156 } 157 158 err := s.Send(ctx, m) 159 if m != nil && p != nil { 160 p.Increment(uint64(len(m.Payload()))) 161 } 162 163 switch { 164 case err == nil: 165 continue 166 case errors.Is(err, net.ErrClosed): 167 endReason = "output stream is closed" 168 default: 169 endReason = err.Error() 170 dlog.Errorf(ctx, "!! %s %s, Send failed: %v", s.Tag(), s.ID(), err) 171 } 172 } 173 break 174 } 175 }() 176 } 177 178 type stream struct { 179 grpcStream GRPCStream 180 id ConnID 181 dialTimeout time.Duration 182 roundtripLatency time.Duration 183 sessionID string 184 tag string 185 syncRatio uint32 // send and check sync after each syncRatio message 186 ackWindow uint32 // maximum permitted difference between sent and received ack 187 peerVersion uint16 188 } 189 190 func newStream(tag string, grpcStream GRPCStream) stream { 191 return stream{tag: tag, grpcStream: grpcStream, syncRatio: 8, ackWindow: 1} 192 } 193 194 func (s *stream) Tag() string { 195 return s.tag 196 } 197 198 func (s *stream) ID() ConnID { 199 return s.id 200 } 201 202 func (s *stream) PeerVersion() uint16 { 203 return s.peerVersion 204 } 205 206 func (s *stream) DialTimeout() time.Duration { 207 return s.dialTimeout 208 } 209 210 func (s *stream) RoundtripLatency() time.Duration { 211 return s.roundtripLatency 212 } 213 214 func (s *stream) SessionID() string { 215 return s.sessionID 216 } 217 218 func (s *stream) Receive(ctx context.Context) (Message, error) { 219 cm, err := s.grpcStream.Recv() 220 if err != nil { 221 return nil, err 222 } 223 m := msg(cm.Payload) 224 switch m.Code() { 225 case closeSend: 226 dlog.Tracef(ctx, "<- %s %s, close send", s.tag, s.id) 227 return nil, net.ErrClosed 228 case streamInfo: 229 dlog.Tracef(ctx, "<- %s, %s", s.tag, m) 230 default: 231 dlog.Tracef(ctx, "<- %s %s, %s", s.tag, s.id, m) 232 } 233 return m, nil 234 } 235 236 func (s *stream) Send(ctx context.Context, m Message) error { 237 if err := s.grpcStream.Send(m.TunnelMessage()); err != nil { 238 if ctx.Err() == nil && !errors.Is(err, net.ErrClosed) { 239 dlog.Errorf(ctx, "!! %s %s, Send failed: %v", s.tag, s.id, err) 240 } 241 return err 242 } 243 dlog.Tracef(ctx, "-> %s %s, %s", s.tag, s.id, m) 244 return nil 245 } 246 247 func (s *stream) CloseSend(ctx context.Context) error { 248 if err := s.Send(ctx, NewMessage(closeSend, nil)); err != nil { 249 if ctx.Err() == nil && !(errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed)) { 250 return fmt.Errorf("send of closeSend message failed: %w", err) 251 } 252 } 253 return nil 254 }