github.com/telepresenceio/telepresence/v2@v2.20.0-pro.6.0.20240517030216-236ea954e789/pkg/tunnel/dialer.go (about) 1 package tunnel 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "io" 8 "net" 9 "sync" 10 "sync/atomic" 11 "time" 12 13 "go.opentelemetry.io/otel" 14 "go.opentelemetry.io/otel/codes" 15 "go.opentelemetry.io/otel/propagation" 16 17 "github.com/datawire/dlib/dlog" 18 rpc "github.com/telepresenceio/telepresence/rpc/v2/manager" 19 "github.com/telepresenceio/telepresence/v2/pkg/ipproto" 20 ) 21 22 // The idleDuration controls how long a dialer for a specific proto+from-to address combination remains alive without 23 // reading or writing any messages. The dialer is normally closed by one of the peers. 24 const ( 25 tcpConnTTL = 2 * time.Hour // Default tcp_keepalive_time on Linux 26 udpConnTTL = 1 * time.Minute 27 ) 28 29 const ( 30 notConnected = int32(iota) 31 connecting 32 connected 33 ) 34 35 // streamReader is implemented by the dialer and udpListener so that they can share the 36 // readLoop function. 37 type streamReader interface { 38 Idle() <-chan time.Time 39 ResetIdle() bool 40 Stop(context.Context) 41 getStream() Stream 42 reply([]byte) (int, error) 43 startDisconnect(context.Context, string) 44 } 45 46 // The dialer takes care of dispatching messages between gRPC and UDP connections. 47 type dialer struct { 48 TimedHandler 49 stream Stream 50 cancel context.CancelFunc 51 conn net.Conn 52 connected int32 53 done chan struct{} 54 55 ingressBytesProbe *CounterProbe 56 egressBytesProbe *CounterProbe 57 } 58 59 // NewDialer creates a new handler that dispatches messages in both directions between the given gRPC stream 60 // and the given connection. 61 func NewDialer( 62 stream Stream, 63 cancel context.CancelFunc, 64 ingressBytesProbe, egressBytesProbe *CounterProbe, 65 ) Endpoint { 66 return NewConnEndpoint(stream, nil, cancel, ingressBytesProbe, egressBytesProbe) 67 } 68 69 // NewDialerTTL creates a new handler that dispatches messages in both directions between the given gRPC stream 70 // and the given connection. The TTL decides how long the connection can be idle before it's closed. 71 // 72 // The handler remains active until it's been idle for the ttl duration, at which time it will automatically close 73 // and call the release function it got from the tunnel.Pool to ensure that it gets properly released. 74 func NewDialerTTL(stream Stream, cancel context.CancelFunc, ttl time.Duration, ingressBytesProbe, egressBytesProbe *CounterProbe) Endpoint { 75 return NewConnEndpointTTL(stream, nil, cancel, ttl, ingressBytesProbe, egressBytesProbe) 76 } 77 78 func NewConnEndpoint(stream Stream, conn net.Conn, cancel context.CancelFunc, ingressBytesProbe, egressBytesProbe *CounterProbe) Endpoint { 79 ttl := tcpConnTTL 80 if stream.ID().Protocol() == ipproto.UDP { 81 ttl = udpConnTTL 82 } 83 return NewConnEndpointTTL(stream, conn, cancel, ttl, ingressBytesProbe, egressBytesProbe) 84 } 85 86 func NewConnEndpointTTL( 87 stream Stream, 88 conn net.Conn, 89 cancel context.CancelFunc, 90 ttl time.Duration, 91 ingressBytesProbe, egressBytesProbe *CounterProbe, 92 ) Endpoint { 93 state := notConnected 94 if conn != nil { 95 state = connecting 96 } 97 return &dialer{ 98 TimedHandler: NewTimedHandler(stream.ID(), ttl, nil), 99 stream: stream, 100 cancel: cancel, 101 conn: conn, 102 connected: state, 103 done: make(chan struct{}), 104 105 ingressBytesProbe: ingressBytesProbe, 106 egressBytesProbe: egressBytesProbe, 107 } 108 } 109 110 func (h *dialer) Start(ctx context.Context) { 111 go func() { 112 ctx, span := otel.Tracer("").Start(ctx, "dialer") 113 defer span.End() 114 defer close(h.done) 115 116 id := h.stream.ID() 117 id.SpanRecord(span) 118 119 switch h.connected { 120 case notConnected: 121 // Set up the idle timer to close and release this handler when it's been idle for a while. 122 h.connected = connecting 123 124 dlog.Tracef(ctx, " CONN %s, dialing", id) 125 d := net.Dialer{Timeout: h.stream.DialTimeout()} 126 conn, err := d.DialContext(ctx, id.DestinationProtocolString(), id.DestinationAddr().String()) 127 if err != nil { 128 dlog.Errorf(ctx, "!! CONN %s, failed to establish connection: %v", id, err) 129 span.SetStatus(codes.Error, err.Error()) 130 if err = h.stream.Send(ctx, NewMessage(DialReject, nil)); err != nil { 131 dlog.Errorf(ctx, "!! CONN %s, failed to send DialReject: %v", id, err) 132 } 133 if err = h.stream.CloseSend(ctx); err != nil { 134 dlog.Errorf(ctx, "!! CONN %s, stream.CloseSend failed: %v", id, err) 135 } 136 h.connected = notConnected 137 return 138 } 139 if err = h.stream.Send(ctx, NewMessage(DialOK, nil)); err != nil { 140 _ = conn.Close() 141 dlog.Errorf(ctx, "!! CONN %s, failed to send DialOK: %v", id, err) 142 span.SetStatus(codes.Error, err.Error()) 143 return 144 } 145 dlog.Tracef(ctx, " CONN %s, dial answered", id) 146 h.conn = conn 147 148 case connecting: 149 default: 150 dlog.Errorf(ctx, "!! CONN %s, start called in invalid state", id) 151 return 152 } 153 154 // Set up the idle timer to close and release this endpoint when it's been idle for a while. 155 h.TimedHandler.Start(ctx) 156 h.connected = connected 157 158 wg := sync.WaitGroup{} 159 wg.Add(2) 160 go h.connToStreamLoop(ctx, &wg) 161 go h.streamToConnLoop(ctx, &wg) 162 wg.Wait() 163 h.Stop(ctx) 164 }() 165 } 166 167 func (h *dialer) Done() <-chan struct{} { 168 return h.done 169 } 170 171 // Stop will close the underlying TCP/UDP connection. 172 func (h *dialer) Stop(ctx context.Context) { 173 h.startDisconnect(ctx, "explicit close") 174 h.cancel() 175 } 176 177 func (h *dialer) startDisconnect(ctx context.Context, reason string) { 178 if !atomic.CompareAndSwapInt32(&h.connected, connected, notConnected) { 179 return 180 } 181 id := h.stream.ID() 182 dlog.Tracef(ctx, " CONN %s closing connection: %s", id, reason) 183 if err := h.conn.Close(); err != nil { 184 dlog.Tracef(ctx, "!! CONN %s, Close failed: %v", id, err) 185 } 186 } 187 188 func (h *dialer) connToStreamLoop(ctx context.Context, wg *sync.WaitGroup) { 189 var endReason string 190 endLevel := dlog.LogLevelTrace 191 id := h.stream.ID() 192 193 outgoing := make(chan Message, 50) 194 defer func() { 195 if !h.ResetIdle() { 196 // Hard close of peer. We don't want any more data 197 select { 198 case outgoing <- NewMessage(Disconnect, nil): 199 default: 200 } 201 } 202 close(outgoing) 203 dlog.Logf(ctx, endLevel, " CONN %s conn-to-stream loop ended because %s", id, endReason) 204 wg.Done() 205 }() 206 207 wg.Add(1) 208 WriteLoop(ctx, h.stream, outgoing, wg, h.egressBytesProbe) 209 210 buf := make([]byte, 0x100000) 211 dlog.Tracef(ctx, " CONN %s conn-to-stream loop started", id) 212 for { 213 n, err := h.conn.Read(buf) 214 if n > 0 { 215 dlog.Tracef(ctx, "<- CONN %s, len %d", id, n) 216 select { 217 case <-ctx.Done(): 218 endReason = ctx.Err().Error() 219 return 220 case outgoing <- NewMessage(Normal, buf[:n]): 221 } 222 } 223 224 if err != nil { 225 switch { 226 case errors.Is(err, io.EOF): 227 endReason = "EOF was encountered" 228 case errors.Is(err, net.ErrClosed): 229 endReason = "the connection was closed" 230 h.startDisconnect(ctx, endReason) 231 default: 232 endReason = fmt.Sprintf("a read error occurred: %v", err) 233 endLevel = dlog.LogLevelError 234 } 235 return 236 } 237 238 if !h.ResetIdle() { 239 endReason = "it was idle for too long" 240 return 241 } 242 } 243 } 244 245 func (h *dialer) getStream() Stream { 246 return h.stream 247 } 248 249 func (h *dialer) reply(data []byte) (int, error) { 250 return h.conn.Write(data) 251 } 252 253 func (h *dialer) streamToConnLoop(ctx context.Context, wg *sync.WaitGroup) { 254 defer func() { 255 wg.Done() 256 }() 257 readLoop(ctx, h, h.ingressBytesProbe) 258 } 259 260 func handleControl(ctx context.Context, h streamReader, cm Message) { 261 switch cm.Code() { 262 case DialReject, Disconnect: // Peer wants to hard-close. No more messages will arrive 263 h.Stop(ctx) 264 case KeepAlive: 265 h.ResetIdle() 266 case DialOK: 267 // So how can a dialer get a DialOK from a peer? Surely, there cannot be a dialer at both ends? 268 // Well, the story goes like this: 269 // 1. A request to the service is made on the workstation. 270 // 2. This agent's listener receives a connection. 271 // 3. Since an intercept is active, the agent creates a tunnel to the workstation 272 // 4. A new dialer is attached to that tunnel (reused as a tunnel endpoint) 273 // 5. The dialer at the workstation dials and responds with DialOK, and here we are. 274 default: 275 dlog.Errorf(ctx, "!! CONN %s: unhandled connection control message: %s", h.getStream().ID(), cm) 276 } 277 } 278 279 func readLoop(ctx context.Context, h streamReader, trafficProbe *CounterProbe) { 280 var endReason string 281 endLevel := dlog.LogLevelTrace 282 id := h.getStream().ID() 283 defer func() { 284 h.startDisconnect(ctx, endReason) 285 dlog.Logf(ctx, endLevel, " CONN %s stream-to-conn loop ended because %s", id, endReason) 286 }() 287 288 incoming, errCh := ReadLoop(ctx, h.getStream(), trafficProbe) 289 dlog.Tracef(ctx, " CONN %s stream-to-conn loop started", id) 290 for { 291 select { 292 case <-ctx.Done(): 293 endReason = ctx.Err().Error() 294 return 295 case <-h.Idle(): 296 endReason = "it was idle for too long" 297 return 298 case err, ok := <-errCh: 299 if ok { 300 dlog.Error(ctx, err) 301 } 302 case dg, ok := <-incoming: 303 if !ok { 304 // h.incoming was closed by the reader and is now drained. 305 endReason = "there was no more input" 306 return 307 } 308 if !h.ResetIdle() { 309 endReason = "it was idle for too long" 310 return 311 } 312 if dg.Code() != Normal { 313 handleControl(ctx, h, dg) 314 continue 315 } 316 payload := dg.Payload() 317 pn := len(payload) 318 for n := 0; n < pn; { 319 wn, err := h.reply(payload[n:]) 320 if err != nil { 321 endReason = fmt.Sprintf("a write error occurred: %v", err) 322 endLevel = dlog.LogLevelError 323 return 324 } 325 dlog.Tracef(ctx, "-> CONN %s, len %d", id, wn) 326 n += wn 327 } 328 } 329 } 330 } 331 332 // DialWaitLoop reads from the given dialStream. A new goroutine that creates a Tunnel to the manager and then 333 // attaches a dialer Endpoint to that tunnel is spawned for each request that arrives. The method blocks until 334 // the dialStream is closed. 335 func DialWaitLoop( 336 ctx context.Context, 337 tunnelProvider Provider, 338 dialStream rpc.Manager_WatchDialClient, 339 sessionID string, 340 ) error { 341 // create ctx to cleanup leftover dialRespond if waitloop dies 342 ctx, cancel := context.WithCancel(ctx) 343 defer cancel() 344 for ctx.Err() == nil { 345 dr, err := dialStream.Recv() 346 if err != nil { 347 if ctx.Err() == nil && !(errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed)) { 348 return fmt.Errorf("dial request stream recv: %w", err) // May be io.EOF 349 } 350 return nil 351 } 352 go dialRespond(ctx, tunnelProvider, dr, sessionID) 353 } 354 return nil 355 } 356 357 func dialRespond(ctx context.Context, tunnelProvider Provider, dr *rpc.DialRequest, sessionID string) { 358 if tc := dr.GetTraceContext(); tc != nil { 359 carrier := propagation.MapCarrier(tc) 360 propagator := otel.GetTextMapPropagator() 361 ctx = propagator.Extract(ctx, carrier) 362 } 363 ctx, span := otel.Tracer("").Start(ctx, "dialRespond") 364 defer span.End() 365 id := ConnID(dr.ConnId) 366 id.SpanRecord(span) 367 mt, err := tunnelProvider.Tunnel(ctx) 368 if err != nil { 369 dlog.Errorf(ctx, "!! CONN %s, call to manager Tunnel failed: %v", id, err) 370 return 371 } 372 ctx, cancel := context.WithCancel(ctx) 373 s, err := NewClientStream(ctx, mt, id, sessionID, time.Duration(dr.RoundtripLatency), time.Duration(dr.DialTimeout)) 374 if err != nil { 375 dlog.Error(ctx, err) 376 cancel() 377 return 378 } 379 d := NewDialer(s, cancel, nil, nil) 380 d.Start(ctx) 381 <-d.Done() 382 }