github.com/ice-blockchain/go/src@v0.0.0-20240403114104-1564d284e521/net/http/h2_wt_handshake.go (about) 1 // SPDX-License-Identifier: ice License 1.0 2 3 package http 4 5 import ( 6 "bufio" 7 "context" 8 "io" 9 "net/http" 10 "strconv" 11 stdlibtime "time" 12 13 "github.com/hashicorp/go-multierror" 14 "github.com/pkg/errors" 15 "github.com/quic-go/quic-go" 16 "github.com/quic-go/quic-go/http3" 17 "github.com/quic-go/quic-go/quicvarint" 18 "github.com/quic-go/webtransport-go" 19 ) 20 21 const ( 22 wtCapsuleResetStream = 0x190B4D39 23 wtCapsuleStopSending = 0x190B4D3A 24 wtCapsuleStream = 0x190B4D3B 25 wtCapsuleStreamFin = 0x190B4D3C 26 wtCapsuleMaxData = 0x190B4D3D 27 wtCapsuleMaxStreamData = 0x190B4D3E 28 wtCapsuleMaxStreams = 0x190B4D3F 29 wtCapsuleMaxStreamsUni = 0x190B4D40 30 wtCapsuleCloseWebTransportSession = 0x2843 31 wtCapsuleDrainWebTransportSession = 0x78ae 32 ) 33 34 type ( 35 Session interface { 36 AcceptStream(ctx context.Context) webtransport.Stream 37 } 38 WebTransportUpgrader interface { 39 UpgradeWebTransport() (Session, error) 40 } 41 webtransportStream struct { 42 rw *http2responseWriter 43 streamReceivedCh chan io.Reader 44 readFinished chan struct{} 45 closedCh chan struct{} 46 reader io.Reader 47 streamReceived bool 48 finReceived bool 49 } 50 wtMaxData struct { 51 MaxData uint64 52 } 53 wtMaxStreams struct { 54 MaxStreams uint64 55 } 56 wtMaxStreamData struct { 57 StreamID uint64 58 MaxData uint64 59 } 60 wtStream struct { 61 ws *webtransportStream 62 StreamID uint32 63 StreamData []byte 64 } 65 ) 66 67 func (rw *http2responseWriter) UpgradeWebTransport() (Session, error) { 68 if !(rw.rws.req.Method == http.MethodConnect && rw.rws.req.Proto == "webtransport") { 69 rw.WriteHeader(400) 70 return nil, errors.New("invalid protocol") 71 } 72 rw.Header().Add(headerCapsuleProtocol, strconv.FormatBool(true)) 73 74 rw.WriteHeader(http.StatusOK) 75 wts := &webtransportStream{ 76 rw: rw, 77 streamReceivedCh: make(chan io.Reader, 1), 78 readFinished: make(chan struct{}, 1), 79 closedCh: make(chan struct{}), 80 } 81 rw.rws.conn.webtransportSessions.Store(rw.rws.stream.id, wts) 82 83 return rw, nil 84 } 85 86 func (rw *http2responseWriter) AcceptStream(ctx context.Context) webtransport.Stream { 87 var stream *webtransportStream 88 if s, ok := rw.rws.conn.webtransportSessions.Load(rw.rws.stream.id); ok { 89 stream = s.(*webtransportStream) 90 } 91 go stream.handleWebTransportStream() 92 return stream 93 } 94 95 func (s *webtransportStream) handleWebTransportStream() { 96 defer func() { close(s.closedCh) }() 97 for { 98 if s.finReceived || s.rw.rws == nil || s.rw.rws.handlerDone || s.rw.rws.stream.state == http2stateClosed { 99 return 100 } 101 if s.streamReceived { 102 select { 103 case <-s.readFinished: 104 } 105 } 106 cType, data, err := http3.ParseCapsule(quicvarint.NewReader(s.rw.rws.req.Body)) 107 cData := bufio.NewReader(data) 108 if err != nil { 109 if !errors.Is(err, http2errClientDisconnected) { 110 if s.rw.rws != nil { 111 s.rw.rws.conn.logf("failed to parse capsule error (http2/wt/rfc9297) %v", err) 112 } 113 } 114 break 115 } 116 if cType > 0 { 117 //log.Printf(fmt.Sprintf("http2/wt: Received capsule 0x%v", strconv.FormatUint(uint64(cType), 16))) 118 switch cType { 119 case wtCapsuleMaxStreamData: 120 md := new(wtMaxStreamData) 121 if err = md.Deserialize(cData); err == nil { 122 s.rw.rws.stream.sc.sendWindowUpdate(s.rw.rws.stream, int(md.MaxData)) 123 s.rw.Flush() 124 } 125 case wtCapsuleMaxData: 126 md := new(wtMaxData) 127 if err = md.Deserialize(cData); err == nil { 128 s.rw.rws.stream.sc.sendWindowUpdate(s.rw.rws.stream, int(md.MaxData)) 129 s.rw.Flush() 130 } 131 case wtCapsuleMaxStreams: 132 ms := new(wtMaxStreams) 133 if err = ms.Deserialize(cData); err == nil { 134 s.rw.rws.stream.sc.advMaxStreams = uint32(ms.MaxStreams) 135 } 136 case wtCapsuleStreamFin: 137 case wtCapsuleStream: 138 s.finReceived = cType == wtCapsuleStreamFin 139 str := &wtStream{ws: s} 140 err = str.Deserialize(cData) 141 case wtCapsuleResetStream: 142 s.rw.rws.stream.endStream() 143 case wtCapsuleStopSending: 144 s.rw.handlerDone() 145 return 146 case wtCapsuleDrainWebTransportSession: 147 s.rw.rws.conn.startGracefulShutdown() 148 case wtCapsuleCloseWebTransportSession: 149 s.rw.handlerDone() 150 return 151 default: 152 _, err = io.ReadAll(cData) 153 } 154 155 if err != nil { 156 if s.rw.rws != nil { 157 s.rw.rws.conn.logf("failed to process capsule (http2/wt/rfc9297): %v", err) 158 } 159 } 160 } 161 } 162 } 163 164 func (s *wtStream) Deserialize(dataReader quicvarint.Reader) (err error) { 165 var sID uint64 166 sID, err = quicvarint.Read(dataReader) 167 if err != nil { 168 err = errors.Wrapf(err, "failed to parse WT_STREAM/StreamID") 169 return err 170 } 171 s.StreamID = uint32(sID) 172 s.ws.streamReceivedCh <- dataReader 173 return errors.Wrapf(err, "failed to copy content from WT_STREAM") 174 } 175 func (s *wtStream) Serialize() []byte { 176 b := make([]byte, 0, 4+len(s.StreamData)) 177 b = quicvarint.Append(b, uint64(s.StreamID)) 178 b = append(b, s.StreamData...) 179 180 return b 181 } 182 183 func (s *webtransportStream) Write(p []byte) (n int, err error) { 184 err = s.rw.WriteCapsule(wtCapsuleStream, &wtStream{StreamData: p, StreamID: s.rw.rws.stream.id}) 185 186 return len(p), err 187 } 188 189 func (s *webtransportStream) Close() error { 190 if s.rw.rws != nil { 191 s.rw.handlerDone() 192 } 193 return nil 194 } 195 196 func (s *webtransportStream) StreamID() quic.StreamID { 197 return quic.StreamID(s.rw.rws.stream.id) 198 } 199 200 func (s *webtransportStream) CancelWrite(code webtransport.StreamErrorCode) { 201 202 } 203 204 func (s *webtransportStream) SetWriteDeadline(t stdlibtime.Time) error { 205 return nil 206 } 207 208 func (s *webtransportStream) Read(p []byte) (n int, err error) { 209 if s.finReceived || s.rw.rws.handlerDone || s.rw.rws.stream.state == http2stateClosed { 210 return 0, io.EOF 211 } 212 var r io.Reader 213 if s.reader == nil { 214 select { 215 case r = <-s.streamReceivedCh: 216 s.reader = r 217 case <-s.closedCh: 218 return 0, io.EOF 219 case <-s.rw.CloseNotify(): 220 return 0, io.EOF 221 } 222 } 223 if r != s.reader { 224 s.reader = nil 225 return 0, nil 226 } 227 n, err = s.reader.Read(p) 228 if errors.Is(err, io.EOF) { 229 s.reader = nil 230 s.readFinished <- struct{}{} 231 err = nil 232 } 233 return n, err 234 } 235 236 func (s *webtransportStream) CancelRead(code webtransport.StreamErrorCode) { 237 } 238 239 func (s *webtransportStream) SetReadDeadline(t stdlibtime.Time) error { 240 return nil 241 } 242 243 func (s *webtransportStream) SetDeadline(t stdlibtime.Time) error { 244 return multierror.Append( 245 s.rw.SetReadDeadline(t), 246 s.rw.SetWriteDeadline(t), 247 ).ErrorOrNil() 248 } 249 250 func (md *wtMaxData) Deserialize(dataReader quicvarint.Reader) (err error) { 251 if md.MaxData, err = quicvarint.Read(dataReader); err != nil { 252 err = errors.Wrapf(err, "failed to parse WT_MAX_DATA/MaxData") 253 } 254 return err 255 } 256 func (ms *wtMaxStreams) Deserialize(dataReader quicvarint.Reader) (err error) { 257 if ms.MaxStreams, err = quicvarint.Read(dataReader); err != nil { 258 err = errors.Wrapf(err, "failed to parse WT_MAX_STREAMS/Maximum Streams") 259 } 260 return err 261 } 262 func (md *wtMaxStreamData) Deserialize(dataReader quicvarint.Reader) (err error) { 263 if md.StreamID, err = quicvarint.Read(dataReader); err != nil { 264 err = errors.Wrapf(err, "failed to parse WT_MAX_STREAM_DATA/StreamID") 265 return err 266 } 267 if md.MaxData, err = quicvarint.Read(dataReader); err != nil { 268 return errors.Wrapf(err, "failed to parse WT_MAX_STREAM_DATA/MaxData") 269 } 270 return err 271 }