github.com/database64128/shadowsocks-go@v1.7.0/ss2022/tcp.go (about) 1 package ss2022 2 3 import ( 4 "bytes" 5 "crypto/cipher" 6 "crypto/rand" 7 "io" 8 9 "github.com/database64128/shadowsocks-go/conn" 10 "github.com/database64128/shadowsocks-go/magic" 11 "github.com/database64128/shadowsocks-go/socks5" 12 "github.com/database64128/shadowsocks-go/zerocopy" 13 "github.com/database64128/tfo-go/v2" 14 ) 15 16 // TCPClient implements the zerocopy TCPClient interface. 17 type TCPClient struct { 18 name string 19 rwo zerocopy.DirectReadWriteCloserOpener 20 cipherConfig *ClientCipherConfig 21 unsafeRequestStreamPrefix []byte 22 unsafeResponseStreamPrefix []byte 23 } 24 25 func NewTCPClient(name, address string, dialer tfo.Dialer, cipherConfig *ClientCipherConfig, unsafeRequestStreamPrefix, unsafeResponseStreamPrefix []byte) *TCPClient { 26 return &TCPClient{ 27 name: name, 28 rwo: zerocopy.NewTCPConnOpener(dialer, "tcp", address), 29 cipherConfig: cipherConfig, 30 unsafeRequestStreamPrefix: unsafeRequestStreamPrefix, 31 unsafeResponseStreamPrefix: unsafeResponseStreamPrefix, 32 } 33 } 34 35 // Info implements the zerocopy.TCPClient Info method. 36 func (c *TCPClient) Info() zerocopy.TCPClientInfo { 37 return zerocopy.TCPClientInfo{ 38 Name: c.name, 39 NativeInitialPayload: true, 40 } 41 } 42 43 // Dial implements the zerocopy.TCPClient Dial method. 44 func (c *TCPClient) Dial(targetAddr conn.Addr, payload []byte) (rawRW zerocopy.DirectReadWriteCloser, rw zerocopy.ReadWriter, err error) { 45 var ( 46 paddingPayloadLen int 47 excessPayload []byte 48 ) 49 50 targetAddrLen := socks5.LengthOfAddrFromConnAddr(targetAddr) 51 payloadLen := len(payload) 52 roomForPayload := MaxPayloadSize - targetAddrLen - 2 53 54 switch { 55 case payloadLen > roomForPayload: 56 paddingPayloadLen = roomForPayload 57 excessPayload = payload[roomForPayload:] 58 payload = payload[:roomForPayload] 59 case payloadLen >= MaxPaddingLength: 60 paddingPayloadLen = payloadLen 61 case payloadLen > 0: 62 paddingPayloadLen = payloadLen + int(magic.Fastrandn(MaxPaddingLength-uint32(payloadLen)+1)) 63 default: 64 paddingPayloadLen = 1 + int(magic.Fastrandn(MaxPaddingLength)) 65 } 66 67 urspLen := len(c.unsafeRequestStreamPrefix) 68 saltLen := len(c.cipherConfig.PSK) 69 eihPSKHashes := c.cipherConfig.EIHPSKHashes() 70 identityHeadersLen := IdentityHeaderLength * len(eihPSKHashes) 71 identityHeadersStart := urspLen + saltLen 72 fixedLengthHeaderStart := identityHeadersStart + identityHeadersLen 73 fixedLengthHeaderEnd := fixedLengthHeaderStart + TCPRequestFixedLengthHeaderLength 74 variableLengthHeaderStart := fixedLengthHeaderEnd + 16 75 variableLengthHeaderLen := targetAddrLen + 2 + paddingPayloadLen 76 variableLengthHeaderEnd := variableLengthHeaderStart + variableLengthHeaderLen 77 bufferLen := variableLengthHeaderEnd + 16 78 b := make([]byte, bufferLen) 79 ursp := b[:urspLen] 80 salt := b[urspLen:identityHeadersStart] 81 identityHeaders := b[identityHeadersStart:fixedLengthHeaderStart] 82 fixedLengthHeaderPlaintext := b[fixedLengthHeaderStart:fixedLengthHeaderEnd] 83 variableLengthHeaderPlaintext := b[variableLengthHeaderStart:variableLengthHeaderEnd] 84 85 // Write unsafe request stream prefix. 86 copy(ursp, c.unsafeRequestStreamPrefix) 87 88 // Random salt. 89 _, err = rand.Read(salt) 90 if err != nil { 91 return 92 } 93 94 // Write and encrypt identity headers. 95 eihCiphers, err := c.cipherConfig.TCPIdentityHeaderCiphers(salt) 96 if err != nil { 97 return 98 } 99 100 for i := range eihPSKHashes { 101 identityHeader := identityHeaders[i*IdentityHeaderLength : (i+1)*IdentityHeaderLength] 102 eihCiphers[i].Encrypt(identityHeader, eihPSKHashes[i][:]) 103 } 104 105 // Write variable-length header. 106 WriteTCPRequestVariableLengthHeader(variableLengthHeaderPlaintext, targetAddr, payload) 107 108 // Write fixed-length header. 109 WriteTCPRequestFixedLengthHeader(fixedLengthHeaderPlaintext, uint16(variableLengthHeaderLen)) 110 111 // Create AEAD cipher. 112 shadowStreamCipher, err := c.cipherConfig.ShadowStreamCipher(salt) 113 if err != nil { 114 return 115 } 116 117 // Seal fixed-length header. 118 shadowStreamCipher.EncryptInPlace(fixedLengthHeaderPlaintext) 119 120 // Seal variable-length header. 121 shadowStreamCipher.EncryptInPlace(variableLengthHeaderPlaintext) 122 123 // Write out. 124 rawRW, err = c.rwo.Open(b) 125 if err != nil { 126 return 127 } 128 129 w := ShadowStreamWriter{ 130 writer: rawRW, 131 ssc: shadowStreamCipher, 132 } 133 134 // Write excess payload, reusing the variable-length header buffer. 135 for len(excessPayload) > 0 { 136 n := copy(variableLengthHeaderPlaintext, excessPayload) 137 excessPayload = excessPayload[n:] 138 if _, err = w.WriteZeroCopy(b, variableLengthHeaderStart, n); err != nil { 139 rawRW.Close() 140 return 141 } 142 } 143 144 rw = &ShadowStreamClientReadWriter{ 145 ShadowStreamWriter: &w, 146 rawRW: rawRW, 147 cipherConfig: c.cipherConfig, 148 requestSalt: salt, 149 unsafeResponseStreamPrefix: c.unsafeResponseStreamPrefix, 150 } 151 152 return 153 } 154 155 // TCPServer implements the zerocopy TCPServer interface. 156 type TCPServer struct { 157 CredStore 158 saltPool *SaltPool[string] 159 userCipherConfig UserCipherConfig 160 identityCipherConfig ServerIdentityCipherConfig 161 unsafeRequestStreamPrefix []byte 162 unsafeResponseStreamPrefix []byte 163 } 164 165 func NewTCPServer(userCipherConfig UserCipherConfig, identityCipherConfig ServerIdentityCipherConfig, unsafeRequestStreamPrefix, unsafeResponseStreamPrefix []byte) *TCPServer { 166 return &TCPServer{ 167 saltPool: NewSaltPool[string](ReplayWindowDuration), 168 userCipherConfig: userCipherConfig, 169 identityCipherConfig: identityCipherConfig, 170 unsafeRequestStreamPrefix: unsafeRequestStreamPrefix, 171 unsafeResponseStreamPrefix: unsafeResponseStreamPrefix, 172 } 173 } 174 175 // Info implements the zerocopy.TCPServer Info method. 176 func (s *TCPServer) Info() zerocopy.TCPServerInfo { 177 return zerocopy.TCPServerInfo{ 178 NativeInitialPayload: true, 179 DefaultTCPConnCloser: zerocopy.ForceReset, 180 } 181 } 182 183 // Accept implements the zerocopy.TCPServer Accept method. 184 func (s *TCPServer) Accept(rawRW zerocopy.DirectReadWriteCloser) (rw zerocopy.ReadWriter, targetAddr conn.Addr, payload []byte, username string, err error) { 185 var identityHeaderLen int 186 userCipherConfig := s.userCipherConfig 187 saltLen := len(userCipherConfig.PSK) 188 if saltLen == 0 { 189 saltLen = len(s.identityCipherConfig.IPSK) 190 identityHeaderLen = IdentityHeaderLength 191 } 192 193 urspLen := len(s.unsafeRequestStreamPrefix) 194 identityHeaderStart := urspLen + saltLen 195 fixedLengthHeaderStart := identityHeaderStart + identityHeaderLen 196 bufferLen := fixedLengthHeaderStart + TCPRequestFixedLengthHeaderLength + 16 197 b := make([]byte, bufferLen) 198 199 // Read unsafe request stream prefix, salt, identity header, fixed-length header. 200 n, err := rawRW.Read(b) 201 if err != nil { 202 return 203 } 204 if n < bufferLen { 205 payload = b[:n] 206 err = &HeaderError[int]{ErrFirstRead, bufferLen, n} 207 return 208 } 209 210 ursp := b[:urspLen] 211 salt := b[urspLen:identityHeaderStart] 212 ciphertext := b[fixedLengthHeaderStart:] 213 214 s.Lock() 215 216 // Check but not add request salt to pool. 217 if !s.saltPool.Check(string(salt)) { // Is the compiler smart enough to not incur an allocation here? 218 s.Unlock() 219 payload = b[:n] 220 err = ErrRepeatedSalt 221 return 222 } 223 224 // Check unsafe request stream prefix. 225 if !bytes.Equal(ursp, s.unsafeRequestStreamPrefix) { 226 s.Unlock() 227 payload = b[:n] 228 err = &HeaderError[[]byte]{ErrUnsafeStreamPrefixMismatch, s.unsafeRequestStreamPrefix, ursp} 229 return 230 } 231 232 // Process identity header. 233 if identityHeaderLen != 0 { 234 var identityHeaderCipher cipher.Block 235 identityHeaderCipher, err = s.identityCipherConfig.TCP(salt) 236 if err != nil { 237 s.Unlock() 238 return 239 } 240 241 var uPSKHash [IdentityHeaderLength]byte 242 identityHeader := b[identityHeaderStart:fixedLengthHeaderStart] 243 identityHeaderCipher.Decrypt(uPSKHash[:], identityHeader) 244 245 serverUserCipherConfig := s.ulm[uPSKHash] 246 if serverUserCipherConfig == nil { 247 s.Unlock() 248 payload = b[:n] 249 err = ErrIdentityHeaderUserPSKNotFound 250 return 251 } 252 userCipherConfig = serverUserCipherConfig.UserCipherConfig 253 username = serverUserCipherConfig.Name 254 } 255 256 // Derive key and create cipher. 257 shadowStreamCipher, err := userCipherConfig.ShadowStreamCipher(salt) 258 if err != nil { 259 s.Unlock() 260 return 261 } 262 263 // AEAD open. 264 plaintext, err := shadowStreamCipher.DecryptTo(nil, ciphertext) 265 if err != nil { 266 s.Unlock() 267 payload = b[:n] 268 return 269 } 270 271 // Parse fixed-length header. 272 vhlen, err := ParseTCPRequestFixedLengthHeader(plaintext) 273 if err != nil { 274 s.Unlock() 275 return 276 } 277 278 // Add request salt to pool. 279 s.saltPool.Add(string(salt)) 280 281 s.Unlock() 282 283 b = make([]byte, vhlen+16) 284 285 // Read variable-length header. 286 _, err = io.ReadFull(rawRW, b) 287 if err != nil { 288 return 289 } 290 291 // AEAD open. 292 plaintext, err = shadowStreamCipher.DecryptInPlace(b) 293 if err != nil { 294 return 295 } 296 297 // Parse variable-length header. 298 targetAddr, payload, err = ParseTCPRequestVariableLengthHeader(plaintext) 299 if err != nil { 300 return 301 } 302 303 r := ShadowStreamReader{ 304 reader: rawRW, 305 ssc: shadowStreamCipher, 306 } 307 rw = &ShadowStreamServerReadWriter{ 308 ShadowStreamReader: &r, 309 rawRW: rawRW, 310 cipherConfig: userCipherConfig, 311 requestSalt: salt, 312 unsafeResponseStreamPrefix: s.unsafeResponseStreamPrefix, 313 } 314 return 315 }