github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/nphttp2/grpc/grpcframe/frame_writer.go (about)

     1  /*
     2   * Copyright 2021 The Go Authors. All rights reserved.
     3   *
     4   * Use of this source code is governed by a BSD-style
     5   * license that can be found in the LICENSE file.
     6   *
     7   * This file may have been modified by CloudWeGo authors. All CloudWeGo
     8   * Modifications are Copyright 2022 CloudWeGo Authors.
     9   *
    10   * Code forked and modified from golang v1.17.4
    11   */
    12  
    13  package grpcframe
    14  
    15  import (
    16  	"errors"
    17  
    18  	"golang.org/x/net/http2"
    19  )
    20  
    21  const frameHeaderLen = 9
    22  
    23  var padZeros = make([]byte, 255) // zeros for padding
    24  
    25  var (
    26  	errStreamID    = errors.New("invalid stream ID")
    27  	errDepStreamID = errors.New("invalid dependent stream ID")
    28  )
    29  
    30  func (fr *Framer) startWrite(ftype http2.FrameType, flags http2.Flags, streamID uint32, payloadLen int) (err error) {
    31  	if payloadLen >= (1 << 24) {
    32  		return http2.ErrFrameTooLarge
    33  	}
    34  	fr.wbuf = append(fr.wbuf[:0],
    35  		byte(payloadLen>>16),
    36  		byte(payloadLen>>8),
    37  		byte(payloadLen),
    38  		byte(ftype),
    39  		byte(flags),
    40  		byte(streamID>>24),
    41  		byte(streamID>>16),
    42  		byte(streamID>>8),
    43  		byte(streamID))
    44  	return nil
    45  }
    46  
    47  func (fr *Framer) endWrite() (err error) {
    48  	_, err = fr.writer.Write(fr.wbuf)
    49  	return err
    50  }
    51  
    52  func (fr *Framer) writeByte(v byte)     { fr.wbuf = append(fr.wbuf, v) }
    53  func (fr *Framer) writeBytes(v []byte)  { fr.wbuf = append(fr.wbuf, v...) }
    54  func (fr *Framer) writeUint16(v uint16) { fr.wbuf = append(fr.wbuf, byte(v>>8), byte(v)) }
    55  func (fr *Framer) writeUint32(v uint32) {
    56  	fr.wbuf = append(fr.wbuf, byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
    57  }
    58  
    59  // WriteData writes a DATA frame.
    60  //
    61  // It will perform exactly one Write to the underlying Writer.
    62  // It is the caller's responsibility not to violate the maximum frame size
    63  // and to not call other Write methods concurrently.
    64  func (fr *Framer) WriteData(streamID uint32, endStream bool, data []byte) error {
    65  	if !validStreamID(streamID) && !fr.AllowIllegalWrites {
    66  		return errStreamID
    67  	}
    68  
    69  	var flags http2.Flags
    70  	if endStream {
    71  		flags |= http2.FlagDataEndStream
    72  	}
    73  
    74  	err := fr.startWrite(http2.FrameData, flags, streamID, len(data))
    75  	if err != nil {
    76  		return err
    77  	}
    78  	fr.writeBytes(data)
    79  	return fr.endWrite()
    80  }
    81  
    82  // WriteHeaders writes a single HEADERS frame.
    83  //
    84  // This is a low-level header writing method. Encoding headers and
    85  // splitting them into any necessary CONTINUATION frames is handled
    86  // elsewhere.
    87  //
    88  // It will perform exactly one Write to the underlying Writer.
    89  // It is the caller's responsibility to not call other Write methods concurrently.
    90  func (fr *Framer) WriteHeaders(p http2.HeadersFrameParam) error {
    91  	if !validStreamID(p.StreamID) && !fr.AllowIllegalWrites {
    92  		return errStreamID
    93  	}
    94  	var payloadLen int
    95  	var flags http2.Flags
    96  	if p.PadLength != 0 {
    97  		flags |= http2.FlagHeadersPadded
    98  		payloadLen += 1 + int(p.PadLength)
    99  	}
   100  	if p.EndStream {
   101  		flags |= http2.FlagHeadersEndStream
   102  	}
   103  	if p.EndHeaders {
   104  		flags |= http2.FlagHeadersEndHeaders
   105  	}
   106  	if !p.Priority.IsZero() {
   107  		v := p.Priority.StreamDep
   108  		if !validStreamIDOrZero(v) && !fr.AllowIllegalWrites {
   109  			return errDepStreamID
   110  		}
   111  		flags |= http2.FlagHeadersPriority
   112  		payloadLen += 5
   113  	}
   114  	payloadLen += len(p.BlockFragment)
   115  	err := fr.startWrite(http2.FrameHeaders, flags, p.StreamID, payloadLen)
   116  	if err != nil {
   117  		return err
   118  	}
   119  	if p.PadLength != 0 {
   120  		fr.writeByte(p.PadLength)
   121  	}
   122  	if !p.Priority.IsZero() {
   123  		v := p.Priority.StreamDep
   124  		if p.Priority.Exclusive {
   125  			v |= 1 << 31
   126  		}
   127  		fr.writeUint32(v)
   128  		fr.writeByte(p.Priority.Weight)
   129  	}
   130  	fr.writeBytes(p.BlockFragment)
   131  	fr.writeBytes(padZeros[:p.PadLength])
   132  	return fr.endWrite()
   133  }
   134  
   135  // WritePriority writes a PRIORITY frame.
   136  //
   137  // It will perform exactly one Write to the underlying Writer.
   138  // It is the caller's responsibility to not call other Write methods concurrently.
   139  func (fr *Framer) WritePriority(streamID uint32, p http2.PriorityParam) error {
   140  	if !validStreamID(streamID) && !fr.AllowIllegalWrites {
   141  		return errStreamID
   142  	}
   143  	if !validStreamIDOrZero(p.StreamDep) {
   144  		return errDepStreamID
   145  	}
   146  
   147  	err := fr.startWrite(http2.FramePriority, 0, streamID, 5)
   148  	if err != nil {
   149  		return err
   150  	}
   151  
   152  	v := p.StreamDep
   153  	if p.Exclusive {
   154  		v |= 1 << 31
   155  	}
   156  	fr.writeUint32(v)
   157  	fr.writeByte(p.Weight)
   158  	return fr.endWrite()
   159  }
   160  
   161  // WriteRSTStream writes a RST_STREAM frame.
   162  //
   163  // It will perform exactly one Write to the underlying Writer.
   164  // It is the caller's responsibility to not call other Write methods concurrently.
   165  func (fr *Framer) WriteRSTStream(streamID uint32, code http2.ErrCode) error {
   166  	if !validStreamID(streamID) && !fr.AllowIllegalWrites {
   167  		return errStreamID
   168  	}
   169  	err := fr.startWrite(http2.FrameRSTStream, 0, streamID, 4)
   170  	if err != nil {
   171  		return err
   172  	}
   173  	fr.writeUint32(uint32(code))
   174  	return fr.endWrite()
   175  }
   176  
   177  // WriteSettings writes a SETTINGS frame with zero or more settings
   178  // specified and the ACK bit not set.
   179  //
   180  // It will perform exactly one Write to the underlying Writer.
   181  // It is the caller's responsibility to not call other Write methods concurrently.
   182  func (fr *Framer) WriteSettings(settings ...http2.Setting) error {
   183  	payloadLen := 6 * len(settings)
   184  	err := fr.startWrite(http2.FrameSettings, 0, 0, payloadLen)
   185  	if err != nil {
   186  		return err
   187  	}
   188  	for _, s := range settings {
   189  		fr.writeUint16(uint16(s.ID))
   190  		fr.writeUint32(s.Val)
   191  	}
   192  	return fr.endWrite()
   193  }
   194  
   195  // WriteSettingsAck writes an empty SETTINGS frame with the ACK bit set.
   196  //
   197  // It will perform exactly one Write to the underlying Writer.
   198  // It is the caller's responsibility to not call other Write methods concurrently.
   199  func (fr *Framer) WriteSettingsAck() error {
   200  	err := fr.startWrite(http2.FrameSettings, http2.FlagSettingsAck, 0, 0)
   201  	if err != nil {
   202  		return err
   203  	}
   204  	return fr.endWrite()
   205  }
   206  
   207  // WritePushPromise writes a single PushPromise Frame.
   208  //
   209  // As with Header Frames, This is the low level call for writing
   210  // individual frames. Continuation frames are handled elsewhere.
   211  //
   212  // It will perform exactly one Write to the underlying Writer.
   213  // It is the caller's responsibility to not call other Write methods concurrently.
   214  func (fr *Framer) WritePushPromise(p http2.PushPromiseParam) error {
   215  	if !fr.AllowIllegalWrites && (!validStreamID(p.StreamID) || !validStreamID(p.PromiseID)) {
   216  		return errStreamID
   217  	}
   218  	if !validStreamID(p.PromiseID) && !fr.AllowIllegalWrites {
   219  		return errStreamID
   220  	}
   221  	var payloadLen int
   222  	var flags http2.Flags
   223  	if p.PadLength != 0 {
   224  		flags |= http2.FlagPushPromisePadded
   225  		payloadLen += 1 + int(p.PadLength)
   226  	}
   227  	if p.EndHeaders {
   228  		flags |= http2.FlagPushPromiseEndHeaders
   229  	}
   230  	payloadLen += 4 + len(p.BlockFragment)
   231  	err := fr.startWrite(http2.FramePushPromise, flags, p.StreamID, payloadLen)
   232  	if err != nil {
   233  		return err
   234  	}
   235  	if p.PadLength != 0 {
   236  		fr.writeByte(p.PadLength)
   237  	}
   238  	fr.writeUint32(p.PromiseID)
   239  	fr.writeBytes(p.BlockFragment)
   240  	fr.writeBytes(padZeros[:p.PadLength])
   241  	return fr.endWrite()
   242  }
   243  
   244  func (fr *Framer) WritePing(ack bool, data [8]byte) error {
   245  	var flags http2.Flags
   246  	if ack {
   247  		flags = http2.FlagPingAck
   248  	}
   249  	err := fr.startWrite(http2.FramePing, flags, 0, len(data))
   250  	if err != nil {
   251  		return err
   252  	}
   253  	fr.writeBytes(data[:])
   254  	return fr.endWrite()
   255  }
   256  
   257  func (fr *Framer) WriteGoAway(maxStreamID uint32, code http2.ErrCode, debugData []byte) error {
   258  	payloadLen := 8 + len(debugData)
   259  	err := fr.startWrite(http2.FrameGoAway, 0, 0, payloadLen)
   260  	if err != nil {
   261  		return err
   262  	}
   263  	fr.writeUint32(maxStreamID & (1<<31 - 1))
   264  	fr.writeUint32(uint32(code))
   265  	fr.writeBytes(debugData)
   266  	return fr.endWrite()
   267  }
   268  
   269  // WriteWindowUpdate writes a WINDOW_UPDATE frame.
   270  // The increment value must be between 1 and 2,147,483,647, inclusive.
   271  // If the Stream ID is zero, the window update applies to the
   272  // connection as a whole.
   273  func (fr *Framer) WriteWindowUpdate(streamID, incr uint32) error {
   274  	// "The legal range for the increment to the flow control window is 1 to 2^31-1 (2,147,483,647) octets."
   275  	if (incr < 1 || incr > 2147483647) && !fr.AllowIllegalWrites {
   276  		return errors.New("illegal window increment value")
   277  	}
   278  
   279  	err := fr.startWrite(http2.FrameWindowUpdate, 0, streamID, 4)
   280  	if err != nil {
   281  		return err
   282  	}
   283  	fr.writeUint32(incr)
   284  	return fr.endWrite()
   285  }
   286  
   287  // WriteContinuation writes a CONTINUATION frame.
   288  //
   289  // It will perform exactly one Write to the underlying Writer.
   290  // It is the caller's responsibility to not call other Write methods concurrently.
   291  func (fr *Framer) WriteContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) error {
   292  	if !validStreamID(streamID) && !fr.AllowIllegalWrites {
   293  		return errStreamID
   294  	}
   295  	var flags http2.Flags
   296  	if endHeaders {
   297  		flags |= http2.FlagContinuationEndHeaders
   298  	}
   299  	err := fr.startWrite(http2.FrameContinuation, flags, streamID, len(headerBlockFragment))
   300  	if err != nil {
   301  		return err
   302  	}
   303  	fr.writeBytes(headerBlockFragment)
   304  	return fr.endWrite()
   305  }
   306  
   307  // WriteRawFrame writes a raw frame. This can be used to write
   308  // extension frames unknown to this package.
   309  func (fr *Framer) WriteRawFrame(t http2.FrameType, flags http2.Flags, streamID uint32, payload []byte) error {
   310  	err := fr.startWrite(t, flags, streamID, len(payload))
   311  	if err != nil {
   312  		return err
   313  	}
   314  	fr.writeBytes(payload)
   315  	return fr.endWrite()
   316  }
   317  
   318  func validStreamIDOrZero(streamID uint32) bool {
   319  	return streamID&(1<<31) == 0
   320  }
   321  
   322  func validStreamID(streamID uint32) bool {
   323  	return streamID != 0 && streamID&(1<<31) == 0
   324  }