github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/ss2022/header.go (about)

     1  package ss2022
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"net/netip"
     9  	"strconv"
    10  	"time"
    11  
    12  	"github.com/database64128/shadowsocks-go/conn"
    13  	"github.com/database64128/shadowsocks-go/socks5"
    14  )
    15  
    16  const (
    17  	HeaderTypeClientStream = 0
    18  	HeaderTypeServerStream = 1
    19  
    20  	HeaderTypeClientPacket = 0
    21  	HeaderTypeServerPacket = 1
    22  
    23  	MinPaddingLength = 0
    24  	MaxPaddingLength = 900
    25  
    26  	IdentityHeaderLength = 16
    27  
    28  	// type + unix epoch timestamp + u16be length
    29  	TCPRequestFixedLengthHeaderLength = 1 + 8 + 2
    30  
    31  	// SOCKS address + padding length + padding
    32  	TCPRequestVariableLengthHeaderNoPayloadMaxLength = socks5.MaxAddrLen + 2 + MaxPaddingLength
    33  
    34  	// type + unix epoch timestamp + request salt + u16be length
    35  	TCPResponseHeaderMaxLength = 1 + 8 + 32 + 2
    36  
    37  	// session ID + packet ID
    38  	UDPSeparateHeaderLength = 8 + 8
    39  
    40  	// type + unix epoch timestamp + padding length
    41  	UDPClientMessageHeaderFixedLength = 1 + 8 + 2
    42  
    43  	// type + unix epoch timestamp + client session id + padding length
    44  	UDPServerMessageHeaderFixedLength = 1 + 8 + 8 + 2
    45  
    46  	// type + unix epoch timestamp + padding length + padding + SOCKS address
    47  	UDPClientMessageHeaderMaxLength = UDPClientMessageHeaderFixedLength + MaxPaddingLength + socks5.MaxAddrLen
    48  
    49  	// type + unix epoch timestamp + client session id + padding length + padding + SOCKS address
    50  	UDPServerMessageHeaderMaxLength = UDPServerMessageHeaderFixedLength + MaxPaddingLength + socks5.IPv6AddrLen
    51  
    52  	// MaxEpochDiff is the maximum allowed time difference between a received timestamp and system time.
    53  	MaxEpochDiff = 30
    54  
    55  	// MaxTimeDiff is the maximum allowed time difference between a received timestamp and system time.
    56  	MaxTimeDiff = MaxEpochDiff * time.Second
    57  
    58  	// ReplayWindowDuration defines the amount of time during which a salt check is necessary.
    59  	ReplayWindowDuration = MaxTimeDiff * 2
    60  
    61  	// DefaultSlidingWindowFilterSize is the default size of the sliding window filter.
    62  	DefaultSlidingWindowFilterSize = 256
    63  )
    64  
    65  var (
    66  	ErrIncompleteHeaderInFirstChunk  = errors.New("header in first chunk is missing or incomplete")
    67  	ErrPaddingExceedChunkBorder      = errors.New("padding in first chunk is shorter than advertised")
    68  	ErrBadTimestamp                  = errors.New("time diff is over 30 seconds")
    69  	ErrTypeMismatch                  = errors.New("header type mismatch")
    70  	ErrClientSaltMismatch            = errors.New("client salt in response header does not match request")
    71  	ErrClientSessionIDMismatch       = errors.New("client session ID in server message header does not match current session")
    72  	ErrTooManyServerSessions         = errors.New("server session changed more than once during the last minute")
    73  	ErrPacketIncompleteHeader        = errors.New("packet contains incomplete header")
    74  	ErrReplay                        = errors.New("detected replay")
    75  	ErrIdentityHeaderUserPSKNotFound = errors.New("decrypted identity header does not match any known uPSK")
    76  )
    77  
    78  type HeaderError[T any] struct {
    79  	Err      error
    80  	Expected T
    81  	Got      T
    82  }
    83  
    84  func (e *HeaderError[T]) Unwrap() error {
    85  	return e.Err
    86  }
    87  
    88  func (e *HeaderError[T]) Error() string {
    89  	return fmt.Sprintf("%s: expected %v, got %v", e.Err.Error(), e.Expected, e.Got)
    90  }
    91  
    92  // ValidateUnixEpochTimestamp validates the Unix Epoch timestamp in the buffer
    93  // and returns an error if the timestamp exceeds the allowed time difference from system time.
    94  //
    95  // This function does not check buffer length. Make sure it's exactly 8 bytes long.
    96  func ValidateUnixEpochTimestamp(b []byte) error {
    97  	tsEpoch := int64(binary.BigEndian.Uint64(b))
    98  	nowEpoch := time.Now().Unix()
    99  	diff := tsEpoch - nowEpoch
   100  	if diff < -MaxEpochDiff || diff > MaxEpochDiff {
   101  		return &HeaderError[int64]{ErrBadTimestamp, nowEpoch, tsEpoch}
   102  	}
   103  	return nil
   104  }
   105  
   106  func intToUint16(i int) (u uint16) {
   107  	u = uint16(i)
   108  	if int(u) != i {
   109  		panic("int -> uint16 overflowed: " + strconv.Itoa(i))
   110  	}
   111  	return
   112  }
   113  
   114  // ParseTCPRequestFixedLengthHeader parses a TCP request fixed-length header and returns the length
   115  // of the variable-length header, or an error if header validation fails.
   116  //
   117  // The buffer must be exactly 11 bytes long. No buffer length checks are performed.
   118  //
   119  // Request fixed-length header:
   120  //
   121  //	+------+---------------+--------+
   122  //	| type |   timestamp   | length |
   123  //	+------+---------------+--------+
   124  //	|  1B  | 8B unix epoch |  u16be |
   125  //	+------+---------------+--------+
   126  func ParseTCPRequestFixedLengthHeader(b []byte) (n int, err error) {
   127  	// Type
   128  	if b[0] != HeaderTypeClientStream {
   129  		err = &HeaderError[byte]{ErrTypeMismatch, HeaderTypeClientStream, b[0]}
   130  		return
   131  	}
   132  
   133  	// Timestamp
   134  	err = ValidateUnixEpochTimestamp(b[1:])
   135  	if err != nil {
   136  		return
   137  	}
   138  
   139  	// Length
   140  	n = int(binary.BigEndian.Uint16(b[1+8:]))
   141  
   142  	return
   143  }
   144  
   145  // WriteTCPRequestFixedLengthHeader writes a TCP request fixed-length header into the buffer.
   146  //
   147  // The buffer must be at least 11 bytes long. No buffer length checks are performed.
   148  func WriteTCPRequestFixedLengthHeader(b []byte, length uint16) {
   149  	// Type
   150  	b[0] = HeaderTypeClientStream
   151  
   152  	// Timestamp
   153  	binary.BigEndian.PutUint64(b[1:], uint64(time.Now().Unix()))
   154  
   155  	// Length
   156  	binary.BigEndian.PutUint16(b[1+8:], length)
   157  }
   158  
   159  // ParseTCPRequestVariableLengthHeader parses a TCP request variable-length header and returns
   160  // the target address, the initial payload if available, or an error if header validation fails.
   161  //
   162  // This function does buffer length checks and returns ErrIncompleteHeaderInFirstChunk if the buffer is too short.
   163  //
   164  // Request variable-length header:
   165  //
   166  //	+------+----------+-------+----------------+----------+-----------------+
   167  //	| ATYP |  address |  port | padding length |  padding | initial payload |
   168  //	+------+----------+-------+----------------+----------+-----------------+
   169  //	|  1B  | variable | u16be |     u16be      | variable |    variable     |
   170  //	+------+----------+-------+----------------+----------+-----------------+
   171  func ParseTCPRequestVariableLengthHeader(b []byte) (targetAddr conn.Addr, payload []byte, err error) {
   172  	// SOCKS address
   173  	targetAddr, n, err := socks5.ConnAddrFromSlice(b)
   174  	if err != nil {
   175  		return
   176  	}
   177  	b = b[n:]
   178  
   179  	// Make sure the remaining length > 2 (padding length + either padding or payload)
   180  	if len(b) <= 2 {
   181  		err = ErrIncompleteHeaderInFirstChunk
   182  		return
   183  	}
   184  
   185  	// Padding length
   186  	paddingLen := int(binary.BigEndian.Uint16(b))
   187  
   188  	// Padding
   189  	if 2+paddingLen > len(b) {
   190  		err = &HeaderError[int]{ErrPaddingExceedChunkBorder, len(b), 2 + paddingLen}
   191  		return
   192  	}
   193  
   194  	// Initial payload
   195  	payload = b[2+paddingLen:]
   196  
   197  	return
   198  }
   199  
   200  // WriteTCPRequestVariableLengthHeader writes a TCP request variable-length header into the buffer.
   201  //
   202  // The header fills the whole buffer. Excess bytes are used as padding.
   203  //
   204  // The buffer size can be calculated with:
   205  //
   206  //	socks5.LengthOfAddrFromConnAddr(targetAddr) + 2 + len(payload) + paddingLen
   207  //
   208  // The buffer size must not exceed [MaxPayloadSize].
   209  // The excess space in the buffer must not be larger than [MaxPaddingLength] bytes.
   210  func WriteTCPRequestVariableLengthHeader(b []byte, targetAddr conn.Addr, payload []byte) {
   211  	// SOCKS address
   212  	n := socks5.WriteAddrFromConnAddr(b, targetAddr)
   213  
   214  	// Padding length
   215  	paddingLen := len(b) - n - 2 - len(payload)
   216  	binary.BigEndian.PutUint16(b[n:], intToUint16(paddingLen))
   217  	n += 2 + paddingLen
   218  
   219  	// Initial payload
   220  	copy(b[n:], payload)
   221  }
   222  
   223  // ParseTCPResponseHeader parses a TCP response fixed-length header and returns the length
   224  // of the next payload chunk, or an error if header validation fails.
   225  //
   226  // The buffer must be exactly 1 + 8 + salt length + 2 bytes long. No buffer length checks are performed.
   227  //
   228  // Response fixed-length header:
   229  //
   230  //	+------+---------------+----------------+--------+
   231  //	| type |   timestamp   |  request salt  | length |
   232  //	+------+---------------+----------------+--------+
   233  //	|  1B  | 8B unix epoch |     16/32B     |  u16be |
   234  //	+------+---------------+----------------+--------+
   235  func ParseTCPResponseHeader(b []byte, requestSalt []byte) (n int, err error) {
   236  	// Type
   237  	if b[0] != HeaderTypeServerStream {
   238  		err = &HeaderError[byte]{ErrTypeMismatch, HeaderTypeServerStream, b[0]}
   239  		return
   240  	}
   241  
   242  	// Timestamp
   243  	err = ValidateUnixEpochTimestamp(b[1 : 1+8])
   244  	if err != nil {
   245  		return
   246  	}
   247  
   248  	// Request salt
   249  	rSalt := b[1+8 : 1+8+len(requestSalt)]
   250  	if !bytes.Equal(requestSalt, rSalt) {
   251  		err = &HeaderError[[]byte]{ErrClientSaltMismatch, requestSalt, rSalt}
   252  		return
   253  	}
   254  
   255  	// Length
   256  	n = int(binary.BigEndian.Uint16(b[1+8+len(requestSalt):]))
   257  
   258  	return
   259  }
   260  
   261  // WriteTCPResponseHeader writes a TCP response fixed-length header into the buffer.
   262  //
   263  // The buffer size must be exactly 1 + 8 + len(requestSalt) + 2 bytes.
   264  func WriteTCPResponseHeader(b []byte, requestSalt []byte, length uint16) {
   265  	// Type
   266  	b[0] = HeaderTypeServerStream
   267  
   268  	// Timestamp
   269  	binary.BigEndian.PutUint64(b[1:], uint64(time.Now().Unix()))
   270  
   271  	// Request salt
   272  	copy(b[1+8:], requestSalt)
   273  
   274  	// Length
   275  	binary.BigEndian.PutUint16(b[1+8+len(requestSalt):], length)
   276  }
   277  
   278  // ParseSessionIDAndPacketID parses the session ID and packet ID segment of a decrypted UDP packet.
   279  //
   280  // The buffer must be exactly 16 bytes long. No buffer length checks are performed.
   281  //
   282  // Session ID and packet ID segment:
   283  //
   284  //	+------------+-----------+
   285  //	| session ID | packet ID |
   286  //	+------------+-----------+
   287  //	|     8B     |   u64be   |
   288  //	+------------+-----------+
   289  func ParseSessionIDAndPacketID(b []byte) (sid, pid uint64) {
   290  	sid = binary.BigEndian.Uint64(b)
   291  	pid = binary.BigEndian.Uint64(b[8:])
   292  	return
   293  }
   294  
   295  // WriteSessionIDAndPacketID writes the session ID and packet ID to the buffer.
   296  //
   297  // The buffer must be exactly 16 bytes long. No buffer length checks are performed.
   298  func WriteSessionIDAndPacketID(b []byte, sid, pid uint64) {
   299  	binary.BigEndian.PutUint64(b, sid)
   300  	binary.BigEndian.PutUint64(b[8:], pid)
   301  }
   302  
   303  // ParseUDPClientMessageHeader parses a UDP client message header and returns the target address
   304  // and payload, or an error if header validation fails or no payload is in the buffer.
   305  //
   306  // This function accepts buffers of arbitrary lengths.
   307  //
   308  // The buffer is expected to contain a decrypted client message in the following format:
   309  //
   310  //	+------+---------------+----------------+----------+------+----------+-------+----------+
   311  //	| type |   timestamp   | padding length |  padding | ATYP |  address |  port |  payload |
   312  //	+------+---------------+----------------+----------+------+----------+-------+----------+
   313  //	|  1B  | 8B unix epoch |     u16be      | variable |  1B  | variable | u16be | variable |
   314  //	+------+---------------+----------------+----------+------+----------+-------+----------+
   315  func ParseUDPClientMessageHeader(b []byte, cachedDomain string) (targetAddr conn.Addr, updatedCachedDomain string, payloadStart, payloadLen int, err error) {
   316  	updatedCachedDomain = cachedDomain
   317  
   318  	// Make sure buffer has type + timestamp + padding length.
   319  	if len(b) < UDPClientMessageHeaderFixedLength {
   320  		err = ErrPacketIncompleteHeader
   321  		return
   322  	}
   323  
   324  	// Type
   325  	if b[0] != HeaderTypeClientPacket {
   326  		err = &HeaderError[byte]{ErrTypeMismatch, HeaderTypeClientPacket, b[0]}
   327  		return
   328  	}
   329  
   330  	// Timestamp
   331  	err = ValidateUnixEpochTimestamp(b[1 : 1+8])
   332  	if err != nil {
   333  		return
   334  	}
   335  
   336  	// Padding length
   337  	paddingLen := int(binary.BigEndian.Uint16(b[1+8:]))
   338  
   339  	// Padding
   340  	payloadStart = UDPClientMessageHeaderFixedLength + paddingLen
   341  	if payloadStart > len(b) {
   342  		err = ErrPacketIncompleteHeader
   343  		return
   344  	}
   345  
   346  	// SOCKS address
   347  	var n int
   348  	targetAddr, n, updatedCachedDomain, err = socks5.ConnAddrFromSliceWithDomainCache(b[payloadStart:], cachedDomain)
   349  	if err != nil {
   350  		return
   351  	}
   352  
   353  	// Payload
   354  	payloadStart += n
   355  	payloadLen = len(b) - payloadStart
   356  	return
   357  }
   358  
   359  // WriteUDPClientMessageHeader writes a UDP client message header into the buffer.
   360  //
   361  // The buffer size must be exactly 1 + 8 + 2 + paddingLen + socks5.LengthOfAddrFromConnAddr(targetAddr) bytes.
   362  func WriteUDPClientMessageHeader(b []byte, paddingLen int, targetAddr conn.Addr) {
   363  	// Type
   364  	b[0] = HeaderTypeClientPacket
   365  
   366  	// Timestamp
   367  	binary.BigEndian.PutUint64(b[1:], uint64(time.Now().Unix()))
   368  
   369  	// Padding length
   370  	binary.BigEndian.PutUint16(b[1+8:], intToUint16(paddingLen))
   371  
   372  	// SOCKS address
   373  	socks5.WriteAddrFromConnAddr(b[1+8+2+paddingLen:], targetAddr)
   374  }
   375  
   376  // ParseUDPServerMessageHeader parses a UDP server message header and returns the payload source address
   377  // and payload, or an error if header validation fails or no payload is in the buffer.
   378  //
   379  // This function accepts buffers of arbitrary lengths.
   380  //
   381  // The buffer is expected to contain a decrypted server message in the following format:
   382  //
   383  //	+------+---------------+-------------------+----------------+----------+------+----------+-------+----------+
   384  //	| type |   timestamp   | client session ID | padding length |  padding | ATYP |  address |  port |  payload |
   385  //	+------+---------------+-------------------+----------------+----------+------+----------+-------+----------+
   386  //	|  1B  | 8B unix epoch |         8B        |     u16be      | variable |  1B  | variable | u16be | variable |
   387  //	+------+---------------+-------------------+----------------+----------+------+----------+-------+----------+
   388  func ParseUDPServerMessageHeader(b []byte, csid uint64) (payloadSourceAddrPort netip.AddrPort, payloadStart, payloadLen int, err error) {
   389  	// Make sure buffer has type + timestamp + client session ID + padding length.
   390  	if len(b) < UDPServerMessageHeaderFixedLength {
   391  		err = ErrPacketIncompleteHeader
   392  		return
   393  	}
   394  
   395  	// Type
   396  	if b[0] != HeaderTypeServerPacket {
   397  		err = &HeaderError[byte]{ErrTypeMismatch, HeaderTypeServerPacket, b[0]}
   398  		return
   399  	}
   400  
   401  	// Timestamp
   402  	err = ValidateUnixEpochTimestamp(b[1 : 1+8])
   403  	if err != nil {
   404  		return
   405  	}
   406  
   407  	// Client session ID
   408  	pcsid := binary.BigEndian.Uint64(b[1+8:])
   409  	if pcsid != csid {
   410  		err = &HeaderError[uint64]{ErrClientSessionIDMismatch, csid, pcsid}
   411  		return
   412  	}
   413  
   414  	// Padding length
   415  	paddingLen := int(binary.BigEndian.Uint16(b[1+8+8:]))
   416  
   417  	// Padding
   418  	payloadStart = UDPServerMessageHeaderFixedLength + paddingLen
   419  	if payloadStart > len(b) {
   420  		err = ErrPacketIncompleteHeader
   421  		return
   422  	}
   423  
   424  	// SOCKS address
   425  	payloadSourceAddrPort, n, err := socks5.AddrPortFromSlice(b[payloadStart:])
   426  	if err != nil {
   427  		return
   428  	}
   429  
   430  	// Payload
   431  	payloadStart += n
   432  	payloadLen = len(b) - payloadStart
   433  	return
   434  }
   435  
   436  // WriteUDPServerMessageHeader writes a UDP server message header into the buffer.
   437  //
   438  // The buffer size must be exactly 1 + 8 + 8 + 2 + paddingLen + socks5.LengthOfAddrFromAddrPort(sourceAddrPort) bytes.
   439  func WriteUDPServerMessageHeader(b []byte, csid uint64, paddingLen int, sourceAddrPort netip.AddrPort) {
   440  	// Type
   441  	b[0] = HeaderTypeServerPacket
   442  
   443  	// Timestamp
   444  	binary.BigEndian.PutUint64(b[1:], uint64(time.Now().Unix()))
   445  
   446  	// Client session ID
   447  	binary.BigEndian.PutUint64(b[1+8:], csid)
   448  
   449  	// Padding length
   450  	binary.BigEndian.PutUint16(b[1+8+8:], intToUint16(paddingLen))
   451  
   452  	// SOCKS address
   453  	socks5.WriteAddrFromAddrPort(b[1+8+8+2+paddingLen:], sourceAddrPort)
   454  }