github.com/simonmittag/ws@v1.1.0-rc.5.0.20210419231947-82b846128245/wsflate/extension.go (about)

     1  package wsflate
     2  
     3  import (
     4  	"bytes"
     5  
     6  	"github.com/gobwas/httphead"
     7  	"github.com/simonmittag/ws"
     8  )
     9  
    10  // Extension contains logic of compression extension parameters negotiation
    11  // made during HTTP WebSocket handshake.
    12  // It might be reused between different upgrades (but not concurrently) with
    13  // Reset() being called after each.
    14  type Extension struct {
    15  	// Parameters is specification of extension parameters server is going to
    16  	// accept.
    17  	Parameters Parameters
    18  
    19  	accepted bool
    20  	params   Parameters
    21  }
    22  
    23  // Negotiate parses given HTTP header option and returns (if any) header option
    24  // which describes accepted parameters.
    25  //
    26  // It may return zero option (i.e. one which Size() returns 0) alongside with
    27  // nil error.
    28  func (n *Extension) Negotiate(opt httphead.Option) (accept httphead.Option, err error) {
    29  	if !bytes.Equal(opt.Name, ExtensionNameBytes) {
    30  		return
    31  	}
    32  	if n.accepted {
    33  		// Negotiate might be called multiple times during upgrade.
    34  		// We stick to first one accepted extension since they must be passed
    35  		// in ordered by preference.
    36  		return
    37  	}
    38  
    39  	want := n.Parameters
    40  
    41  	// NOTE: Parse() resets params inside, so no worries.
    42  	if err = n.params.Parse(opt); err != nil {
    43  		return
    44  	}
    45  	{
    46  		offer := n.params.ServerMaxWindowBits
    47  		want := want.ServerMaxWindowBits
    48  		if offer > want {
    49  			// A server declines an extension negotiation offer
    50  			// with this parameter if the server doesn't support
    51  			// it.
    52  			return
    53  		}
    54  	}
    55  	{
    56  		// If a received extension negotiation offer has the
    57  		// "client_max_window_bits" extension parameter, the server MAY
    58  		// include the "client_max_window_bits" extension parameter in the
    59  		// corresponding extension negotiation response to the offer.
    60  		offer := n.params.ClientMaxWindowBits
    61  		want := want.ClientMaxWindowBits
    62  		if want > offer {
    63  			return
    64  		}
    65  	}
    66  	{
    67  		offer := n.params.ServerNoContextTakeover
    68  		want := want.ServerNoContextTakeover
    69  		if offer && !want {
    70  			return
    71  		}
    72  	}
    73  
    74  	n.accepted = true
    75  
    76  	return want.Option(), nil
    77  }
    78  
    79  // Accepted returns parameters parsed during last negotiation and a flag that
    80  // reports whether they were accepted.
    81  func (n *Extension) Accepted() (_ Parameters, accepted bool) {
    82  	return n.params, n.accepted
    83  }
    84  
    85  // Reset resets extension for further reuse.
    86  func (n *Extension) Reset() {
    87  	n.accepted = false
    88  	n.params = Parameters{}
    89  }
    90  
    91  var ErrUnexpectedCompressionBit = ws.ProtocolError(
    92  	"control frame or non-first fragment of data contains compression bit set",
    93  )
    94  
    95  // UnsetBit clears the Per-Message Compression bit in header h and returns its
    96  // modified copy. It reports whether compression bit was set in header h.
    97  // It returns non-nil error if compression bit has unexpected value.
    98  //
    99  // This function's main purpose is to be compatible with "Framing" section of
   100  // the Compression Extensions for WebSocket RFC. If you don't need to work with
   101  // chains of extensions then IsCompressed() could be enough to check if
   102  // message is compressed.
   103  // See https://tools.ietf.org/html/rfc7692#section-6.2
   104  func UnsetBit(h ws.Header) (_ ws.Header, wasSet bool, err error) {
   105  	var s MessageState
   106  	h, err = s.UnsetBits(h)
   107  	return h, s.IsCompressed(), err
   108  }
   109  
   110  // SetBit sets the Per-Message Compression bit in header h and returns its
   111  // modified copy.
   112  // It returns non-nil error if compression bit has unexpected value.
   113  func SetBit(h ws.Header) (_ ws.Header, err error) {
   114  	var s MessageState
   115  	s.SetCompressed(true)
   116  	return s.SetBits(h)
   117  }
   118  
   119  // IsCompressed reports whether the Per-Message Compression bit is set in
   120  // header h.
   121  // It returns non-nil error if compression bit has unexpected value.
   122  //
   123  // If you need to be fully compatible with Compression Extensions for WebSocket
   124  // RFC and work with chains of extensions, take a look at the UnsetBit()
   125  // instead. That is, IsCompressed() is a shortcut for UnsetBit() with reduced
   126  // number of return values.
   127  func IsCompressed(h ws.Header) (bool, error) {
   128  	_, isSet, err := UnsetBit(h)
   129  	return isSet, err
   130  }
   131  
   132  // MessageState holds message compression state.
   133  //
   134  // It is consulted during SetBits(h) call to make a decision whether we must
   135  // set the Per-Message Compression bit for given header h argument.
   136  // It is updated during UnsetBits(h) to reflect compression state of a message
   137  // represented by header h argument.
   138  // It can also be consulted/updated directly by calling
   139  // IsCompressed()/SetCompressed().
   140  //
   141  // In general MessageState should be used when there is no direct access to
   142  // connection to read frame from, but it is still needed to know if message
   143  // being read is compressed. For other cases SetBit() and UnsetBit() should be
   144  // used instead.
   145  //
   146  // NOTE: the compression state is updated during UnsetBits(h) only when header
   147  // h argument represents data (text or binary) frame.
   148  type MessageState struct {
   149  	compressed bool
   150  }
   151  
   152  // SetCompressed marks message as "compressed" or "uncompressed".
   153  // See https://tools.ietf.org/html/rfc7692#section-6
   154  func (s *MessageState) SetCompressed(v bool) {
   155  	s.compressed = v
   156  }
   157  
   158  // IsCompressed reports whether message is "compressed".
   159  // See https://tools.ietf.org/html/rfc7692#section-6
   160  func (s *MessageState) IsCompressed() bool {
   161  	return s.compressed
   162  }
   163  
   164  // UnsetBits changes RSV bits of the given frame header h as if compression
   165  // extension was negotiated. It returns modified copy of h and error if header
   166  // is malformed from the RFC perspective.
   167  func (s *MessageState) UnsetBits(h ws.Header) (ws.Header, error) {
   168  	r1, r2, r3 := ws.RsvBits(h.Rsv)
   169  	switch {
   170  	case h.OpCode.IsData() && h.OpCode != ws.OpContinuation:
   171  		h.Rsv = ws.Rsv(false, r2, r3)
   172  		s.SetCompressed(r1)
   173  		return h, nil
   174  
   175  	case r1:
   176  		// An endpoint MUST NOT set the "Per-Message Compressed"
   177  		// bit of control frames and non-first fragments of a data
   178  		// message. An endpoint receiving such a frame MUST _Fail
   179  		// the WebSocket Connection_.
   180  		return h, ErrUnexpectedCompressionBit
   181  
   182  	default:
   183  		// NOTE: do not change the state of s.compressed since UnsetBits()
   184  		// might also be called for (intermediate) control frames.
   185  		return h, nil
   186  	}
   187  }
   188  
   189  // SetBits changes RSV bits of the frame header h which is being send as if
   190  // compression extension was negotiated. It returns modified copy of h and
   191  // error if header is malformed from the RFC perspective.
   192  func (s *MessageState) SetBits(h ws.Header) (ws.Header, error) {
   193  	r1, r2, r3 := ws.RsvBits(h.Rsv)
   194  	if r1 {
   195  		return h, ErrUnexpectedCompressionBit
   196  	}
   197  	if !h.OpCode.IsData() || h.OpCode == ws.OpContinuation {
   198  		// An endpoint MUST NOT set the "Per-Message Compressed"
   199  		// bit of control frames and non-first fragments of a data
   200  		// message. An endpoint receiving such a frame MUST _Fail
   201  		// the WebSocket Connection_.
   202  		return h, nil
   203  	}
   204  	if s.IsCompressed() {
   205  		h.Rsv = ws.Rsv(true, r2, r3)
   206  	}
   207  	return h, nil
   208  }