github.com/ice-blockchain/go/src@v0.0.0-20240403114104-1564d284e521/net/http/h2_wt_handshake.go (about)

     1  // SPDX-License-Identifier: ice License 1.0
     2  
     3  package http
     4  
     5  import (
     6  	"bufio"
     7  	"context"
     8  	"io"
     9  	"net/http"
    10  	"strconv"
    11  	stdlibtime "time"
    12  
    13  	"github.com/hashicorp/go-multierror"
    14  	"github.com/pkg/errors"
    15  	"github.com/quic-go/quic-go"
    16  	"github.com/quic-go/quic-go/http3"
    17  	"github.com/quic-go/quic-go/quicvarint"
    18  	"github.com/quic-go/webtransport-go"
    19  )
    20  
    21  const (
    22  	wtCapsuleResetStream              = 0x190B4D39
    23  	wtCapsuleStopSending              = 0x190B4D3A
    24  	wtCapsuleStream                   = 0x190B4D3B
    25  	wtCapsuleStreamFin                = 0x190B4D3C
    26  	wtCapsuleMaxData                  = 0x190B4D3D
    27  	wtCapsuleMaxStreamData            = 0x190B4D3E
    28  	wtCapsuleMaxStreams               = 0x190B4D3F
    29  	wtCapsuleMaxStreamsUni            = 0x190B4D40
    30  	wtCapsuleCloseWebTransportSession = 0x2843
    31  	wtCapsuleDrainWebTransportSession = 0x78ae
    32  )
    33  
    34  type (
    35  	Session interface {
    36  		AcceptStream(ctx context.Context) webtransport.Stream
    37  	}
    38  	WebTransportUpgrader interface {
    39  		UpgradeWebTransport() (Session, error)
    40  	}
    41  	webtransportStream struct {
    42  		rw               *http2responseWriter
    43  		streamReceivedCh chan io.Reader
    44  		readFinished     chan struct{}
    45  		closedCh         chan struct{}
    46  		reader           io.Reader
    47  		streamReceived   bool
    48  		finReceived      bool
    49  	}
    50  	wtMaxData struct {
    51  		MaxData uint64
    52  	}
    53  	wtMaxStreams struct {
    54  		MaxStreams uint64
    55  	}
    56  	wtMaxStreamData struct {
    57  		StreamID uint64
    58  		MaxData  uint64
    59  	}
    60  	wtStream struct {
    61  		ws         *webtransportStream
    62  		StreamID   uint32
    63  		StreamData []byte
    64  	}
    65  )
    66  
    67  func (rw *http2responseWriter) UpgradeWebTransport() (Session, error) {
    68  	if !(rw.rws.req.Method == http.MethodConnect && rw.rws.req.Proto == "webtransport") {
    69  		rw.WriteHeader(400)
    70  		return nil, errors.New("invalid protocol")
    71  	}
    72  	rw.Header().Add(headerCapsuleProtocol, strconv.FormatBool(true))
    73  
    74  	rw.WriteHeader(http.StatusOK)
    75  	wts := &webtransportStream{
    76  		rw:               rw,
    77  		streamReceivedCh: make(chan io.Reader, 1),
    78  		readFinished:     make(chan struct{}, 1),
    79  		closedCh:         make(chan struct{}),
    80  	}
    81  	rw.rws.conn.webtransportSessions.Store(rw.rws.stream.id, wts)
    82  
    83  	return rw, nil
    84  }
    85  
    86  func (rw *http2responseWriter) AcceptStream(ctx context.Context) webtransport.Stream {
    87  	var stream *webtransportStream
    88  	if s, ok := rw.rws.conn.webtransportSessions.Load(rw.rws.stream.id); ok {
    89  		stream = s.(*webtransportStream)
    90  	}
    91  	go stream.handleWebTransportStream()
    92  	return stream
    93  }
    94  
    95  func (s *webtransportStream) handleWebTransportStream() {
    96  	defer func() { close(s.closedCh) }()
    97  	for {
    98  		if s.finReceived || s.rw.rws == nil || s.rw.rws.handlerDone || s.rw.rws.stream.state == http2stateClosed {
    99  			return
   100  		}
   101  		if s.streamReceived {
   102  			select {
   103  			case <-s.readFinished:
   104  			}
   105  		}
   106  		cType, data, err := http3.ParseCapsule(quicvarint.NewReader(s.rw.rws.req.Body))
   107  		cData := bufio.NewReader(data)
   108  		if err != nil {
   109  			if !errors.Is(err, http2errClientDisconnected) {
   110  				if s.rw.rws != nil {
   111  					s.rw.rws.conn.logf("failed to parse capsule error (http2/wt/rfc9297) %v", err)
   112  				}
   113  			}
   114  			break
   115  		}
   116  		if cType > 0 {
   117  			//log.Printf(fmt.Sprintf("http2/wt: Received capsule 0x%v", strconv.FormatUint(uint64(cType), 16)))
   118  			switch cType {
   119  			case wtCapsuleMaxStreamData:
   120  				md := new(wtMaxStreamData)
   121  				if err = md.Deserialize(cData); err == nil {
   122  					s.rw.rws.stream.sc.sendWindowUpdate(s.rw.rws.stream, int(md.MaxData))
   123  					s.rw.Flush()
   124  				}
   125  			case wtCapsuleMaxData:
   126  				md := new(wtMaxData)
   127  				if err = md.Deserialize(cData); err == nil {
   128  					s.rw.rws.stream.sc.sendWindowUpdate(s.rw.rws.stream, int(md.MaxData))
   129  					s.rw.Flush()
   130  				}
   131  			case wtCapsuleMaxStreams:
   132  				ms := new(wtMaxStreams)
   133  				if err = ms.Deserialize(cData); err == nil {
   134  					s.rw.rws.stream.sc.advMaxStreams = uint32(ms.MaxStreams)
   135  				}
   136  			case wtCapsuleStreamFin:
   137  			case wtCapsuleStream:
   138  				s.finReceived = cType == wtCapsuleStreamFin
   139  				str := &wtStream{ws: s}
   140  				err = str.Deserialize(cData)
   141  			case wtCapsuleResetStream:
   142  				s.rw.rws.stream.endStream()
   143  			case wtCapsuleStopSending:
   144  				s.rw.handlerDone()
   145  				return
   146  			case wtCapsuleDrainWebTransportSession:
   147  				s.rw.rws.conn.startGracefulShutdown()
   148  			case wtCapsuleCloseWebTransportSession:
   149  				s.rw.handlerDone()
   150  				return
   151  			default:
   152  				_, err = io.ReadAll(cData)
   153  			}
   154  
   155  			if err != nil {
   156  				if s.rw.rws != nil {
   157  					s.rw.rws.conn.logf("failed to process capsule (http2/wt/rfc9297): %v", err)
   158  				}
   159  			}
   160  		}
   161  	}
   162  }
   163  
   164  func (s *wtStream) Deserialize(dataReader quicvarint.Reader) (err error) {
   165  	var sID uint64
   166  	sID, err = quicvarint.Read(dataReader)
   167  	if err != nil {
   168  		err = errors.Wrapf(err, "failed to parse WT_STREAM/StreamID")
   169  		return err
   170  	}
   171  	s.StreamID = uint32(sID)
   172  	s.ws.streamReceivedCh <- dataReader
   173  	return errors.Wrapf(err, "failed to copy content from WT_STREAM")
   174  }
   175  func (s *wtStream) Serialize() []byte {
   176  	b := make([]byte, 0, 4+len(s.StreamData))
   177  	b = quicvarint.Append(b, uint64(s.StreamID))
   178  	b = append(b, s.StreamData...)
   179  
   180  	return b
   181  }
   182  
   183  func (s *webtransportStream) Write(p []byte) (n int, err error) {
   184  	err = s.rw.WriteCapsule(wtCapsuleStream, &wtStream{StreamData: p, StreamID: s.rw.rws.stream.id})
   185  
   186  	return len(p), err
   187  }
   188  
   189  func (s *webtransportStream) Close() error {
   190  	if s.rw.rws != nil {
   191  		s.rw.handlerDone()
   192  	}
   193  	return nil
   194  }
   195  
   196  func (s *webtransportStream) StreamID() quic.StreamID {
   197  	return quic.StreamID(s.rw.rws.stream.id)
   198  }
   199  
   200  func (s *webtransportStream) CancelWrite(code webtransport.StreamErrorCode) {
   201  
   202  }
   203  
   204  func (s *webtransportStream) SetWriteDeadline(t stdlibtime.Time) error {
   205  	return nil
   206  }
   207  
   208  func (s *webtransportStream) Read(p []byte) (n int, err error) {
   209  	if s.finReceived || s.rw.rws.handlerDone || s.rw.rws.stream.state == http2stateClosed {
   210  		return 0, io.EOF
   211  	}
   212  	var r io.Reader
   213  	if s.reader == nil {
   214  		select {
   215  		case r = <-s.streamReceivedCh:
   216  			s.reader = r
   217  		case <-s.closedCh:
   218  			return 0, io.EOF
   219  		case <-s.rw.CloseNotify():
   220  			return 0, io.EOF
   221  		}
   222  	}
   223  	if r != s.reader {
   224  		s.reader = nil
   225  		return 0, nil
   226  	}
   227  	n, err = s.reader.Read(p)
   228  	if errors.Is(err, io.EOF) {
   229  		s.reader = nil
   230  		s.readFinished <- struct{}{}
   231  		err = nil
   232  	}
   233  	return n, err
   234  }
   235  
   236  func (s *webtransportStream) CancelRead(code webtransport.StreamErrorCode) {
   237  }
   238  
   239  func (s *webtransportStream) SetReadDeadline(t stdlibtime.Time) error {
   240  	return nil
   241  }
   242  
   243  func (s *webtransportStream) SetDeadline(t stdlibtime.Time) error {
   244  	return multierror.Append(
   245  		s.rw.SetReadDeadline(t),
   246  		s.rw.SetWriteDeadline(t),
   247  	).ErrorOrNil()
   248  }
   249  
   250  func (md *wtMaxData) Deserialize(dataReader quicvarint.Reader) (err error) {
   251  	if md.MaxData, err = quicvarint.Read(dataReader); err != nil {
   252  		err = errors.Wrapf(err, "failed to parse WT_MAX_DATA/MaxData")
   253  	}
   254  	return err
   255  }
   256  func (ms *wtMaxStreams) Deserialize(dataReader quicvarint.Reader) (err error) {
   257  	if ms.MaxStreams, err = quicvarint.Read(dataReader); err != nil {
   258  		err = errors.Wrapf(err, "failed to parse WT_MAX_STREAMS/Maximum Streams")
   259  	}
   260  	return err
   261  }
   262  func (md *wtMaxStreamData) Deserialize(dataReader quicvarint.Reader) (err error) {
   263  	if md.StreamID, err = quicvarint.Read(dataReader); err != nil {
   264  		err = errors.Wrapf(err, "failed to parse WT_MAX_STREAM_DATA/StreamID")
   265  		return err
   266  	}
   267  	if md.MaxData, err = quicvarint.Read(dataReader); err != nil {
   268  		return errors.Wrapf(err, "failed to parse WT_MAX_STREAM_DATA/MaxData")
   269  	}
   270  	return err
   271  }