github.com/sagernet/quic-go@v0.43.1-beta.1/http3/body.go (about)

     1  package http3
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  
     8  	"github.com/sagernet/quic-go"
     9  	"github.com/sagernet/quic-go/internal/utils"
    10  )
    11  
    12  // A Hijacker allows hijacking of the stream creating part of a quic.Session from a http.Response.Body.
    13  // It is used by WebTransport to create WebTransport streams after a session has been established.
    14  type Hijacker interface {
    15  	Connection() Connection
    16  }
    17  
    18  var errTooMuchData = errors.New("peer sent too much data")
    19  
    20  // The body is used in the requestBody (for a http.Request) and the responseBody (for a http.Response).
    21  type body struct {
    22  	str *stream
    23  
    24  	remainingContentLength int64
    25  	violatedContentLength  bool
    26  	hasContentLength       bool
    27  }
    28  
    29  func newBody(str *stream, contentLength int64) *body {
    30  	b := &body{str: str}
    31  	if contentLength >= 0 {
    32  		b.hasContentLength = true
    33  		b.remainingContentLength = contentLength
    34  	}
    35  	return b
    36  }
    37  
    38  func (r *body) StreamID() quic.StreamID { return r.str.StreamID() }
    39  
    40  func (r *body) checkContentLengthViolation() error {
    41  	if !r.hasContentLength {
    42  		return nil
    43  	}
    44  	if r.remainingContentLength < 0 || r.remainingContentLength == 0 && r.str.hasMoreData() {
    45  		if !r.violatedContentLength {
    46  			r.str.CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
    47  			r.str.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
    48  			r.violatedContentLength = true
    49  		}
    50  		return errTooMuchData
    51  	}
    52  	return nil
    53  }
    54  
    55  func (r *body) Read(b []byte) (int, error) {
    56  	if err := r.checkContentLengthViolation(); err != nil {
    57  		return 0, err
    58  	}
    59  	if r.hasContentLength {
    60  		b = b[:utils.Min(int64(len(b)), r.remainingContentLength)]
    61  	}
    62  	n, err := r.str.Read(b)
    63  	r.remainingContentLength -= int64(n)
    64  	if err := r.checkContentLengthViolation(); err != nil {
    65  		return n, err
    66  	}
    67  	return n, maybeReplaceError(err)
    68  }
    69  
    70  func (r *body) Close() error {
    71  	r.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
    72  	return nil
    73  }
    74  
    75  type requestBody struct {
    76  	body
    77  	connCtx      context.Context
    78  	rcvdSettings <-chan struct{}
    79  	getSettings  func() *Settings
    80  }
    81  
    82  var _ io.ReadCloser = &requestBody{}
    83  
    84  func newRequestBody(str *stream, contentLength int64, connCtx context.Context, rcvdSettings <-chan struct{}, getSettings func() *Settings) *requestBody {
    85  	return &requestBody{
    86  		body:         *newBody(str, contentLength),
    87  		connCtx:      connCtx,
    88  		rcvdSettings: rcvdSettings,
    89  		getSettings:  getSettings,
    90  	}
    91  }
    92  
    93  type hijackableBody struct {
    94  	body body
    95  
    96  	// only set for the http.Response
    97  	// The channel is closed when the user is done with this response:
    98  	// either when Read() errors, or when Close() is called.
    99  	reqDone       chan<- struct{}
   100  	reqDoneClosed bool
   101  }
   102  
   103  var _ io.ReadCloser = &hijackableBody{}
   104  
   105  func newResponseBody(str *stream, contentLength int64, done chan<- struct{}) *hijackableBody {
   106  	return &hijackableBody{
   107  		body:    *newBody(str, contentLength),
   108  		reqDone: done,
   109  	}
   110  }
   111  
   112  func (r *hijackableBody) Read(b []byte) (int, error) {
   113  	n, err := r.body.Read(b)
   114  	if err != nil {
   115  		r.requestDone()
   116  	}
   117  	return n, maybeReplaceError(err)
   118  }
   119  
   120  func (r *hijackableBody) requestDone() {
   121  	if r.reqDoneClosed || r.reqDone == nil {
   122  		return
   123  	}
   124  	if r.reqDone != nil {
   125  		close(r.reqDone)
   126  	}
   127  	r.reqDoneClosed = true
   128  }
   129  
   130  func (r *hijackableBody) Close() error {
   131  	r.requestDone()
   132  	// If the EOF was read, CancelRead() is a no-op.
   133  	r.body.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
   134  	return nil
   135  }