tractor.dev/toolkit-go@v0.0.0-20241010005851-214d91207d07/duplex/mux/session.go (about) 1 package mux 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "net" 8 "sync" 9 "time" 10 11 "tractor.dev/toolkit-go/duplex/mux/frame" 12 ) 13 14 const ( 15 minPacketLength = 9 16 maxPacketLength = 1 << 31 17 18 // channelMaxPacket contains the maximum number of bytes that will be 19 // sent in a single packet. 20 channelMaxPacket = 1 << 24 // ~16MB, arbitrary 21 // We follow OpenSSH here. 22 channelWindowSize = 64 * channelMaxPacket 23 24 // chanSize sets the amount of buffering qmux connections. This is 25 // primarily for testing: setting chanSize=0 uncovers deadlocks more 26 // quickly. 27 chanSize = 16 28 ) 29 30 var ( 31 // timeout for queuing a new channel to be `Accept`ed 32 // use a `var` so that this can be overridden in tests 33 openTimeout = 30 * time.Second 34 ) 35 36 // Session is a bi-directional channel muxing session on a given transport. 37 type Session interface { 38 io.Closer 39 Accept() (Channel, error) 40 Open(ctx context.Context) (Channel, error) 41 Wait() error 42 } 43 44 type session struct { 45 t io.ReadWriteCloser 46 chans chanList 47 48 enc *frame.Encoder 49 dec *frame.Decoder 50 51 inbox chan Channel 52 53 errCond *sync.Cond 54 err error 55 closeCh chan bool 56 } 57 58 // NewSession returns a session that runs over the given transport. 59 func New(t io.ReadWriteCloser) Session { 60 if t == nil { 61 return nil 62 } 63 s := &session{ 64 t: t, 65 enc: frame.NewEncoder(t), 66 dec: frame.NewDecoder(t), 67 inbox: make(chan Channel), 68 errCond: sync.NewCond(new(sync.Mutex)), 69 closeCh: make(chan bool, 1), 70 } 71 go s.loop() 72 return s 73 } 74 75 // Close closes the underlying transport. 76 func (s *session) Close() error { 77 s.t.Close() 78 return nil 79 } 80 81 // Wait blocks until the transport has shut down, and returns the 82 // error causing the shutdown. 83 func (s *session) Wait() error { 84 s.errCond.L.Lock() 85 defer s.errCond.L.Unlock() 86 for s.err == nil { 87 s.errCond.Wait() 88 } 89 return s.err 90 } 91 92 // Accept waits for and returns the next incoming channel. 93 func (s *session) Accept() (Channel, error) { 94 select { 95 case ch := <-s.inbox: 96 return ch, nil 97 case <-s.closeCh: 98 return nil, io.EOF 99 } 100 } 101 102 // Open establishes a new channel with the other end. 103 func (s *session) Open(ctx context.Context) (Channel, error) { 104 ch := s.newChannel(channelOutbound) 105 ch.maxIncomingPayload = channelMaxPacket 106 107 if err := s.enc.Encode(frame.OpenMessage{ 108 WindowSize: ch.myWindow, 109 MaxPacketSize: ch.maxIncomingPayload, 110 SenderID: ch.localId, 111 }); err != nil { 112 return nil, err 113 } 114 115 var m frame.Message 116 117 select { 118 case <-ctx.Done(): 119 return nil, ctx.Err() 120 case m = <-ch.msg: 121 if m == nil { 122 // channel was closed before open got a response, 123 // typically meaning the session/conn was closed. 124 return nil, net.ErrClosed 125 } 126 } 127 128 switch msg := m.(type) { 129 case *frame.OpenConfirmMessage: 130 return ch, nil 131 case *frame.OpenFailureMessage: 132 return nil, fmt.Errorf("qmux: channel open failed on remote side") 133 default: 134 return nil, fmt.Errorf("qmux: unexpected packet in response to channel open: %v", msg) 135 } 136 } 137 138 func (s *session) newChannel(direction channelDirection) *channel { 139 ch := &channel{ 140 remoteWin: window{Cond: sync.NewCond(new(sync.Mutex))}, 141 myWindow: channelWindowSize, 142 pending: newBuffer(), 143 direction: direction, 144 msg: make(chan frame.Message, chanSize), 145 session: s, 146 packetBuf: make([]byte, 0), 147 } 148 ch.localId = s.chans.add(ch) 149 return ch 150 } 151 152 // loop runs the connection machine. It will process packets until an 153 // error is encountered. To synchronize on loop exit, use session.Wait. 154 func (s *session) loop() { 155 var err error 156 for err == nil { 157 err = s.onePacket() 158 } 159 160 for _, ch := range s.chans.dropAll() { 161 ch.close() 162 } 163 164 s.t.Close() 165 s.closeCh <- true 166 167 s.errCond.L.Lock() 168 s.err = err 169 s.errCond.Broadcast() 170 s.errCond.L.Unlock() 171 } 172 173 // onePacket reads and processes one packet. 174 func (s *session) onePacket() error { 175 var err error 176 var msg frame.Message 177 178 msg, err = s.dec.Decode() 179 if err != nil { 180 return err 181 } 182 183 id, isChan := msg.Channel() 184 if !isChan { 185 return s.handleOpen(msg.(*frame.OpenMessage)) 186 } 187 188 ch := s.chans.getChan(id) 189 if ch == nil { 190 return fmt.Errorf("qmux: invalid channel %d", id) 191 } 192 193 return ch.handle(msg) 194 } 195 196 // handleChannelOpen schedules a channel to be Accept()ed. 197 func (s *session) handleOpen(msg *frame.OpenMessage) error { 198 if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > maxPacketLength { 199 return s.enc.Encode(frame.OpenFailureMessage{ 200 ChannelID: msg.SenderID, 201 }) 202 } 203 204 c := s.newChannel(channelInbound) 205 c.remoteId = msg.SenderID 206 c.maxRemotePayload = msg.MaxPacketSize 207 c.remoteWin.add(msg.WindowSize) 208 c.maxIncomingPayload = channelMaxPacket 209 t := time.NewTimer(openTimeout) 210 defer t.Stop() 211 select { 212 case s.inbox <- c: 213 return s.enc.Encode(frame.OpenConfirmMessage{ 214 ChannelID: c.remoteId, 215 SenderID: c.localId, 216 WindowSize: c.myWindow, 217 MaxPacketSize: c.maxIncomingPayload, 218 }) 219 case <-t.C: 220 return s.enc.Encode(frame.OpenFailureMessage{ 221 ChannelID: msg.SenderID, 222 }) 223 } 224 }