github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/ss2022/tcp.go (about) 1 package ss2022 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/cipher" 7 "crypto/rand" 8 "io" 9 mrand "math/rand/v2" 10 11 "github.com/database64128/shadowsocks-go/conn" 12 "github.com/database64128/shadowsocks-go/socks5" 13 "github.com/database64128/shadowsocks-go/zerocopy" 14 ) 15 16 // TCPClient implements the zerocopy TCPClient interface. 17 type TCPClient struct { 18 name string 19 rwo zerocopy.DirectReadWriteCloserOpener 20 readOnceOrFull func(io.Reader, []byte) (int, error) 21 cipherConfig *ClientCipherConfig 22 unsafeRequestStreamPrefix []byte 23 unsafeResponseStreamPrefix []byte 24 } 25 26 func NewTCPClient(name, network, address string, dialer conn.Dialer, allowSegmentedFixedLengthHeader bool, cipherConfig *ClientCipherConfig, unsafeRequestStreamPrefix, unsafeResponseStreamPrefix []byte) *TCPClient { 27 return &TCPClient{ 28 name: name, 29 rwo: zerocopy.NewTCPConnOpener(dialer, network, address), 30 readOnceOrFull: readOnceOrFullFunc(allowSegmentedFixedLengthHeader), 31 cipherConfig: cipherConfig, 32 unsafeRequestStreamPrefix: unsafeRequestStreamPrefix, 33 unsafeResponseStreamPrefix: unsafeResponseStreamPrefix, 34 } 35 } 36 37 // Info implements the zerocopy.TCPClient Info method. 38 func (c *TCPClient) Info() zerocopy.TCPClientInfo { 39 return zerocopy.TCPClientInfo{ 40 Name: c.name, 41 NativeInitialPayload: true, 42 } 43 } 44 45 // Dial implements the zerocopy.TCPClient Dial method. 46 func (c *TCPClient) Dial(ctx context.Context, targetAddr conn.Addr, payload []byte) (rawRW zerocopy.DirectReadWriteCloser, rw zerocopy.ReadWriter, err error) { 47 var ( 48 paddingPayloadLen int 49 excessPayload []byte 50 ) 51 52 targetAddrLen := socks5.LengthOfAddrFromConnAddr(targetAddr) 53 payloadLen := len(payload) 54 roomForPayload := MaxPayloadSize - targetAddrLen - 2 55 56 switch { 57 case payloadLen > roomForPayload: 58 paddingPayloadLen = roomForPayload 59 excessPayload = payload[roomForPayload:] 60 payload = payload[:roomForPayload] 61 case payloadLen >= MaxPaddingLength: 62 paddingPayloadLen = payloadLen 63 case payloadLen > 0: 64 paddingPayloadLen = payloadLen + mrand.IntN(MaxPaddingLength-payloadLen+1) 65 default: 66 paddingPayloadLen = 1 + mrand.IntN(MaxPaddingLength) 67 } 68 69 urspLen := len(c.unsafeRequestStreamPrefix) 70 saltLen := len(c.cipherConfig.PSK) 71 eihPSKHashes := c.cipherConfig.EIHPSKHashes() 72 identityHeadersLen := IdentityHeaderLength * len(eihPSKHashes) 73 identityHeadersStart := urspLen + saltLen 74 fixedLengthHeaderStart := identityHeadersStart + identityHeadersLen 75 fixedLengthHeaderEnd := fixedLengthHeaderStart + TCPRequestFixedLengthHeaderLength 76 variableLengthHeaderStart := fixedLengthHeaderEnd + 16 77 variableLengthHeaderLen := targetAddrLen + 2 + paddingPayloadLen 78 variableLengthHeaderEnd := variableLengthHeaderStart + variableLengthHeaderLen 79 bufferLen := variableLengthHeaderEnd + 16 80 b := make([]byte, bufferLen) 81 ursp := b[:urspLen] 82 salt := b[urspLen:identityHeadersStart] 83 identityHeaders := b[identityHeadersStart:fixedLengthHeaderStart] 84 fixedLengthHeaderPlaintext := b[fixedLengthHeaderStart:fixedLengthHeaderEnd] 85 variableLengthHeaderPlaintext := b[variableLengthHeaderStart:variableLengthHeaderEnd] 86 87 // Write unsafe request stream prefix. 88 copy(ursp, c.unsafeRequestStreamPrefix) 89 90 // Random salt. 91 _, err = rand.Read(salt) 92 if err != nil { 93 return 94 } 95 96 // Write and encrypt identity headers. 97 eihCiphers, err := c.cipherConfig.TCPIdentityHeaderCiphers(salt) 98 if err != nil { 99 return 100 } 101 102 for i := range eihPSKHashes { 103 identityHeader := identityHeaders[i*IdentityHeaderLength : (i+1)*IdentityHeaderLength] 104 eihCiphers[i].Encrypt(identityHeader, eihPSKHashes[i][:]) 105 } 106 107 // Write variable-length header. 108 WriteTCPRequestVariableLengthHeader(variableLengthHeaderPlaintext, targetAddr, payload) 109 110 // Write fixed-length header. 111 WriteTCPRequestFixedLengthHeader(fixedLengthHeaderPlaintext, uint16(variableLengthHeaderLen)) 112 113 // Create AEAD cipher. 114 shadowStreamCipher, err := c.cipherConfig.ShadowStreamCipher(salt) 115 if err != nil { 116 return 117 } 118 119 // Seal fixed-length header. 120 shadowStreamCipher.EncryptInPlace(fixedLengthHeaderPlaintext) 121 122 // Seal variable-length header. 123 shadowStreamCipher.EncryptInPlace(variableLengthHeaderPlaintext) 124 125 // Write out. 126 rawRW, err = c.rwo.Open(ctx, b) 127 if err != nil { 128 return 129 } 130 131 w := ShadowStreamWriter{ 132 writer: rawRW, 133 ssc: shadowStreamCipher, 134 } 135 136 // Write excess payload, reusing the variable-length header buffer. 137 for len(excessPayload) > 0 { 138 n := copy(variableLengthHeaderPlaintext, excessPayload) 139 excessPayload = excessPayload[n:] 140 if _, err = w.WriteZeroCopy(b, variableLengthHeaderStart, n); err != nil { 141 rawRW.Close() 142 return 143 } 144 } 145 146 rw = &ShadowStreamClientReadWriter{ 147 ShadowStreamWriter: &w, 148 rawRW: rawRW, 149 readOnceOrFull: c.readOnceOrFull, 150 cipherConfig: c.cipherConfig, 151 requestSalt: salt, 152 unsafeResponseStreamPrefix: c.unsafeResponseStreamPrefix, 153 } 154 155 return 156 } 157 158 // TCPServer implements the zerocopy TCPServer interface. 159 type TCPServer struct { 160 CredStore 161 saltPool *SaltPool[string] 162 readOnceOrFull func(io.Reader, []byte) (int, error) 163 userCipherConfig UserCipherConfig 164 identityCipherConfig ServerIdentityCipherConfig 165 unsafeRequestStreamPrefix []byte 166 unsafeResponseStreamPrefix []byte 167 } 168 169 func NewTCPServer(allowSegmentedFixedLengthHeader bool, userCipherConfig UserCipherConfig, identityCipherConfig ServerIdentityCipherConfig, unsafeRequestStreamPrefix, unsafeResponseStreamPrefix []byte) *TCPServer { 170 return &TCPServer{ 171 saltPool: NewSaltPool[string](ReplayWindowDuration), 172 readOnceOrFull: readOnceOrFullFunc(allowSegmentedFixedLengthHeader), 173 userCipherConfig: userCipherConfig, 174 identityCipherConfig: identityCipherConfig, 175 unsafeRequestStreamPrefix: unsafeRequestStreamPrefix, 176 unsafeResponseStreamPrefix: unsafeResponseStreamPrefix, 177 } 178 } 179 180 // Info implements the zerocopy.TCPServer Info method. 181 func (s *TCPServer) Info() zerocopy.TCPServerInfo { 182 return zerocopy.TCPServerInfo{ 183 NativeInitialPayload: true, 184 DefaultTCPConnCloser: zerocopy.ForceReset, 185 } 186 } 187 188 // Accept implements the zerocopy.TCPServer Accept method. 189 func (s *TCPServer) Accept(rawRW zerocopy.DirectReadWriteCloser) (rw zerocopy.ReadWriter, targetAddr conn.Addr, payload []byte, username string, err error) { 190 var identityHeaderLen int 191 userCipherConfig := s.userCipherConfig 192 saltLen := len(userCipherConfig.PSK) 193 if saltLen == 0 { 194 saltLen = len(s.identityCipherConfig.IPSK) 195 identityHeaderLen = IdentityHeaderLength 196 } 197 198 urspLen := len(s.unsafeRequestStreamPrefix) 199 identityHeaderStart := urspLen + saltLen 200 fixedLengthHeaderStart := identityHeaderStart + identityHeaderLen 201 bufferLen := fixedLengthHeaderStart + TCPRequestFixedLengthHeaderLength + 16 202 b := make([]byte, bufferLen) 203 204 // Read unsafe request stream prefix, salt, identity header, fixed-length header. 205 n, err := s.readOnceOrFull(rawRW, b) 206 if err != nil { 207 payload = b[:n] 208 return 209 } 210 211 ursp := b[:urspLen] 212 salt := b[urspLen:identityHeaderStart] 213 ciphertext := b[fixedLengthHeaderStart:] 214 215 s.Lock() 216 217 // Check but not add request salt to pool. 218 if !s.saltPool.Check(string(salt)) { // Is the compiler smart enough to not incur an allocation here? 219 s.Unlock() 220 payload = b[:n] 221 err = ErrRepeatedSalt 222 return 223 } 224 225 // Check unsafe request stream prefix. 226 if !bytes.Equal(ursp, s.unsafeRequestStreamPrefix) { 227 s.Unlock() 228 payload = b[:n] 229 err = &HeaderError[[]byte]{ErrUnsafeStreamPrefixMismatch, s.unsafeRequestStreamPrefix, ursp} 230 return 231 } 232 233 // Process identity header. 234 if identityHeaderLen != 0 { 235 var identityHeaderCipher cipher.Block 236 identityHeaderCipher, err = s.identityCipherConfig.TCP(salt) 237 if err != nil { 238 s.Unlock() 239 return 240 } 241 242 var uPSKHash [IdentityHeaderLength]byte 243 identityHeader := b[identityHeaderStart:fixedLengthHeaderStart] 244 identityHeaderCipher.Decrypt(uPSKHash[:], identityHeader) 245 246 serverUserCipherConfig := s.ulm[uPSKHash] 247 if serverUserCipherConfig == nil { 248 s.Unlock() 249 payload = b[:n] 250 err = ErrIdentityHeaderUserPSKNotFound 251 return 252 } 253 userCipherConfig = serverUserCipherConfig.UserCipherConfig 254 username = serverUserCipherConfig.Name 255 } 256 257 // Derive key and create cipher. 258 shadowStreamCipher, err := userCipherConfig.ShadowStreamCipher(salt) 259 if err != nil { 260 s.Unlock() 261 return 262 } 263 264 // AEAD open. 265 plaintext, err := shadowStreamCipher.DecryptTo(nil, ciphertext) 266 if err != nil { 267 s.Unlock() 268 payload = b[:n] 269 return 270 } 271 272 // Parse fixed-length header. 273 vhlen, err := ParseTCPRequestFixedLengthHeader(plaintext) 274 if err != nil { 275 s.Unlock() 276 return 277 } 278 279 // Add request salt to pool. 280 s.saltPool.Add(string(salt)) 281 282 s.Unlock() 283 284 b = make([]byte, vhlen+16) 285 286 // Read variable-length header. 287 _, err = io.ReadFull(rawRW, b) 288 if err != nil { 289 return 290 } 291 292 // AEAD open. 293 plaintext, err = shadowStreamCipher.DecryptInPlace(b) 294 if err != nil { 295 return 296 } 297 298 // Parse variable-length header. 299 targetAddr, payload, err = ParseTCPRequestVariableLengthHeader(plaintext) 300 if err != nil { 301 return 302 } 303 304 r := ShadowStreamReader{ 305 reader: rawRW, 306 ssc: shadowStreamCipher, 307 } 308 rw = &ShadowStreamServerReadWriter{ 309 ShadowStreamReader: &r, 310 rawRW: rawRW, 311 cipherConfig: userCipherConfig, 312 requestSalt: salt, 313 unsafeResponseStreamPrefix: s.unsafeResponseStreamPrefix, 314 } 315 return 316 }