github.com/database64128/shadowsocks-go@v1.7.0/ss2022/tcp.go (about)

     1  package ss2022
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/cipher"
     6  	"crypto/rand"
     7  	"io"
     8  
     9  	"github.com/database64128/shadowsocks-go/conn"
    10  	"github.com/database64128/shadowsocks-go/magic"
    11  	"github.com/database64128/shadowsocks-go/socks5"
    12  	"github.com/database64128/shadowsocks-go/zerocopy"
    13  	"github.com/database64128/tfo-go/v2"
    14  )
    15  
    16  // TCPClient implements the zerocopy TCPClient interface.
    17  type TCPClient struct {
    18  	name                       string
    19  	rwo                        zerocopy.DirectReadWriteCloserOpener
    20  	cipherConfig               *ClientCipherConfig
    21  	unsafeRequestStreamPrefix  []byte
    22  	unsafeResponseStreamPrefix []byte
    23  }
    24  
    25  func NewTCPClient(name, address string, dialer tfo.Dialer, cipherConfig *ClientCipherConfig, unsafeRequestStreamPrefix, unsafeResponseStreamPrefix []byte) *TCPClient {
    26  	return &TCPClient{
    27  		name:                       name,
    28  		rwo:                        zerocopy.NewTCPConnOpener(dialer, "tcp", address),
    29  		cipherConfig:               cipherConfig,
    30  		unsafeRequestStreamPrefix:  unsafeRequestStreamPrefix,
    31  		unsafeResponseStreamPrefix: unsafeResponseStreamPrefix,
    32  	}
    33  }
    34  
    35  // Info implements the zerocopy.TCPClient Info method.
    36  func (c *TCPClient) Info() zerocopy.TCPClientInfo {
    37  	return zerocopy.TCPClientInfo{
    38  		Name:                 c.name,
    39  		NativeInitialPayload: true,
    40  	}
    41  }
    42  
    43  // Dial implements the zerocopy.TCPClient Dial method.
    44  func (c *TCPClient) Dial(targetAddr conn.Addr, payload []byte) (rawRW zerocopy.DirectReadWriteCloser, rw zerocopy.ReadWriter, err error) {
    45  	var (
    46  		paddingPayloadLen int
    47  		excessPayload     []byte
    48  	)
    49  
    50  	targetAddrLen := socks5.LengthOfAddrFromConnAddr(targetAddr)
    51  	payloadLen := len(payload)
    52  	roomForPayload := MaxPayloadSize - targetAddrLen - 2
    53  
    54  	switch {
    55  	case payloadLen > roomForPayload:
    56  		paddingPayloadLen = roomForPayload
    57  		excessPayload = payload[roomForPayload:]
    58  		payload = payload[:roomForPayload]
    59  	case payloadLen >= MaxPaddingLength:
    60  		paddingPayloadLen = payloadLen
    61  	case payloadLen > 0:
    62  		paddingPayloadLen = payloadLen + int(magic.Fastrandn(MaxPaddingLength-uint32(payloadLen)+1))
    63  	default:
    64  		paddingPayloadLen = 1 + int(magic.Fastrandn(MaxPaddingLength))
    65  	}
    66  
    67  	urspLen := len(c.unsafeRequestStreamPrefix)
    68  	saltLen := len(c.cipherConfig.PSK)
    69  	eihPSKHashes := c.cipherConfig.EIHPSKHashes()
    70  	identityHeadersLen := IdentityHeaderLength * len(eihPSKHashes)
    71  	identityHeadersStart := urspLen + saltLen
    72  	fixedLengthHeaderStart := identityHeadersStart + identityHeadersLen
    73  	fixedLengthHeaderEnd := fixedLengthHeaderStart + TCPRequestFixedLengthHeaderLength
    74  	variableLengthHeaderStart := fixedLengthHeaderEnd + 16
    75  	variableLengthHeaderLen := targetAddrLen + 2 + paddingPayloadLen
    76  	variableLengthHeaderEnd := variableLengthHeaderStart + variableLengthHeaderLen
    77  	bufferLen := variableLengthHeaderEnd + 16
    78  	b := make([]byte, bufferLen)
    79  	ursp := b[:urspLen]
    80  	salt := b[urspLen:identityHeadersStart]
    81  	identityHeaders := b[identityHeadersStart:fixedLengthHeaderStart]
    82  	fixedLengthHeaderPlaintext := b[fixedLengthHeaderStart:fixedLengthHeaderEnd]
    83  	variableLengthHeaderPlaintext := b[variableLengthHeaderStart:variableLengthHeaderEnd]
    84  
    85  	// Write unsafe request stream prefix.
    86  	copy(ursp, c.unsafeRequestStreamPrefix)
    87  
    88  	// Random salt.
    89  	_, err = rand.Read(salt)
    90  	if err != nil {
    91  		return
    92  	}
    93  
    94  	// Write and encrypt identity headers.
    95  	eihCiphers, err := c.cipherConfig.TCPIdentityHeaderCiphers(salt)
    96  	if err != nil {
    97  		return
    98  	}
    99  
   100  	for i := range eihPSKHashes {
   101  		identityHeader := identityHeaders[i*IdentityHeaderLength : (i+1)*IdentityHeaderLength]
   102  		eihCiphers[i].Encrypt(identityHeader, eihPSKHashes[i][:])
   103  	}
   104  
   105  	// Write variable-length header.
   106  	WriteTCPRequestVariableLengthHeader(variableLengthHeaderPlaintext, targetAddr, payload)
   107  
   108  	// Write fixed-length header.
   109  	WriteTCPRequestFixedLengthHeader(fixedLengthHeaderPlaintext, uint16(variableLengthHeaderLen))
   110  
   111  	// Create AEAD cipher.
   112  	shadowStreamCipher, err := c.cipherConfig.ShadowStreamCipher(salt)
   113  	if err != nil {
   114  		return
   115  	}
   116  
   117  	// Seal fixed-length header.
   118  	shadowStreamCipher.EncryptInPlace(fixedLengthHeaderPlaintext)
   119  
   120  	// Seal variable-length header.
   121  	shadowStreamCipher.EncryptInPlace(variableLengthHeaderPlaintext)
   122  
   123  	// Write out.
   124  	rawRW, err = c.rwo.Open(b)
   125  	if err != nil {
   126  		return
   127  	}
   128  
   129  	w := ShadowStreamWriter{
   130  		writer: rawRW,
   131  		ssc:    shadowStreamCipher,
   132  	}
   133  
   134  	// Write excess payload, reusing the variable-length header buffer.
   135  	for len(excessPayload) > 0 {
   136  		n := copy(variableLengthHeaderPlaintext, excessPayload)
   137  		excessPayload = excessPayload[n:]
   138  		if _, err = w.WriteZeroCopy(b, variableLengthHeaderStart, n); err != nil {
   139  			rawRW.Close()
   140  			return
   141  		}
   142  	}
   143  
   144  	rw = &ShadowStreamClientReadWriter{
   145  		ShadowStreamWriter:         &w,
   146  		rawRW:                      rawRW,
   147  		cipherConfig:               c.cipherConfig,
   148  		requestSalt:                salt,
   149  		unsafeResponseStreamPrefix: c.unsafeResponseStreamPrefix,
   150  	}
   151  
   152  	return
   153  }
   154  
   155  // TCPServer implements the zerocopy TCPServer interface.
   156  type TCPServer struct {
   157  	CredStore
   158  	saltPool                   *SaltPool[string]
   159  	userCipherConfig           UserCipherConfig
   160  	identityCipherConfig       ServerIdentityCipherConfig
   161  	unsafeRequestStreamPrefix  []byte
   162  	unsafeResponseStreamPrefix []byte
   163  }
   164  
   165  func NewTCPServer(userCipherConfig UserCipherConfig, identityCipherConfig ServerIdentityCipherConfig, unsafeRequestStreamPrefix, unsafeResponseStreamPrefix []byte) *TCPServer {
   166  	return &TCPServer{
   167  		saltPool:                   NewSaltPool[string](ReplayWindowDuration),
   168  		userCipherConfig:           userCipherConfig,
   169  		identityCipherConfig:       identityCipherConfig,
   170  		unsafeRequestStreamPrefix:  unsafeRequestStreamPrefix,
   171  		unsafeResponseStreamPrefix: unsafeResponseStreamPrefix,
   172  	}
   173  }
   174  
   175  // Info implements the zerocopy.TCPServer Info method.
   176  func (s *TCPServer) Info() zerocopy.TCPServerInfo {
   177  	return zerocopy.TCPServerInfo{
   178  		NativeInitialPayload: true,
   179  		DefaultTCPConnCloser: zerocopy.ForceReset,
   180  	}
   181  }
   182  
   183  // Accept implements the zerocopy.TCPServer Accept method.
   184  func (s *TCPServer) Accept(rawRW zerocopy.DirectReadWriteCloser) (rw zerocopy.ReadWriter, targetAddr conn.Addr, payload []byte, username string, err error) {
   185  	var identityHeaderLen int
   186  	userCipherConfig := s.userCipherConfig
   187  	saltLen := len(userCipherConfig.PSK)
   188  	if saltLen == 0 {
   189  		saltLen = len(s.identityCipherConfig.IPSK)
   190  		identityHeaderLen = IdentityHeaderLength
   191  	}
   192  
   193  	urspLen := len(s.unsafeRequestStreamPrefix)
   194  	identityHeaderStart := urspLen + saltLen
   195  	fixedLengthHeaderStart := identityHeaderStart + identityHeaderLen
   196  	bufferLen := fixedLengthHeaderStart + TCPRequestFixedLengthHeaderLength + 16
   197  	b := make([]byte, bufferLen)
   198  
   199  	// Read unsafe request stream prefix, salt, identity header, fixed-length header.
   200  	n, err := rawRW.Read(b)
   201  	if err != nil {
   202  		return
   203  	}
   204  	if n < bufferLen {
   205  		payload = b[:n]
   206  		err = &HeaderError[int]{ErrFirstRead, bufferLen, n}
   207  		return
   208  	}
   209  
   210  	ursp := b[:urspLen]
   211  	salt := b[urspLen:identityHeaderStart]
   212  	ciphertext := b[fixedLengthHeaderStart:]
   213  
   214  	s.Lock()
   215  
   216  	// Check but not add request salt to pool.
   217  	if !s.saltPool.Check(string(salt)) { // Is the compiler smart enough to not incur an allocation here?
   218  		s.Unlock()
   219  		payload = b[:n]
   220  		err = ErrRepeatedSalt
   221  		return
   222  	}
   223  
   224  	// Check unsafe request stream prefix.
   225  	if !bytes.Equal(ursp, s.unsafeRequestStreamPrefix) {
   226  		s.Unlock()
   227  		payload = b[:n]
   228  		err = &HeaderError[[]byte]{ErrUnsafeStreamPrefixMismatch, s.unsafeRequestStreamPrefix, ursp}
   229  		return
   230  	}
   231  
   232  	// Process identity header.
   233  	if identityHeaderLen != 0 {
   234  		var identityHeaderCipher cipher.Block
   235  		identityHeaderCipher, err = s.identityCipherConfig.TCP(salt)
   236  		if err != nil {
   237  			s.Unlock()
   238  			return
   239  		}
   240  
   241  		var uPSKHash [IdentityHeaderLength]byte
   242  		identityHeader := b[identityHeaderStart:fixedLengthHeaderStart]
   243  		identityHeaderCipher.Decrypt(uPSKHash[:], identityHeader)
   244  
   245  		serverUserCipherConfig := s.ulm[uPSKHash]
   246  		if serverUserCipherConfig == nil {
   247  			s.Unlock()
   248  			payload = b[:n]
   249  			err = ErrIdentityHeaderUserPSKNotFound
   250  			return
   251  		}
   252  		userCipherConfig = serverUserCipherConfig.UserCipherConfig
   253  		username = serverUserCipherConfig.Name
   254  	}
   255  
   256  	// Derive key and create cipher.
   257  	shadowStreamCipher, err := userCipherConfig.ShadowStreamCipher(salt)
   258  	if err != nil {
   259  		s.Unlock()
   260  		return
   261  	}
   262  
   263  	// AEAD open.
   264  	plaintext, err := shadowStreamCipher.DecryptTo(nil, ciphertext)
   265  	if err != nil {
   266  		s.Unlock()
   267  		payload = b[:n]
   268  		return
   269  	}
   270  
   271  	// Parse fixed-length header.
   272  	vhlen, err := ParseTCPRequestFixedLengthHeader(plaintext)
   273  	if err != nil {
   274  		s.Unlock()
   275  		return
   276  	}
   277  
   278  	// Add request salt to pool.
   279  	s.saltPool.Add(string(salt))
   280  
   281  	s.Unlock()
   282  
   283  	b = make([]byte, vhlen+16)
   284  
   285  	// Read variable-length header.
   286  	_, err = io.ReadFull(rawRW, b)
   287  	if err != nil {
   288  		return
   289  	}
   290  
   291  	// AEAD open.
   292  	plaintext, err = shadowStreamCipher.DecryptInPlace(b)
   293  	if err != nil {
   294  		return
   295  	}
   296  
   297  	// Parse variable-length header.
   298  	targetAddr, payload, err = ParseTCPRequestVariableLengthHeader(plaintext)
   299  	if err != nil {
   300  		return
   301  	}
   302  
   303  	r := ShadowStreamReader{
   304  		reader: rawRW,
   305  		ssc:    shadowStreamCipher,
   306  	}
   307  	rw = &ShadowStreamServerReadWriter{
   308  		ShadowStreamReader:         &r,
   309  		rawRW:                      rawRW,
   310  		cipherConfig:               userCipherConfig,
   311  		requestSalt:                salt,
   312  		unsafeResponseStreamPrefix: s.unsafeResponseStreamPrefix,
   313  	}
   314  	return
   315  }