tractor.dev/toolkit-go@v0.0.0-20241010005851-214d91207d07/duplex/mux/channel.go (about) 1 package mux 2 3 import ( 4 "errors" 5 "fmt" 6 "io" 7 "sync" 8 9 "tractor.dev/toolkit-go/duplex/mux/frame" 10 ) 11 12 type channelDirection uint8 13 14 const ( 15 channelInbound channelDirection = iota 16 channelOutbound 17 ) 18 19 func min(a uint32, b int) uint32 { 20 if a < uint32(b) { 21 return a 22 } 23 return uint32(b) 24 } 25 26 type Channel interface { 27 io.ReadWriteCloser 28 ID() uint32 29 CloseWrite() error 30 } 31 32 // channel is an implementation of the Channel interface that works 33 // with the session class. 34 type channel struct { 35 36 // R/O after creation 37 localId, remoteId uint32 38 39 // maxIncomingPayload and maxRemotePayload are the maximum 40 // payload sizes of normal and extended data packets for 41 // receiving and sending, respectively. The wire packet will 42 // be 9 or 13 bytes larger (excluding encryption overhead). 43 maxIncomingPayload uint32 44 maxRemotePayload uint32 45 46 session *session 47 48 // direction contains either channelOutbound, for channels created 49 // locally, or channelInbound, for channels created by the peer. 50 direction channelDirection 51 52 // Pending internal channel messages. 53 msg chan frame.Message 54 55 sentEOF bool 56 57 // thread-safe data 58 remoteWin window 59 pending *buffer 60 61 // windowMu protects myWindow, the flow-control window. 62 windowMu sync.Mutex 63 myWindow uint32 64 65 // writeMu serializes calls to session.conn.Write() and 66 // protects sentClose and packetPool. This mutex must be 67 // different from windowMu, as writePacket can block if there 68 // is a key exchange pending. 69 writeMu sync.Mutex 70 sentClose bool 71 72 // packet buffer for writing 73 packetBuf []byte 74 } 75 76 // ID returns the unique identifier of this channel 77 // within the session 78 func (ch *channel) ID() uint32 { 79 return ch.localId 80 } 81 82 // CloseWrite signals the end of sending data. 83 // The other side may still send data 84 func (ch *channel) CloseWrite() error { 85 ch.sentEOF = true 86 return ch.send(frame.EOFMessage{ 87 ChannelID: ch.remoteId}) 88 } 89 90 // Close signals end of channel use. No data may be sent after this 91 // call. 92 func (ch *channel) Close() error { 93 return ch.send(frame.CloseMessage{ 94 ChannelID: ch.remoteId}) 95 } 96 97 // Write writes len(data) bytes to the channel. 98 func (ch *channel) Write(data []byte) (n int, err error) { 99 if ch.sentEOF { 100 return 0, io.EOF 101 } 102 103 for len(data) > 0 { 104 space := min(ch.maxRemotePayload, len(data)) 105 if space, err = ch.remoteWin.reserve(space); err != nil { 106 return n, err 107 } 108 109 toSend := data[:space] 110 111 if err = ch.session.enc.Encode(frame.DataMessage{ 112 ChannelID: ch.remoteId, 113 Length: uint32(len(toSend)), 114 Data: toSend, 115 }); err != nil { 116 return n, err 117 } 118 119 n += len(toSend) 120 data = data[len(toSend):] 121 } 122 123 return n, err 124 } 125 126 // Read reads up to len(data) bytes from the channel. 127 func (c *channel) Read(data []byte) (n int, err error) { 128 n, err = c.pending.Read(data) 129 130 if n > 0 { 131 err = c.adjustWindow(uint32(n)) 132 // sendWindowAdjust can return io.EOF if the remote 133 // peer has closed the connection, however we want to 134 // defer forwarding io.EOF to the caller of Read until 135 // the buffer has been drained. 136 if n > 0 && err == io.EOF { 137 err = nil 138 } 139 } 140 return n, err 141 } 142 143 // sends writes a message frame. If the message is a channel close, it updates 144 // sentClose. This method takes the lock c.writeMu. 145 func (ch *channel) send(msg frame.Message) error { 146 ch.writeMu.Lock() 147 defer ch.writeMu.Unlock() 148 149 if ch.sentClose { 150 return io.EOF 151 } 152 153 if _, ok := msg.(frame.CloseMessage); ok { 154 ch.sentClose = true 155 } 156 157 return ch.session.enc.Encode(msg) 158 } 159 160 func (c *channel) adjustWindow(n uint32) error { 161 c.windowMu.Lock() 162 // Since myWindow is managed on our side, and can never exceed 163 // the initial window setting, we don't worry about overflow. 164 c.myWindow += uint32(n) 165 c.windowMu.Unlock() 166 return c.send(frame.WindowAdjustMessage{ 167 ChannelID: c.remoteId, 168 AdditionalBytes: uint32(n), 169 }) 170 } 171 172 func (c *channel) close() { 173 c.pending.eof() 174 close(c.msg) 175 c.writeMu.Lock() 176 // This is not necessary for a normal channel teardown, but if 177 // there was another error, it is. 178 c.sentClose = true 179 c.writeMu.Unlock() 180 // Unblock writers. 181 c.remoteWin.close() 182 } 183 184 // responseMessageReceived is called when a success or failure message is 185 // received on a channel to check that such a message is reasonable for the 186 // given channel. 187 func (ch *channel) responseMessageReceived() error { 188 if ch.direction == channelInbound { 189 return errors.New("qmux: channel response message received on inbound channel") 190 } 191 return nil 192 } 193 194 func (ch *channel) handle(msg frame.Message) error { 195 switch m := msg.(type) { 196 case *frame.DataMessage: 197 return ch.handleData(m) 198 199 case *frame.CloseMessage: 200 ch.send(frame.CloseMessage{ 201 ChannelID: ch.remoteId, 202 }) 203 ch.session.chans.remove(ch.localId) 204 ch.close() 205 return nil 206 207 case *frame.EOFMessage: 208 ch.pending.eof() 209 return nil 210 211 case *frame.WindowAdjustMessage: 212 if !ch.remoteWin.add(m.AdditionalBytes) { 213 return fmt.Errorf("qmux: invalid window update for %d bytes", m.AdditionalBytes) 214 } 215 return nil 216 217 case *frame.OpenConfirmMessage: 218 if err := ch.responseMessageReceived(); err != nil { 219 return err 220 } 221 if m.MaxPacketSize < minPacketLength || m.MaxPacketSize > maxPacketLength { 222 return fmt.Errorf("qmux: invalid MaxPacketSize %d from peer", m.MaxPacketSize) 223 } 224 ch.remoteId = m.SenderID 225 ch.maxRemotePayload = m.MaxPacketSize 226 ch.remoteWin.add(m.WindowSize) 227 ch.msg <- m 228 return nil 229 230 case *frame.OpenFailureMessage: 231 if err := ch.responseMessageReceived(); err != nil { 232 return err 233 } 234 ch.session.chans.remove(m.ChannelID) 235 ch.msg <- m 236 return nil 237 238 default: 239 return fmt.Errorf("qmux: invalid channel message %v", msg) 240 } 241 } 242 243 func (ch *channel) handleData(msg *frame.DataMessage) error { 244 if msg.Length > ch.maxIncomingPayload { 245 // TODO(hanwen): should send Disconnect? 246 return errors.New("qmux: incoming packet exceeds maximum payload size") 247 } 248 249 if msg.Length != uint32(len(msg.Data)) { 250 return errors.New("qmux: wrong packet length") 251 } 252 253 ch.windowMu.Lock() 254 if ch.myWindow < msg.Length { 255 ch.windowMu.Unlock() 256 // TODO(hanwen): should send Disconnect with reason? 257 return errors.New("qmux: remote side wrote too much") 258 } 259 ch.myWindow -= msg.Length 260 ch.windowMu.Unlock() 261 262 ch.pending.write(msg.Data) 263 return nil 264 }