github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/http3/body.go (about)

     1  package http3
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  
     8  	"github.com/apernet/quic-go"
     9  )
    10  
    11  // A Hijacker allows hijacking of the stream creating part of a quic.Session from a http.Response.Body.
    12  // It is used by WebTransport to create WebTransport streams after a session has been established.
    13  type Hijacker interface {
    14  	Connection() Connection
    15  }
    16  
    17  var errTooMuchData = errors.New("peer sent too much data")
    18  
    19  // The body is used in the requestBody (for a http.Request) and the responseBody (for a http.Response).
    20  type body struct {
    21  	str *stream
    22  
    23  	remainingContentLength int64
    24  	violatedContentLength  bool
    25  	hasContentLength       bool
    26  }
    27  
    28  func newBody(str *stream, contentLength int64) *body {
    29  	b := &body{str: str}
    30  	if contentLength >= 0 {
    31  		b.hasContentLength = true
    32  		b.remainingContentLength = contentLength
    33  	}
    34  	return b
    35  }
    36  
    37  func (r *body) StreamID() quic.StreamID { return r.str.StreamID() }
    38  
    39  func (r *body) checkContentLengthViolation() error {
    40  	if !r.hasContentLength {
    41  		return nil
    42  	}
    43  	if r.remainingContentLength < 0 || r.remainingContentLength == 0 && r.str.hasMoreData() {
    44  		if !r.violatedContentLength {
    45  			r.str.CancelRead(quic.StreamErrorCode(ErrCodeMessageError))
    46  			r.str.CancelWrite(quic.StreamErrorCode(ErrCodeMessageError))
    47  			r.violatedContentLength = true
    48  		}
    49  		return errTooMuchData
    50  	}
    51  	return nil
    52  }
    53  
    54  func (r *body) Read(b []byte) (int, error) {
    55  	if err := r.checkContentLengthViolation(); err != nil {
    56  		return 0, err
    57  	}
    58  	if r.hasContentLength {
    59  		b = b[:min(int64(len(b)), r.remainingContentLength)]
    60  	}
    61  	n, err := r.str.Read(b)
    62  	r.remainingContentLength -= int64(n)
    63  	if err := r.checkContentLengthViolation(); err != nil {
    64  		return n, err
    65  	}
    66  	return n, maybeReplaceError(err)
    67  }
    68  
    69  func (r *body) Close() error {
    70  	r.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
    71  	return nil
    72  }
    73  
    74  type requestBody struct {
    75  	body
    76  	connCtx      context.Context
    77  	rcvdSettings <-chan struct{}
    78  	getSettings  func() *Settings
    79  }
    80  
    81  var _ io.ReadCloser = &requestBody{}
    82  
    83  func newRequestBody(str *stream, contentLength int64, connCtx context.Context, rcvdSettings <-chan struct{}, getSettings func() *Settings) *requestBody {
    84  	return &requestBody{
    85  		body:         *newBody(str, contentLength),
    86  		connCtx:      connCtx,
    87  		rcvdSettings: rcvdSettings,
    88  		getSettings:  getSettings,
    89  	}
    90  }
    91  
    92  type hijackableBody struct {
    93  	body body
    94  
    95  	// only set for the http.Response
    96  	// The channel is closed when the user is done with this response:
    97  	// either when Read() errors, or when Close() is called.
    98  	reqDone       chan<- struct{}
    99  	reqDoneClosed bool
   100  }
   101  
   102  var _ io.ReadCloser = &hijackableBody{}
   103  
   104  func newResponseBody(str *stream, contentLength int64, done chan<- struct{}) *hijackableBody {
   105  	return &hijackableBody{
   106  		body:    *newBody(str, contentLength),
   107  		reqDone: done,
   108  	}
   109  }
   110  
   111  func (r *hijackableBody) Read(b []byte) (int, error) {
   112  	n, err := r.body.Read(b)
   113  	if err != nil {
   114  		r.requestDone()
   115  	}
   116  	return n, maybeReplaceError(err)
   117  }
   118  
   119  func (r *hijackableBody) requestDone() {
   120  	if r.reqDoneClosed || r.reqDone == nil {
   121  		return
   122  	}
   123  	if r.reqDone != nil {
   124  		close(r.reqDone)
   125  	}
   126  	r.reqDoneClosed = true
   127  }
   128  
   129  func (r *hijackableBody) Close() error {
   130  	r.requestDone()
   131  	// If the EOF was read, CancelRead() is a no-op.
   132  	r.body.str.CancelRead(quic.StreamErrorCode(ErrCodeRequestCanceled))
   133  	return nil
   134  }