github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/websocket/x/frame.go (about)

     1  package websocket
     2  
     3  import (
     4  	"encoding/binary"
     5  	"fmt"
     6  	"io"
     7  	"math"
     8  	"math/bits"
     9  
    10  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    11  	"github.com/Asutorufa/yuhaiin/pkg/net/proxy/yuubinsya/types"
    12  )
    13  
    14  // opcode represents a WebSocket opcode.
    15  type opcode byte
    16  
    17  // https://tools.ietf.org/html/rfc6455#section-11.8.
    18  const (
    19  	opContinuation opcode = iota
    20  	opText
    21  	opBinary
    22  	// 3 - 7 are reserved for further non-control frames.
    23  	_
    24  	_
    25  	_
    26  	_
    27  	_
    28  	opClose
    29  	opPing
    30  	opPong
    31  	// 11-16 are reserved for further control frames.
    32  )
    33  
    34  // Header represents a WebSocket frame Header.
    35  // See https://tools.ietf.org/html/rfc6455#section-5.2.
    36  type Header struct {
    37  	fin    bool
    38  	rsv1   bool
    39  	rsv2   bool
    40  	rsv3   bool
    41  	opcode opcode
    42  
    43  	payloadLength int64
    44  
    45  	masked  bool
    46  	maskKey uint32
    47  }
    48  
    49  // readFrameHeader reads a header from the reader.
    50  // See https://tools.ietf.org/html/rfc6455#section-5.2.
    51  func readFrameHeader(r *netapi.Reader, readBuf []byte) (h Header, err error) {
    52  	b, err := r.ReadByte()
    53  	if err != nil {
    54  		return Header{}, err
    55  	}
    56  
    57  	h.fin = b&(1<<7) != 0
    58  	h.rsv1 = b&(1<<6) != 0
    59  	h.rsv2 = b&(1<<5) != 0
    60  	h.rsv3 = b&(1<<4) != 0
    61  
    62  	h.opcode = opcode(b & 0xf)
    63  
    64  	b, err = r.ReadByte()
    65  	if err != nil {
    66  		return Header{}, err
    67  	}
    68  
    69  	h.masked = b&(1<<7) != 0
    70  
    71  	payloadLength := b &^ (1 << 7)
    72  	switch {
    73  	case payloadLength < 126:
    74  		h.payloadLength = int64(payloadLength)
    75  	case payloadLength == 126:
    76  		_, err = io.ReadFull(r, readBuf[:2])
    77  		h.payloadLength = int64(binary.BigEndian.Uint16(readBuf))
    78  	case payloadLength == 127:
    79  		_, err = io.ReadFull(r, readBuf)
    80  		h.payloadLength = int64(binary.BigEndian.Uint64(readBuf))
    81  	}
    82  	if err != nil {
    83  		return Header{}, err
    84  	}
    85  
    86  	if h.payloadLength < 0 {
    87  		return Header{}, fmt.Errorf("received negative payload length: %v", h.payloadLength)
    88  	}
    89  
    90  	if h.masked {
    91  		_, err = io.ReadFull(r, readBuf[:4])
    92  		if err != nil {
    93  			return Header{}, err
    94  		}
    95  		h.maskKey = binary.LittleEndian.Uint32(readBuf)
    96  	}
    97  
    98  	return h, nil
    99  }
   100  
   101  // maxControlPayload is the maximum length of a control frame payload.
   102  // See https://tools.ietf.org/html/rfc6455#section-5.5.
   103  const maxControlPayload = 125
   104  
   105  // writeFrameHeader writes the bytes of the header to w.
   106  // See https://tools.ietf.org/html/rfc6455#section-5.2
   107  func writeFrameHeader(h Header, w types.Buffer, buf []byte) (err error) {
   108  	var b byte
   109  	if h.fin {
   110  		b |= 1 << 7
   111  	}
   112  	if h.rsv1 {
   113  		b |= 1 << 6
   114  	}
   115  	if h.rsv2 {
   116  		b |= 1 << 5
   117  	}
   118  	if h.rsv3 {
   119  		b |= 1 << 4
   120  	}
   121  
   122  	b |= byte(h.opcode)
   123  
   124  	err = w.WriteByte(b)
   125  	if err != nil {
   126  		return err
   127  	}
   128  
   129  	lengthByte := byte(0)
   130  	if h.masked {
   131  		lengthByte |= 1 << 7
   132  	}
   133  
   134  	switch {
   135  	case h.payloadLength > math.MaxUint16:
   136  		lengthByte |= 127
   137  	case h.payloadLength > 125:
   138  		lengthByte |= 126
   139  	case h.payloadLength >= 0:
   140  		lengthByte |= byte(h.payloadLength)
   141  	}
   142  	err = w.WriteByte(lengthByte)
   143  	if err != nil {
   144  		return err
   145  	}
   146  
   147  	switch {
   148  	case h.payloadLength > math.MaxUint16:
   149  		binary.BigEndian.PutUint64(buf, uint64(h.payloadLength))
   150  		_, err = w.Write(buf)
   151  	case h.payloadLength > 125:
   152  		binary.BigEndian.PutUint16(buf, uint16(h.payloadLength))
   153  		_, err = w.Write(buf[:2])
   154  	}
   155  	if err != nil {
   156  		return err
   157  	}
   158  
   159  	if h.masked {
   160  		binary.LittleEndian.PutUint32(buf, h.maskKey)
   161  		_, err = w.Write(buf[:4])
   162  		if err != nil {
   163  			return err
   164  		}
   165  	}
   166  
   167  	return nil
   168  }
   169  
   170  // mask applies the WebSocket masking algorithm to p
   171  // with the given key.
   172  // See https://tools.ietf.org/html/rfc6455#section-5.3
   173  //
   174  // The returned value is the correctly rotated key to
   175  // to continue to mask/unmask the message.
   176  //
   177  // It is optimized for LittleEndian and expects the key
   178  // to be in little endian.
   179  //
   180  // See https://github.com/golang/go/issues/31586
   181  func mask(key uint32, b []byte) uint32 {
   182  	if len(b) >= 8 {
   183  		key64 := uint64(key)<<32 | uint64(key)
   184  
   185  		// At some point in the future we can clean these unrolled loops up.
   186  		// See https://github.com/golang/go/issues/31586#issuecomment-487436401
   187  
   188  		// Then we xor until b is less than 128 bytes.
   189  		for len(b) >= 128 {
   190  			v := binary.LittleEndian.Uint64(b)
   191  			binary.LittleEndian.PutUint64(b, v^key64)
   192  			v = binary.LittleEndian.Uint64(b[8:16])
   193  			binary.LittleEndian.PutUint64(b[8:16], v^key64)
   194  			v = binary.LittleEndian.Uint64(b[16:24])
   195  			binary.LittleEndian.PutUint64(b[16:24], v^key64)
   196  			v = binary.LittleEndian.Uint64(b[24:32])
   197  			binary.LittleEndian.PutUint64(b[24:32], v^key64)
   198  			v = binary.LittleEndian.Uint64(b[32:40])
   199  			binary.LittleEndian.PutUint64(b[32:40], v^key64)
   200  			v = binary.LittleEndian.Uint64(b[40:48])
   201  			binary.LittleEndian.PutUint64(b[40:48], v^key64)
   202  			v = binary.LittleEndian.Uint64(b[48:56])
   203  			binary.LittleEndian.PutUint64(b[48:56], v^key64)
   204  			v = binary.LittleEndian.Uint64(b[56:64])
   205  			binary.LittleEndian.PutUint64(b[56:64], v^key64)
   206  			v = binary.LittleEndian.Uint64(b[64:72])
   207  			binary.LittleEndian.PutUint64(b[64:72], v^key64)
   208  			v = binary.LittleEndian.Uint64(b[72:80])
   209  			binary.LittleEndian.PutUint64(b[72:80], v^key64)
   210  			v = binary.LittleEndian.Uint64(b[80:88])
   211  			binary.LittleEndian.PutUint64(b[80:88], v^key64)
   212  			v = binary.LittleEndian.Uint64(b[88:96])
   213  			binary.LittleEndian.PutUint64(b[88:96], v^key64)
   214  			v = binary.LittleEndian.Uint64(b[96:104])
   215  			binary.LittleEndian.PutUint64(b[96:104], v^key64)
   216  			v = binary.LittleEndian.Uint64(b[104:112])
   217  			binary.LittleEndian.PutUint64(b[104:112], v^key64)
   218  			v = binary.LittleEndian.Uint64(b[112:120])
   219  			binary.LittleEndian.PutUint64(b[112:120], v^key64)
   220  			v = binary.LittleEndian.Uint64(b[120:128])
   221  			binary.LittleEndian.PutUint64(b[120:128], v^key64)
   222  			b = b[128:]
   223  		}
   224  
   225  		// Then we xor until b is less than 64 bytes.
   226  		for len(b) >= 64 {
   227  			v := binary.LittleEndian.Uint64(b)
   228  			binary.LittleEndian.PutUint64(b, v^key64)
   229  			v = binary.LittleEndian.Uint64(b[8:16])
   230  			binary.LittleEndian.PutUint64(b[8:16], v^key64)
   231  			v = binary.LittleEndian.Uint64(b[16:24])
   232  			binary.LittleEndian.PutUint64(b[16:24], v^key64)
   233  			v = binary.LittleEndian.Uint64(b[24:32])
   234  			binary.LittleEndian.PutUint64(b[24:32], v^key64)
   235  			v = binary.LittleEndian.Uint64(b[32:40])
   236  			binary.LittleEndian.PutUint64(b[32:40], v^key64)
   237  			v = binary.LittleEndian.Uint64(b[40:48])
   238  			binary.LittleEndian.PutUint64(b[40:48], v^key64)
   239  			v = binary.LittleEndian.Uint64(b[48:56])
   240  			binary.LittleEndian.PutUint64(b[48:56], v^key64)
   241  			v = binary.LittleEndian.Uint64(b[56:64])
   242  			binary.LittleEndian.PutUint64(b[56:64], v^key64)
   243  			b = b[64:]
   244  		}
   245  
   246  		// Then we xor until b is less than 32 bytes.
   247  		for len(b) >= 32 {
   248  			v := binary.LittleEndian.Uint64(b)
   249  			binary.LittleEndian.PutUint64(b, v^key64)
   250  			v = binary.LittleEndian.Uint64(b[8:16])
   251  			binary.LittleEndian.PutUint64(b[8:16], v^key64)
   252  			v = binary.LittleEndian.Uint64(b[16:24])
   253  			binary.LittleEndian.PutUint64(b[16:24], v^key64)
   254  			v = binary.LittleEndian.Uint64(b[24:32])
   255  			binary.LittleEndian.PutUint64(b[24:32], v^key64)
   256  			b = b[32:]
   257  		}
   258  
   259  		// Then we xor until b is less than 16 bytes.
   260  		for len(b) >= 16 {
   261  			v := binary.LittleEndian.Uint64(b)
   262  			binary.LittleEndian.PutUint64(b, v^key64)
   263  			v = binary.LittleEndian.Uint64(b[8:16])
   264  			binary.LittleEndian.PutUint64(b[8:16], v^key64)
   265  			b = b[16:]
   266  		}
   267  
   268  		// Then we xor until b is less than 8 bytes.
   269  		for len(b) >= 8 {
   270  			v := binary.LittleEndian.Uint64(b)
   271  			binary.LittleEndian.PutUint64(b, v^key64)
   272  			b = b[8:]
   273  		}
   274  	}
   275  
   276  	// Then we xor until b is less than 4 bytes.
   277  	for len(b) >= 4 {
   278  		v := binary.LittleEndian.Uint32(b)
   279  		binary.LittleEndian.PutUint32(b, v^key)
   280  		b = b[4:]
   281  	}
   282  
   283  	// xor remaining bytes.
   284  	for i := range b {
   285  		b[i] ^= byte(key)
   286  		key = bits.RotateLeft32(key, -8)
   287  	}
   288  
   289  	return key
   290  }