github.com/ezoic/ws@v1.0.4-0.20220713205711-5c1d69e074c5/check.go (about)

     1  package ws
     2  
     3  import "unicode/utf8"
     4  
     5  // State represents state of websocket endpoint.
     6  // It used by some functions to be more strict when checking compatibility with RFC6455.
     7  type State uint8
     8  
     9  const (
    10  	// StateServerSide means that endpoint (caller) is a server.
    11  	StateServerSide State = 0x1 << iota
    12  	// StateClientSide means that endpoint (caller) is a client.
    13  	StateClientSide
    14  	// StateExtended means that extension was negotiated during handshake.
    15  	StateExtended
    16  	// StateFragmented means that endpoint (caller) has received fragmented
    17  	// frame and waits for continuation parts.
    18  	StateFragmented
    19  )
    20  
    21  // Is checks whether the s has v enabled.
    22  func (s State) Is(v State) bool {
    23  	return uint8(s)&uint8(v) != 0
    24  }
    25  
    26  // Set enables v state on s.
    27  func (s State) Set(v State) State {
    28  	return s | v
    29  }
    30  
    31  // Clear disables v state on s.
    32  func (s State) Clear(v State) State {
    33  	return s & (^v)
    34  }
    35  
    36  // ServerSide reports whether states represents server side.
    37  func (s State) ServerSide() bool { return s.Is(StateServerSide) }
    38  
    39  // ClientSide reports whether state represents client side.
    40  func (s State) ClientSide() bool { return s.Is(StateClientSide) }
    41  
    42  // Extended reports whether state is extended.
    43  func (s State) Extended() bool { return s.Is(StateExtended) }
    44  
    45  // Fragmented reports whether state is fragmented.
    46  func (s State) Fragmented() bool { return s.Is(StateFragmented) }
    47  
    48  // ProtocolError describes error during checking/parsing websocket frames or
    49  // headers.
    50  type ProtocolError string
    51  
    52  // Error implements error interface.
    53  func (p ProtocolError) Error() string { return string(p) }
    54  
    55  // Errors used by the protocol checkers.
    56  var (
    57  	ErrProtocolOpCodeReserved             = ProtocolError("use of reserved op code")
    58  	ErrProtocolControlPayloadOverflow     = ProtocolError("control frame payload limit exceeded")
    59  	ErrProtocolControlNotFinal            = ProtocolError("control frame is not final")
    60  	ErrProtocolNonZeroRsv                 = ProtocolError("non-zero rsv bits with no extension negotiated")
    61  	ErrProtocolMaskRequired               = ProtocolError("frames from client to server must be masked")
    62  	ErrProtocolMaskUnexpected             = ProtocolError("frames from server to client must be not masked")
    63  	ErrProtocolContinuationExpected       = ProtocolError("unexpected non-continuation data frame")
    64  	ErrProtocolContinuationUnexpected     = ProtocolError("unexpected continuation data frame")
    65  	ErrProtocolStatusCodeNotInUse         = ProtocolError("status code is not in use")
    66  	ErrProtocolStatusCodeApplicationLevel = ProtocolError("status code is only application level")
    67  	ErrProtocolStatusCodeNoMeaning        = ProtocolError("status code has no meaning yet")
    68  	ErrProtocolStatusCodeUnknown          = ProtocolError("status code is not defined in spec")
    69  	ErrProtocolInvalidUTF8                = ProtocolError("invalid utf8 sequence in close reason")
    70  )
    71  
    72  // CheckHeader checks h to contain valid header data for given state s.
    73  //
    74  // Note that zero state (0) means that state is clean,
    75  // neither server or client side, nor fragmented, nor extended.
    76  func CheckHeader(h Header, s State) error {
    77  	if h.OpCode.IsReserved() {
    78  		return ErrProtocolOpCodeReserved
    79  	}
    80  	if h.OpCode.IsControl() {
    81  		if h.Length > MaxControlFramePayloadSize {
    82  			return ErrProtocolControlPayloadOverflow
    83  		}
    84  		if !h.Fin {
    85  			return ErrProtocolControlNotFinal
    86  		}
    87  	}
    88  
    89  	switch {
    90  	// [RFC6455]: MUST be 0 unless an extension is negotiated that defines meanings for
    91  	// non-zero values. If a nonzero value is received and none of the
    92  	// negotiated extensions defines the meaning of such a nonzero value, the
    93  	// receiving endpoint MUST _Fail the WebSocket Connection_.
    94  	case h.Rsv != 0 && !s.Extended():
    95  		return ErrProtocolNonZeroRsv
    96  
    97  	// [RFC6455]: The server MUST close the connection upon receiving a frame that is not masked.
    98  	// In this case, a server MAY send a Close frame with a status code of 1002 (protocol error)
    99  	// as defined in Section 7.4.1. A server MUST NOT mask any frames that it sends to the client.
   100  	// A client MUST close a connection if it detects a masked frame. In this case, it MAY use the
   101  	// status code 1002 (protocol error) as defined in Section 7.4.1.
   102  	case s.ServerSide() && !h.Masked:
   103  		return ErrProtocolMaskRequired
   104  	case s.ClientSide() && h.Masked:
   105  		return ErrProtocolMaskUnexpected
   106  
   107  	// [RFC6455]: See detailed explanation in 5.4 section.
   108  	case s.Fragmented() && !h.OpCode.IsControl() && h.OpCode != OpContinuation:
   109  		return ErrProtocolContinuationExpected
   110  	case !s.Fragmented() && h.OpCode == OpContinuation:
   111  		return ErrProtocolContinuationUnexpected
   112  
   113  	default:
   114  		return nil
   115  	}
   116  }
   117  
   118  // CheckCloseFrameData checks received close information
   119  // to be valid RFC6455 compatible close info.
   120  //
   121  // Note that code.Empty() or code.IsAppLevel() will raise error.
   122  //
   123  // If endpoint sends close frame without status code (with frame.Length = 0),
   124  // application should not check its payload.
   125  func CheckCloseFrameData(code StatusCode, reason string) error {
   126  	switch {
   127  	case code.IsNotUsed():
   128  		return ErrProtocolStatusCodeNotInUse
   129  
   130  	case code.IsProtocolReserved():
   131  		return ErrProtocolStatusCodeApplicationLevel
   132  
   133  	case code == StatusNoMeaningYet:
   134  		return ErrProtocolStatusCodeNoMeaning
   135  
   136  	case code.IsProtocolSpec() && !code.IsProtocolDefined():
   137  		return ErrProtocolStatusCodeUnknown
   138  
   139  	case !utf8.ValidString(reason):
   140  		return ErrProtocolInvalidUTF8
   141  
   142  	default:
   143  		return nil
   144  	}
   145  }