github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/ss2022/tcp.go (about)

     1  package ss2022
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/cipher"
     7  	"crypto/rand"
     8  	"io"
     9  	mrand "math/rand/v2"
    10  
    11  	"github.com/database64128/shadowsocks-go/conn"
    12  	"github.com/database64128/shadowsocks-go/socks5"
    13  	"github.com/database64128/shadowsocks-go/zerocopy"
    14  )
    15  
    16  // TCPClient implements the zerocopy TCPClient interface.
    17  type TCPClient struct {
    18  	name                       string
    19  	rwo                        zerocopy.DirectReadWriteCloserOpener
    20  	readOnceOrFull             func(io.Reader, []byte) (int, error)
    21  	cipherConfig               *ClientCipherConfig
    22  	unsafeRequestStreamPrefix  []byte
    23  	unsafeResponseStreamPrefix []byte
    24  }
    25  
    26  func NewTCPClient(name, network, address string, dialer conn.Dialer, allowSegmentedFixedLengthHeader bool, cipherConfig *ClientCipherConfig, unsafeRequestStreamPrefix, unsafeResponseStreamPrefix []byte) *TCPClient {
    27  	return &TCPClient{
    28  		name:                       name,
    29  		rwo:                        zerocopy.NewTCPConnOpener(dialer, network, address),
    30  		readOnceOrFull:             readOnceOrFullFunc(allowSegmentedFixedLengthHeader),
    31  		cipherConfig:               cipherConfig,
    32  		unsafeRequestStreamPrefix:  unsafeRequestStreamPrefix,
    33  		unsafeResponseStreamPrefix: unsafeResponseStreamPrefix,
    34  	}
    35  }
    36  
    37  // Info implements the zerocopy.TCPClient Info method.
    38  func (c *TCPClient) Info() zerocopy.TCPClientInfo {
    39  	return zerocopy.TCPClientInfo{
    40  		Name:                 c.name,
    41  		NativeInitialPayload: true,
    42  	}
    43  }
    44  
    45  // Dial implements the zerocopy.TCPClient Dial method.
    46  func (c *TCPClient) Dial(ctx context.Context, targetAddr conn.Addr, payload []byte) (rawRW zerocopy.DirectReadWriteCloser, rw zerocopy.ReadWriter, err error) {
    47  	var (
    48  		paddingPayloadLen int
    49  		excessPayload     []byte
    50  	)
    51  
    52  	targetAddrLen := socks5.LengthOfAddrFromConnAddr(targetAddr)
    53  	payloadLen := len(payload)
    54  	roomForPayload := MaxPayloadSize - targetAddrLen - 2
    55  
    56  	switch {
    57  	case payloadLen > roomForPayload:
    58  		paddingPayloadLen = roomForPayload
    59  		excessPayload = payload[roomForPayload:]
    60  		payload = payload[:roomForPayload]
    61  	case payloadLen >= MaxPaddingLength:
    62  		paddingPayloadLen = payloadLen
    63  	case payloadLen > 0:
    64  		paddingPayloadLen = payloadLen + mrand.IntN(MaxPaddingLength-payloadLen+1)
    65  	default:
    66  		paddingPayloadLen = 1 + mrand.IntN(MaxPaddingLength)
    67  	}
    68  
    69  	urspLen := len(c.unsafeRequestStreamPrefix)
    70  	saltLen := len(c.cipherConfig.PSK)
    71  	eihPSKHashes := c.cipherConfig.EIHPSKHashes()
    72  	identityHeadersLen := IdentityHeaderLength * len(eihPSKHashes)
    73  	identityHeadersStart := urspLen + saltLen
    74  	fixedLengthHeaderStart := identityHeadersStart + identityHeadersLen
    75  	fixedLengthHeaderEnd := fixedLengthHeaderStart + TCPRequestFixedLengthHeaderLength
    76  	variableLengthHeaderStart := fixedLengthHeaderEnd + 16
    77  	variableLengthHeaderLen := targetAddrLen + 2 + paddingPayloadLen
    78  	variableLengthHeaderEnd := variableLengthHeaderStart + variableLengthHeaderLen
    79  	bufferLen := variableLengthHeaderEnd + 16
    80  	b := make([]byte, bufferLen)
    81  	ursp := b[:urspLen]
    82  	salt := b[urspLen:identityHeadersStart]
    83  	identityHeaders := b[identityHeadersStart:fixedLengthHeaderStart]
    84  	fixedLengthHeaderPlaintext := b[fixedLengthHeaderStart:fixedLengthHeaderEnd]
    85  	variableLengthHeaderPlaintext := b[variableLengthHeaderStart:variableLengthHeaderEnd]
    86  
    87  	// Write unsafe request stream prefix.
    88  	copy(ursp, c.unsafeRequestStreamPrefix)
    89  
    90  	// Random salt.
    91  	_, err = rand.Read(salt)
    92  	if err != nil {
    93  		return
    94  	}
    95  
    96  	// Write and encrypt identity headers.
    97  	eihCiphers, err := c.cipherConfig.TCPIdentityHeaderCiphers(salt)
    98  	if err != nil {
    99  		return
   100  	}
   101  
   102  	for i := range eihPSKHashes {
   103  		identityHeader := identityHeaders[i*IdentityHeaderLength : (i+1)*IdentityHeaderLength]
   104  		eihCiphers[i].Encrypt(identityHeader, eihPSKHashes[i][:])
   105  	}
   106  
   107  	// Write variable-length header.
   108  	WriteTCPRequestVariableLengthHeader(variableLengthHeaderPlaintext, targetAddr, payload)
   109  
   110  	// Write fixed-length header.
   111  	WriteTCPRequestFixedLengthHeader(fixedLengthHeaderPlaintext, uint16(variableLengthHeaderLen))
   112  
   113  	// Create AEAD cipher.
   114  	shadowStreamCipher, err := c.cipherConfig.ShadowStreamCipher(salt)
   115  	if err != nil {
   116  		return
   117  	}
   118  
   119  	// Seal fixed-length header.
   120  	shadowStreamCipher.EncryptInPlace(fixedLengthHeaderPlaintext)
   121  
   122  	// Seal variable-length header.
   123  	shadowStreamCipher.EncryptInPlace(variableLengthHeaderPlaintext)
   124  
   125  	// Write out.
   126  	rawRW, err = c.rwo.Open(ctx, b)
   127  	if err != nil {
   128  		return
   129  	}
   130  
   131  	w := ShadowStreamWriter{
   132  		writer: rawRW,
   133  		ssc:    shadowStreamCipher,
   134  	}
   135  
   136  	// Write excess payload, reusing the variable-length header buffer.
   137  	for len(excessPayload) > 0 {
   138  		n := copy(variableLengthHeaderPlaintext, excessPayload)
   139  		excessPayload = excessPayload[n:]
   140  		if _, err = w.WriteZeroCopy(b, variableLengthHeaderStart, n); err != nil {
   141  			rawRW.Close()
   142  			return
   143  		}
   144  	}
   145  
   146  	rw = &ShadowStreamClientReadWriter{
   147  		ShadowStreamWriter:         &w,
   148  		rawRW:                      rawRW,
   149  		readOnceOrFull:             c.readOnceOrFull,
   150  		cipherConfig:               c.cipherConfig,
   151  		requestSalt:                salt,
   152  		unsafeResponseStreamPrefix: c.unsafeResponseStreamPrefix,
   153  	}
   154  
   155  	return
   156  }
   157  
   158  // TCPServer implements the zerocopy TCPServer interface.
   159  type TCPServer struct {
   160  	CredStore
   161  	saltPool                   *SaltPool[string]
   162  	readOnceOrFull             func(io.Reader, []byte) (int, error)
   163  	userCipherConfig           UserCipherConfig
   164  	identityCipherConfig       ServerIdentityCipherConfig
   165  	unsafeRequestStreamPrefix  []byte
   166  	unsafeResponseStreamPrefix []byte
   167  }
   168  
   169  func NewTCPServer(allowSegmentedFixedLengthHeader bool, userCipherConfig UserCipherConfig, identityCipherConfig ServerIdentityCipherConfig, unsafeRequestStreamPrefix, unsafeResponseStreamPrefix []byte) *TCPServer {
   170  	return &TCPServer{
   171  		saltPool:                   NewSaltPool[string](ReplayWindowDuration),
   172  		readOnceOrFull:             readOnceOrFullFunc(allowSegmentedFixedLengthHeader),
   173  		userCipherConfig:           userCipherConfig,
   174  		identityCipherConfig:       identityCipherConfig,
   175  		unsafeRequestStreamPrefix:  unsafeRequestStreamPrefix,
   176  		unsafeResponseStreamPrefix: unsafeResponseStreamPrefix,
   177  	}
   178  }
   179  
   180  // Info implements the zerocopy.TCPServer Info method.
   181  func (s *TCPServer) Info() zerocopy.TCPServerInfo {
   182  	return zerocopy.TCPServerInfo{
   183  		NativeInitialPayload: true,
   184  		DefaultTCPConnCloser: zerocopy.ForceReset,
   185  	}
   186  }
   187  
   188  // Accept implements the zerocopy.TCPServer Accept method.
   189  func (s *TCPServer) Accept(rawRW zerocopy.DirectReadWriteCloser) (rw zerocopy.ReadWriter, targetAddr conn.Addr, payload []byte, username string, err error) {
   190  	var identityHeaderLen int
   191  	userCipherConfig := s.userCipherConfig
   192  	saltLen := len(userCipherConfig.PSK)
   193  	if saltLen == 0 {
   194  		saltLen = len(s.identityCipherConfig.IPSK)
   195  		identityHeaderLen = IdentityHeaderLength
   196  	}
   197  
   198  	urspLen := len(s.unsafeRequestStreamPrefix)
   199  	identityHeaderStart := urspLen + saltLen
   200  	fixedLengthHeaderStart := identityHeaderStart + identityHeaderLen
   201  	bufferLen := fixedLengthHeaderStart + TCPRequestFixedLengthHeaderLength + 16
   202  	b := make([]byte, bufferLen)
   203  
   204  	// Read unsafe request stream prefix, salt, identity header, fixed-length header.
   205  	n, err := s.readOnceOrFull(rawRW, b)
   206  	if err != nil {
   207  		payload = b[:n]
   208  		return
   209  	}
   210  
   211  	ursp := b[:urspLen]
   212  	salt := b[urspLen:identityHeaderStart]
   213  	ciphertext := b[fixedLengthHeaderStart:]
   214  
   215  	s.Lock()
   216  
   217  	// Check but not add request salt to pool.
   218  	if !s.saltPool.Check(string(salt)) { // Is the compiler smart enough to not incur an allocation here?
   219  		s.Unlock()
   220  		payload = b[:n]
   221  		err = ErrRepeatedSalt
   222  		return
   223  	}
   224  
   225  	// Check unsafe request stream prefix.
   226  	if !bytes.Equal(ursp, s.unsafeRequestStreamPrefix) {
   227  		s.Unlock()
   228  		payload = b[:n]
   229  		err = &HeaderError[[]byte]{ErrUnsafeStreamPrefixMismatch, s.unsafeRequestStreamPrefix, ursp}
   230  		return
   231  	}
   232  
   233  	// Process identity header.
   234  	if identityHeaderLen != 0 {
   235  		var identityHeaderCipher cipher.Block
   236  		identityHeaderCipher, err = s.identityCipherConfig.TCP(salt)
   237  		if err != nil {
   238  			s.Unlock()
   239  			return
   240  		}
   241  
   242  		var uPSKHash [IdentityHeaderLength]byte
   243  		identityHeader := b[identityHeaderStart:fixedLengthHeaderStart]
   244  		identityHeaderCipher.Decrypt(uPSKHash[:], identityHeader)
   245  
   246  		serverUserCipherConfig := s.ulm[uPSKHash]
   247  		if serverUserCipherConfig == nil {
   248  			s.Unlock()
   249  			payload = b[:n]
   250  			err = ErrIdentityHeaderUserPSKNotFound
   251  			return
   252  		}
   253  		userCipherConfig = serverUserCipherConfig.UserCipherConfig
   254  		username = serverUserCipherConfig.Name
   255  	}
   256  
   257  	// Derive key and create cipher.
   258  	shadowStreamCipher, err := userCipherConfig.ShadowStreamCipher(salt)
   259  	if err != nil {
   260  		s.Unlock()
   261  		return
   262  	}
   263  
   264  	// AEAD open.
   265  	plaintext, err := shadowStreamCipher.DecryptTo(nil, ciphertext)
   266  	if err != nil {
   267  		s.Unlock()
   268  		payload = b[:n]
   269  		return
   270  	}
   271  
   272  	// Parse fixed-length header.
   273  	vhlen, err := ParseTCPRequestFixedLengthHeader(plaintext)
   274  	if err != nil {
   275  		s.Unlock()
   276  		return
   277  	}
   278  
   279  	// Add request salt to pool.
   280  	s.saltPool.Add(string(salt))
   281  
   282  	s.Unlock()
   283  
   284  	b = make([]byte, vhlen+16)
   285  
   286  	// Read variable-length header.
   287  	_, err = io.ReadFull(rawRW, b)
   288  	if err != nil {
   289  		return
   290  	}
   291  
   292  	// AEAD open.
   293  	plaintext, err = shadowStreamCipher.DecryptInPlace(b)
   294  	if err != nil {
   295  		return
   296  	}
   297  
   298  	// Parse variable-length header.
   299  	targetAddr, payload, err = ParseTCPRequestVariableLengthHeader(plaintext)
   300  	if err != nil {
   301  		return
   302  	}
   303  
   304  	r := ShadowStreamReader{
   305  		reader: rawRW,
   306  		ssc:    shadowStreamCipher,
   307  	}
   308  	rw = &ShadowStreamServerReadWriter{
   309  		ShadowStreamReader:         &r,
   310  		rawRW:                      rawRW,
   311  		cipherConfig:               userCipherConfig,
   312  		requestSalt:                salt,
   313  		unsafeResponseStreamPrefix: s.unsafeResponseStreamPrefix,
   314  	}
   315  	return
   316  }