github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/vendor_skip/nhooyr.io/websocket/frame.go (about)

     1  package websocket
     2  
     3  import (
     4  	"bufio"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"io"
     8  	"math"
     9  	"math/bits"
    10  
    11  	"nhooyr.io/websocket/internal/errd"
    12  )
    13  
    14  // opcode represents a WebSocket opcode.
    15  type opcode int
    16  
    17  // https://tools.ietf.org/html/rfc6455#section-11.8.
    18  const (
    19  	opContinuation opcode = iota
    20  	opText
    21  	opBinary
    22  	// 3 - 7 are reserved for further non-control frames.
    23  	_
    24  	_
    25  	_
    26  	_
    27  	_
    28  	opClose
    29  	opPing
    30  	opPong
    31  	// 11-16 are reserved for further control frames.
    32  )
    33  
    34  // header represents a WebSocket frame header.
    35  // See https://tools.ietf.org/html/rfc6455#section-5.2.
    36  type header struct {
    37  	fin    bool
    38  	rsv1   bool
    39  	rsv2   bool
    40  	rsv3   bool
    41  	opcode opcode
    42  
    43  	payloadLength int64
    44  
    45  	masked  bool
    46  	maskKey uint32
    47  }
    48  
    49  // readFrameHeader reads a header from the reader.
    50  // See https://tools.ietf.org/html/rfc6455#section-5.2.
    51  func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) {
    52  	defer errd.Wrap(&err, "failed to read frame header")
    53  
    54  	b, err := r.ReadByte()
    55  	if err != nil {
    56  		return header{}, err
    57  	}
    58  
    59  	h.fin = b&(1<<7) != 0
    60  	h.rsv1 = b&(1<<6) != 0
    61  	h.rsv2 = b&(1<<5) != 0
    62  	h.rsv3 = b&(1<<4) != 0
    63  
    64  	h.opcode = opcode(b & 0xf)
    65  
    66  	b, err = r.ReadByte()
    67  	if err != nil {
    68  		return header{}, err
    69  	}
    70  
    71  	h.masked = b&(1<<7) != 0
    72  
    73  	payloadLength := b &^ (1 << 7)
    74  	switch {
    75  	case payloadLength < 126:
    76  		h.payloadLength = int64(payloadLength)
    77  	case payloadLength == 126:
    78  		_, err = io.ReadFull(r, readBuf[:2])
    79  		h.payloadLength = int64(binary.BigEndian.Uint16(readBuf))
    80  	case payloadLength == 127:
    81  		_, err = io.ReadFull(r, readBuf)
    82  		h.payloadLength = int64(binary.BigEndian.Uint64(readBuf))
    83  	}
    84  	if err != nil {
    85  		return header{}, err
    86  	}
    87  
    88  	if h.payloadLength < 0 {
    89  		return header{}, fmt.Errorf("received negative payload length: %v", h.payloadLength)
    90  	}
    91  
    92  	if h.masked {
    93  		_, err = io.ReadFull(r, readBuf[:4])
    94  		if err != nil {
    95  			return header{}, err
    96  		}
    97  		h.maskKey = binary.LittleEndian.Uint32(readBuf)
    98  	}
    99  
   100  	return h, nil
   101  }
   102  
   103  // maxControlPayload is the maximum length of a control frame payload.
   104  // See https://tools.ietf.org/html/rfc6455#section-5.5.
   105  const maxControlPayload = 125
   106  
   107  // writeFrameHeader writes the bytes of the header to w.
   108  // See https://tools.ietf.org/html/rfc6455#section-5.2
   109  func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) {
   110  	defer errd.Wrap(&err, "failed to write frame header")
   111  
   112  	var b byte
   113  	if h.fin {
   114  		b |= 1 << 7
   115  	}
   116  	if h.rsv1 {
   117  		b |= 1 << 6
   118  	}
   119  	if h.rsv2 {
   120  		b |= 1 << 5
   121  	}
   122  	if h.rsv3 {
   123  		b |= 1 << 4
   124  	}
   125  
   126  	b |= byte(h.opcode)
   127  
   128  	err = w.WriteByte(b)
   129  	if err != nil {
   130  		return err
   131  	}
   132  
   133  	lengthByte := byte(0)
   134  	if h.masked {
   135  		lengthByte |= 1 << 7
   136  	}
   137  
   138  	switch {
   139  	case h.payloadLength > math.MaxUint16:
   140  		lengthByte |= 127
   141  	case h.payloadLength > 125:
   142  		lengthByte |= 126
   143  	case h.payloadLength >= 0:
   144  		lengthByte |= byte(h.payloadLength)
   145  	}
   146  	err = w.WriteByte(lengthByte)
   147  	if err != nil {
   148  		return err
   149  	}
   150  
   151  	switch {
   152  	case h.payloadLength > math.MaxUint16:
   153  		binary.BigEndian.PutUint64(buf, uint64(h.payloadLength))
   154  		_, err = w.Write(buf)
   155  	case h.payloadLength > 125:
   156  		binary.BigEndian.PutUint16(buf, uint16(h.payloadLength))
   157  		_, err = w.Write(buf[:2])
   158  	}
   159  	if err != nil {
   160  		return err
   161  	}
   162  
   163  	if h.masked {
   164  		binary.LittleEndian.PutUint32(buf, h.maskKey)
   165  		_, err = w.Write(buf[:4])
   166  		if err != nil {
   167  			return err
   168  		}
   169  	}
   170  
   171  	return nil
   172  }
   173  
   174  // mask applies the WebSocket masking algorithm to p
   175  // with the given key.
   176  // See https://tools.ietf.org/html/rfc6455#section-5.3
   177  //
   178  // The returned value is the correctly rotated key to
   179  // to continue to mask/unmask the message.
   180  //
   181  // It is optimized for LittleEndian and expects the key
   182  // to be in little endian.
   183  //
   184  // See https://github.com/golang/go/issues/31586
   185  func mask(key uint32, b []byte) uint32 {
   186  	if len(b) >= 8 {
   187  		key64 := uint64(key)<<32 | uint64(key)
   188  
   189  		// At some point in the future we can clean these unrolled loops up.
   190  		// See https://github.com/golang/go/issues/31586#issuecomment-487436401
   191  
   192  		// Then we xor until b is less than 128 bytes.
   193  		for len(b) >= 128 {
   194  			v := binary.LittleEndian.Uint64(b)
   195  			binary.LittleEndian.PutUint64(b, v^key64)
   196  			v = binary.LittleEndian.Uint64(b[8:16])
   197  			binary.LittleEndian.PutUint64(b[8:16], v^key64)
   198  			v = binary.LittleEndian.Uint64(b[16:24])
   199  			binary.LittleEndian.PutUint64(b[16:24], v^key64)
   200  			v = binary.LittleEndian.Uint64(b[24:32])
   201  			binary.LittleEndian.PutUint64(b[24:32], v^key64)
   202  			v = binary.LittleEndian.Uint64(b[32:40])
   203  			binary.LittleEndian.PutUint64(b[32:40], v^key64)
   204  			v = binary.LittleEndian.Uint64(b[40:48])
   205  			binary.LittleEndian.PutUint64(b[40:48], v^key64)
   206  			v = binary.LittleEndian.Uint64(b[48:56])
   207  			binary.LittleEndian.PutUint64(b[48:56], v^key64)
   208  			v = binary.LittleEndian.Uint64(b[56:64])
   209  			binary.LittleEndian.PutUint64(b[56:64], v^key64)
   210  			v = binary.LittleEndian.Uint64(b[64:72])
   211  			binary.LittleEndian.PutUint64(b[64:72], v^key64)
   212  			v = binary.LittleEndian.Uint64(b[72:80])
   213  			binary.LittleEndian.PutUint64(b[72:80], v^key64)
   214  			v = binary.LittleEndian.Uint64(b[80:88])
   215  			binary.LittleEndian.PutUint64(b[80:88], v^key64)
   216  			v = binary.LittleEndian.Uint64(b[88:96])
   217  			binary.LittleEndian.PutUint64(b[88:96], v^key64)
   218  			v = binary.LittleEndian.Uint64(b[96:104])
   219  			binary.LittleEndian.PutUint64(b[96:104], v^key64)
   220  			v = binary.LittleEndian.Uint64(b[104:112])
   221  			binary.LittleEndian.PutUint64(b[104:112], v^key64)
   222  			v = binary.LittleEndian.Uint64(b[112:120])
   223  			binary.LittleEndian.PutUint64(b[112:120], v^key64)
   224  			v = binary.LittleEndian.Uint64(b[120:128])
   225  			binary.LittleEndian.PutUint64(b[120:128], v^key64)
   226  			b = b[128:]
   227  		}
   228  
   229  		// Then we xor until b is less than 64 bytes.
   230  		for len(b) >= 64 {
   231  			v := binary.LittleEndian.Uint64(b)
   232  			binary.LittleEndian.PutUint64(b, v^key64)
   233  			v = binary.LittleEndian.Uint64(b[8:16])
   234  			binary.LittleEndian.PutUint64(b[8:16], v^key64)
   235  			v = binary.LittleEndian.Uint64(b[16:24])
   236  			binary.LittleEndian.PutUint64(b[16:24], v^key64)
   237  			v = binary.LittleEndian.Uint64(b[24:32])
   238  			binary.LittleEndian.PutUint64(b[24:32], v^key64)
   239  			v = binary.LittleEndian.Uint64(b[32:40])
   240  			binary.LittleEndian.PutUint64(b[32:40], v^key64)
   241  			v = binary.LittleEndian.Uint64(b[40:48])
   242  			binary.LittleEndian.PutUint64(b[40:48], v^key64)
   243  			v = binary.LittleEndian.Uint64(b[48:56])
   244  			binary.LittleEndian.PutUint64(b[48:56], v^key64)
   245  			v = binary.LittleEndian.Uint64(b[56:64])
   246  			binary.LittleEndian.PutUint64(b[56:64], v^key64)
   247  			b = b[64:]
   248  		}
   249  
   250  		// Then we xor until b is less than 32 bytes.
   251  		for len(b) >= 32 {
   252  			v := binary.LittleEndian.Uint64(b)
   253  			binary.LittleEndian.PutUint64(b, v^key64)
   254  			v = binary.LittleEndian.Uint64(b[8:16])
   255  			binary.LittleEndian.PutUint64(b[8:16], v^key64)
   256  			v = binary.LittleEndian.Uint64(b[16:24])
   257  			binary.LittleEndian.PutUint64(b[16:24], v^key64)
   258  			v = binary.LittleEndian.Uint64(b[24:32])
   259  			binary.LittleEndian.PutUint64(b[24:32], v^key64)
   260  			b = b[32:]
   261  		}
   262  
   263  		// Then we xor until b is less than 16 bytes.
   264  		for len(b) >= 16 {
   265  			v := binary.LittleEndian.Uint64(b)
   266  			binary.LittleEndian.PutUint64(b, v^key64)
   267  			v = binary.LittleEndian.Uint64(b[8:16])
   268  			binary.LittleEndian.PutUint64(b[8:16], v^key64)
   269  			b = b[16:]
   270  		}
   271  
   272  		// Then we xor until b is less than 8 bytes.
   273  		for len(b) >= 8 {
   274  			v := binary.LittleEndian.Uint64(b)
   275  			binary.LittleEndian.PutUint64(b, v^key64)
   276  			b = b[8:]
   277  		}
   278  	}
   279  
   280  	// Then we xor until b is less than 4 bytes.
   281  	for len(b) >= 4 {
   282  		v := binary.LittleEndian.Uint32(b)
   283  		binary.LittleEndian.PutUint32(b, v^key)
   284  		b = b[4:]
   285  	}
   286  
   287  	// xor remaining bytes.
   288  	for i := range b {
   289  		b[i] ^= byte(key)
   290  		key = bits.RotateLeft32(key, -8)
   291  	}
   292  
   293  	return key
   294  }