github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/nphttp2/grpc/grpcframe/frame_reader.go (about)

     1  /*
     2   * Copyright 2021 The Go Authors. All rights reserved.
     3   *
     4   * Use of this source code is governed by a BSD-style
     5   * license that can be found in the LICENSE file.
     6   *
     7   * This file may have been modified by CloudWeGo authors. All CloudWeGo
     8   * Modifications are Copyright 2022 CloudWeGo Authors.
     9   *
    10   * Code forked and modified from golang v1.17.4
    11   */
    12  
    13  package grpcframe
    14  
    15  import (
    16  	"encoding/binary"
    17  	"errors"
    18  	"fmt"
    19  	"io"
    20  	"strings"
    21  
    22  	"github.com/cloudwego/netpoll"
    23  	"golang.org/x/net/http/httpguts"
    24  	"golang.org/x/net/http2"
    25  	"golang.org/x/net/http2/hpack"
    26  )
    27  
    28  // A Framer reads and writes Frames.
    29  type Framer struct {
    30  	errDetail error
    31  
    32  	// lastHeaderStream is non-zero if the last frame was an
    33  	// unfinished HEADERS/CONTINUATION.
    34  	lastHeaderStream uint32
    35  	lastFrame        http2.Frame
    36  
    37  	reader      netpoll.Reader
    38  	maxReadSize uint32
    39  
    40  	writer io.Writer
    41  	wbuf   []byte
    42  	// maxWriteSize uint32 // zero means unlimited; TODO: implement
    43  
    44  	// AllowIllegalWrites permits the Framer's Write methods to
    45  	// write frames that do not conform to the HTTP/2 spec. This
    46  	// permits using the Framer to test other HTTP/2
    47  	// implementations' conformance to the spec.
    48  	// If false, the Write methods will prefer to return an error
    49  	// rather than comply.
    50  	AllowIllegalWrites bool
    51  
    52  	// AllowIllegalReads permits the Framer's ReadFrame method
    53  	// to return non-compliant frames or frame orders.
    54  	// This is for testing and permits using the Framer to test
    55  	// other HTTP/2 implementations' conformance to the spec.
    56  	// It is not compatible with ReadMetaHeaders.
    57  	AllowIllegalReads bool
    58  
    59  	// ReadMetaHeaders if non-nil causes ReadFrame to merge
    60  	// HEADERS and CONTINUATION frames together and return
    61  	// MetaHeadersFrame instead.
    62  	ReadMetaHeaders *hpack.Decoder
    63  
    64  	// MaxHeaderListSize is the http2 MAX_HEADER_LIST_SIZE.
    65  	// It's used only if ReadMetaHeaders is set; 0 means a sane default
    66  	// (currently 16MB)
    67  	// If the limit is hit, MetaHeadersFrame.Truncated is set true.
    68  	MaxHeaderListSize uint32
    69  
    70  	frameCache *frameCache // nil if frames aren't reused (default)
    71  }
    72  
    73  func (fr *Framer) maxHeaderListSize() uint32 {
    74  	if fr.MaxHeaderListSize == 0 {
    75  		return 16 << 20 // sane default, per docs
    76  	}
    77  	return fr.MaxHeaderListSize
    78  }
    79  
    80  const (
    81  	minMaxFrameSize = 1 << 14
    82  	maxFrameSize    = 1<<24 - 1
    83  )
    84  
    85  // SetReuseFrames allows the Framer to reuse Frames.
    86  // If called on a Framer, Frames returned by calls to ReadFrame are only
    87  // valid until the next call to ReadFrame.
    88  func (fr *Framer) SetReuseFrames() {
    89  	if fr.frameCache != nil {
    90  		return
    91  	}
    92  	fr.frameCache = &frameCache{}
    93  }
    94  
    95  type frameCache struct {
    96  	dataFrame DataFrame
    97  }
    98  
    99  func (fc *frameCache) getDataFrame() *DataFrame {
   100  	if fc == nil {
   101  		return &DataFrame{}
   102  	}
   103  	return &fc.dataFrame
   104  }
   105  
   106  // NewFramer returns a Framer that writes frames to w and reads them from r.
   107  func NewFramer(w io.Writer, r netpoll.Reader) *Framer {
   108  	fr := &Framer{
   109  		writer: w,
   110  		reader: r,
   111  	}
   112  	fr.SetMaxReadFrameSize(maxFrameSize)
   113  	return fr
   114  }
   115  
   116  // SetMaxReadFrameSize sets the maximum size of a frame
   117  // that will be read by a subsequent call to ReadFrame.
   118  // It is the caller's responsibility to advertise this
   119  // limit with a SETTINGS frame.
   120  func (fr *Framer) SetMaxReadFrameSize(v uint32) {
   121  	if v > maxFrameSize {
   122  		v = maxFrameSize
   123  	}
   124  	fr.maxReadSize = v
   125  }
   126  
   127  // ErrorDetail returns a more detailed error of the last error
   128  // returned by Framer.ReadFrame. For instance, if ReadFrame
   129  // returns a StreamError with code PROTOCOL_ERROR, ErrorDetail
   130  // will say exactly what was invalid. ErrorDetail is not guaranteed
   131  // to return a non-nil value and like the rest of the http2 package,
   132  // its return value is not protected by an API compatibility promise.
   133  // ErrorDetail is reset after the next call to ReadFrame.
   134  func (fr *Framer) ErrorDetail() error {
   135  	return fr.errDetail
   136  }
   137  
   138  // ReadFrame reads a single frame. The returned Frame is only valid
   139  // until the next call to ReadFrame.
   140  //
   141  // If the frame is larger than previously set with SetMaxReadFrameSize, the
   142  // returned error is ErrFrameTooLarge. Other errors may be of type
   143  // ConnectionError, StreamError, or anything else from the underlying
   144  // reader.
   145  func (fr *Framer) ReadFrame() (http2.Frame, error) {
   146  	fr.errDetail = nil
   147  
   148  	fh, err := readFrameHeader(fr.reader)
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  	if fh.Length > fr.maxReadSize {
   153  		return nil, http2.ErrFrameTooLarge
   154  	}
   155  	// payload := fr.getReadBuf(fh.Length)
   156  	var f http2.Frame
   157  	if fh.Type == http2.FrameData {
   158  		f, err = parseDataFrame(fr.frameCache, fh, fr.reader)
   159  	} else {
   160  		var payload []byte
   161  		payload, err = fr.reader.Next(int(fh.Length))
   162  		if err != nil {
   163  			return nil, err
   164  		}
   165  		f, err = typeFrameParser(fh.Type)(fr.frameCache, fh, payload)
   166  	}
   167  	if err != nil {
   168  		if ce, ok := err.(connError); ok {
   169  			return nil, fr.connError(ce.Code, ce.Reason)
   170  		}
   171  		return nil, err
   172  	}
   173  	if err = fr.checkFrameOrder(f); err != nil {
   174  		return nil, err
   175  	}
   176  	if fh.Type == http2.FrameHeaders && fr.ReadMetaHeaders != nil {
   177  		return fr.readMetaFrame(f.(*HeadersFrame))
   178  	}
   179  	return f, nil
   180  }
   181  
   182  // connError returns ConnectionError(code) but first
   183  // stashes away a public reason to the caller can optionally relay it
   184  // to the peer before hanging up on them. This might help others debug
   185  // their implementations.
   186  func (fr *Framer) connError(code http2.ErrCode, reason string) error {
   187  	fr.errDetail = errors.New(reason)
   188  	return http2.ConnectionError(code)
   189  }
   190  
   191  // checkFrameOrder reports an error if f is an invalid frame to return
   192  // next from ReadFrame. Mostly it checks whether HEADERS and
   193  // CONTINUATION frames are contiguous.
   194  func (fr *Framer) checkFrameOrder(f http2.Frame) error {
   195  	last := fr.lastFrame
   196  	fr.lastFrame = f
   197  	if fr.AllowIllegalReads {
   198  		return nil
   199  	}
   200  
   201  	fh := f.Header()
   202  	if fr.lastHeaderStream != 0 {
   203  		if fh.Type != http2.FrameContinuation {
   204  			return fr.connError(http2.ErrCodeProtocol,
   205  				fmt.Sprintf("got %s for stream %d; expected CONTINUATION following %s for stream %d",
   206  					fh.Type, fh.StreamID,
   207  					last.Header().Type, fr.lastHeaderStream))
   208  		}
   209  		if fh.StreamID != fr.lastHeaderStream {
   210  			return fr.connError(http2.ErrCodeProtocol,
   211  				fmt.Sprintf("got CONTINUATION for stream %d; expected stream %d",
   212  					fh.StreamID, fr.lastHeaderStream))
   213  		}
   214  	} else if fh.Type == http2.FrameContinuation {
   215  		return fr.connError(http2.ErrCodeProtocol, fmt.Sprintf("unexpected CONTINUATION for stream %d", fh.StreamID))
   216  	}
   217  
   218  	switch fh.Type {
   219  	case http2.FrameHeaders, http2.FrameContinuation:
   220  		if fh.Flags.Has(http2.FlagHeadersEndHeaders) {
   221  			fr.lastHeaderStream = 0
   222  		} else {
   223  			fr.lastHeaderStream = fh.StreamID
   224  		}
   225  	}
   226  
   227  	return nil
   228  }
   229  
   230  type headersEnder interface {
   231  	HeadersEnded() bool
   232  }
   233  
   234  type headersOrContinuation interface {
   235  	headersEnder
   236  	HeaderBlockFragment() []byte
   237  }
   238  
   239  func (fr *Framer) maxHeaderStringLen() int {
   240  	v := fr.maxHeaderListSize()
   241  	if uint32(int(v)) == v {
   242  		return int(v)
   243  	}
   244  	// They had a crazy big number for MaxHeaderBytes anyway,
   245  	// so give them unlimited header lengths:
   246  	return 0
   247  }
   248  
   249  // readMetaFrame returns 0 or more CONTINUATION frames from fr and
   250  // merge them into the provided hf and returns a MetaHeadersFrame
   251  // with the decoded hpack values.
   252  func (fr *Framer) readMetaFrame(hf *HeadersFrame) (*MetaHeadersFrame, error) {
   253  	if fr.AllowIllegalReads {
   254  		return nil, errors.New("illegal use of AllowIllegalReads with ReadMetaHeaders")
   255  	}
   256  	mh := &MetaHeadersFrame{
   257  		HeadersFrame: hf,
   258  		Fields:       make([]hpack.HeaderField, 0, 2),
   259  	}
   260  	remainSize := fr.maxHeaderListSize()
   261  	var sawRegular bool
   262  
   263  	var invalid error // pseudo header field errors
   264  	hdec := fr.ReadMetaHeaders
   265  	hdec.SetEmitEnabled(true)
   266  	hdec.SetMaxStringLength(fr.maxHeaderStringLen())
   267  	hdec.SetEmitFunc(func(hf hpack.HeaderField) {
   268  		if !httpguts.ValidHeaderFieldValue(hf.Value) {
   269  			invalid = headerFieldValueError(hf.Value)
   270  		}
   271  		isPseudo := strings.HasPrefix(hf.Name, ":")
   272  		if isPseudo {
   273  			if sawRegular {
   274  				invalid = errPseudoAfterRegular
   275  			}
   276  		} else {
   277  			sawRegular = true
   278  			if !validWireHeaderFieldName(hf.Name) {
   279  				invalid = headerFieldNameError(hf.Name)
   280  			}
   281  		}
   282  
   283  		if invalid != nil {
   284  			hdec.SetEmitEnabled(false)
   285  			return
   286  		}
   287  
   288  		size := hf.Size()
   289  		if size > remainSize {
   290  			hdec.SetEmitEnabled(false)
   291  			mh.Truncated = true
   292  			return
   293  		}
   294  		remainSize -= size
   295  
   296  		mh.Fields = append(mh.Fields, hf)
   297  	})
   298  	// Lose reference to MetaHeadersFrame:
   299  	defer hdec.SetEmitFunc(func(hf hpack.HeaderField) {})
   300  
   301  	var hc headersOrContinuation = hf
   302  	for {
   303  		frag := hc.HeaderBlockFragment()
   304  		if _, err := hdec.Write(frag); err != nil {
   305  			return nil, http2.ConnectionError(http2.ErrCodeCompression)
   306  		}
   307  
   308  		if hc.HeadersEnded() {
   309  			break
   310  		}
   311  		if f, err := fr.ReadFrame(); err != nil {
   312  			return nil, err
   313  		} else {
   314  			hc = f.(*ContinuationFrame) // guaranteed by checkFrameOrder
   315  		}
   316  	}
   317  
   318  	mh.HeadersFrame.headerFragBuf = nil
   319  	// mh.HeadersFrame.invalidate()
   320  
   321  	if err := hdec.Close(); err != nil {
   322  		return nil, http2.ConnectionError(http2.ErrCodeCompression)
   323  	}
   324  	if invalid != nil {
   325  		fr.errDetail = invalid
   326  		return nil, http2.StreamError{StreamID: mh.StreamID, Code: http2.ErrCodeProtocol, Cause: invalid}
   327  	}
   328  	if err := mh.checkPseudos(); err != nil {
   329  		fr.errDetail = err
   330  		return nil, http2.StreamError{StreamID: mh.StreamID, Code: http2.ErrCodeProtocol, Cause: err}
   331  	}
   332  	return mh, nil
   333  }
   334  
   335  // validWireHeaderFieldName reports whether v is a valid header field
   336  // name (key). See httpguts.ValidHeaderName for the base rules.
   337  //
   338  // Further, http2 says:
   339  //
   340  //	"Just as in HTTP/1.x, header field names are strings of ASCII
   341  //	characters that are compared in a case-insensitive
   342  //	fashion. However, header field names MUST be converted to
   343  //	lowercase prior to their encoding in HTTP/2. "
   344  func validWireHeaderFieldName(v string) bool {
   345  	if len(v) == 0 {
   346  		return false
   347  	}
   348  	for _, r := range v {
   349  		if !httpguts.IsTokenRune(r) {
   350  			return false
   351  		}
   352  		if 'A' <= r && r <= 'Z' {
   353  			return false
   354  		}
   355  	}
   356  	return true
   357  }
   358  
   359  func readByte(p []byte) (remain []byte, b byte, err error) {
   360  	if len(p) == 0 {
   361  		return nil, 0, io.ErrUnexpectedEOF
   362  	}
   363  	return p[1:], p[0], nil
   364  }
   365  
   366  func readUint32(p []byte) (remain []byte, v uint32, err error) {
   367  	if len(p) < 4 {
   368  		return nil, 0, io.ErrUnexpectedEOF
   369  	}
   370  	return p[4:], binary.BigEndian.Uint32(p[:4]), nil
   371  }
   372  
   373  func readFrameHeader(r netpoll.Reader) (http2.FrameHeader, error) {
   374  	buf, err := r.Next(frameHeaderLen)
   375  	if err != nil {
   376  		return http2.FrameHeader{}, err
   377  	}
   378  	return http2.FrameHeader{
   379  		Length:   uint32(buf[0])<<16 | uint32(buf[1])<<8 | uint32(buf[2]),
   380  		Type:     http2.FrameType(buf[3]),
   381  		Flags:    http2.Flags(buf[4]),
   382  		StreamID: binary.BigEndian.Uint32(buf[5:]) & (1<<31 - 1),
   383  		// valid:    true,
   384  	}, nil
   385  }