github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/ss2022/header_test.go (about) 1 package ss2022 2 3 import ( 4 "bytes" 5 "crypto/rand" 6 "encoding/binary" 7 "errors" 8 "math" 9 mrand "math/rand/v2" 10 "net/netip" 11 "testing" 12 "time" 13 14 "github.com/database64128/shadowsocks-go/conn" 15 "github.com/database64128/shadowsocks-go/socks5" 16 ) 17 18 func TestHeaderErrorString(t *testing.T) { 19 const errMsg = "time diff is over 30 seconds: expected 1, got 2" 20 err := HeaderError[int]{ErrBadTimestamp, 1, 2} 21 if err.Error() != errMsg { 22 t.FailNow() 23 } 24 } 25 26 func TestWriteAndParseTCPRequestFixedLengthHeader(t *testing.T) { 27 b := make([]byte, TCPRequestFixedLengthHeaderLength) 28 length := int(mrand.Uint64() & math.MaxUint16) 29 30 // 1. Good header 31 WriteTCPRequestFixedLengthHeader(b, uint16(length)) 32 33 n, err := ParseTCPRequestFixedLengthHeader(b) 34 if err != nil { 35 t.Fatal(err) 36 } 37 if n != length { 38 t.Fatalf("Expected: %d\nGot: %d", length, n) 39 } 40 41 // 2. Bad timestamp (31s ago) 42 ts := time.Now().Add(-31 * time.Second) 43 binary.BigEndian.PutUint64(b[1:], uint64(ts.Unix())) 44 45 _, err = ParseTCPRequestFixedLengthHeader(b) 46 if !errors.Is(err, ErrBadTimestamp) { 47 t.Fatalf("Expected: %s\nGot: %s", ErrBadTimestamp, err) 48 } 49 50 // 3. Bad timestamp (31s later) 51 ts = time.Now().Add(31 * time.Second) 52 binary.BigEndian.PutUint64(b[1:], uint64(ts.Unix())) 53 54 _, err = ParseTCPRequestFixedLengthHeader(b) 55 if !errors.Is(err, ErrBadTimestamp) { 56 t.Fatalf("Expected: %s\nGot: %s", ErrBadTimestamp, err) 57 } 58 59 // 4. Bad type 60 b[0] = HeaderTypeServerStream 61 62 _, err = ParseTCPRequestFixedLengthHeader(b) 63 if !errors.Is(err, ErrTypeMismatch) { 64 t.Fatalf("Expected: %s\nGot: %s", ErrTypeMismatch, err) 65 } 66 } 67 68 func TestWriteAndParseTCPRequestVariableLengthHeader(t *testing.T) { 69 payloadLen := 1 + int(mrand.Uint64()&1023) 70 payload := make([]byte, payloadLen) 71 _, err := rand.Read(payload) 72 if err != nil { 73 t.Fatal(err) 74 } 75 targetAddr := conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Unspecified(), 443)) 76 targetAddrLen := socks5.LengthOfAddrFromConnAddr(targetAddr) 77 noPayloadLen := targetAddrLen + 2 + 1 + mrand.IntN(MaxPaddingLength) 78 noPaddingLen := targetAddrLen + 2 + payloadLen 79 bufLen := noPaddingLen + MaxPaddingLength 80 b := make([]byte, bufLen) 81 82 // 1. Good header (padding + initial payload) 83 WriteTCPRequestVariableLengthHeader(b, targetAddr, payload) 84 85 ta, p, err := ParseTCPRequestVariableLengthHeader(b) 86 if err != nil { 87 t.Fatal(err) 88 } 89 if !bytes.Equal(p, payload) { 90 t.Fatalf("Expected payload %v\nGot: %v", payload, p) 91 } 92 if !ta.Equals(targetAddr) { 93 t.Fatalf("Expected target address %s, got %s", targetAddr, ta) 94 } 95 96 // 2. Good header (initial payload) 97 b = b[:noPaddingLen] 98 WriteTCPRequestVariableLengthHeader(b, targetAddr, payload) 99 100 ta, p, err = ParseTCPRequestVariableLengthHeader(b) 101 if err != nil { 102 t.Fatal(err) 103 } 104 if !bytes.Equal(p, payload) { 105 t.Fatalf("Expected payload %v\nGot: %v", payload, p) 106 } 107 if !ta.Equals(targetAddr) { 108 t.Fatalf("Expected target address %s, got %s", targetAddr, ta) 109 } 110 111 // 3. Good header (padding) 112 b = b[:noPayloadLen] 113 WriteTCPRequestVariableLengthHeader(b, targetAddr, nil) 114 115 ta, p, err = ParseTCPRequestVariableLengthHeader(b) 116 if err != nil { 117 t.Fatal(err) 118 } 119 if len(p) > 0 { 120 t.Fatalf("Expected empty initial payload, got length %d", len(p)) 121 } 122 if !ta.Equals(targetAddr) { 123 t.Fatalf("Expected target address %s, got %s", targetAddr, ta) 124 } 125 126 // 4. Bad header (incomplete padding) 127 b = b[:noPayloadLen-1] 128 129 _, _, err = ParseTCPRequestVariableLengthHeader(b) 130 if !errors.Is(err, ErrPaddingExceedChunkBorder) { 131 t.Fatalf("Expected: %s\nGot: %s", ErrPaddingExceedChunkBorder, err) 132 } 133 134 // 5. Bad header (incomplete padding length) 135 b = b[:targetAddrLen+1] 136 137 _, _, err = ParseTCPRequestVariableLengthHeader(b) 138 if !errors.Is(err, ErrIncompleteHeaderInFirstChunk) { 139 t.Fatalf("Expected: %s\nGot: %s", ErrIncompleteHeaderInFirstChunk, err) 140 } 141 142 // 6. Bad header (incomplete SOCKS address) 143 b = b[:targetAddrLen-1] 144 145 _, _, err = ParseTCPRequestVariableLengthHeader(b) 146 if err == nil { 147 t.Fatal("Expected error, got nil") 148 } 149 } 150 151 func TestWriteAndParseTCPResponseHeader(t *testing.T) { 152 const ( 153 saltLen = 32 154 bufLen = 1 + 8 + saltLen + 2 155 ) 156 157 b := make([]byte, bufLen) 158 length := int(mrand.Uint64() & math.MaxUint16) 159 requestSalt := make([]byte, saltLen) 160 _, err := rand.Read(requestSalt) 161 if err != nil { 162 t.Fatal(err) 163 } 164 165 // 1. Good header 166 WriteTCPResponseHeader(b, requestSalt, uint16(length)) 167 168 n, err := ParseTCPResponseHeader(b, requestSalt) 169 if err != nil { 170 t.Fatal(err) 171 } 172 if n != length { 173 t.Fatalf("Expected: %d\nGot: %d", length, n) 174 } 175 176 // 2. Bad request salt 177 _, err = rand.Read(b[1+8 : 1+8+saltLen]) 178 if err != nil { 179 t.Fatal(err) 180 } 181 182 _, err = ParseTCPResponseHeader(b, requestSalt) 183 if !errors.Is(err, ErrClientSaltMismatch) { 184 t.Fatalf("Expected: %s\nGot: %s", ErrClientSaltMismatch, err) 185 } 186 187 // 3. Bad timestamp (31s ago) 188 ts := time.Now().Add(-31 * time.Second) 189 binary.BigEndian.PutUint64(b[1:], uint64(ts.Unix())) 190 191 _, err = ParseTCPResponseHeader(b, requestSalt) 192 if !errors.Is(err, ErrBadTimestamp) { 193 t.Fatalf("Expected: %s\nGot: %s", ErrBadTimestamp, err) 194 } 195 196 // 4. Bad timestamp (31s later) 197 ts = time.Now().Add(31 * time.Second) 198 binary.BigEndian.PutUint64(b[1:], uint64(ts.Unix())) 199 200 _, err = ParseTCPResponseHeader(b, requestSalt) 201 if !errors.Is(err, ErrBadTimestamp) { 202 t.Fatalf("Expected: %s\nGot: %s", ErrBadTimestamp, err) 203 } 204 205 // 5. Bad type 206 b[0] = HeaderTypeClientStream 207 208 _, err = ParseTCPResponseHeader(b, requestSalt) 209 if !errors.Is(err, ErrTypeMismatch) { 210 t.Fatalf("Expected: %s\nGot: %s", ErrTypeMismatch, err) 211 } 212 } 213 214 func TestWriteAndParseSessionIDAndPacketID(t *testing.T) { 215 sid := mrand.Uint64() 216 pid := mrand.Uint64() 217 b := make([]byte, 16) 218 219 WriteSessionIDAndPacketID(b, sid, pid) 220 psid, ppid := ParseSessionIDAndPacketID(b) 221 if psid != sid { 222 t.Fatalf("Expected session ID %d, got %d", sid, psid) 223 } 224 if ppid != pid { 225 t.Fatalf("Expected packet ID %d, got %d", pid, ppid) 226 } 227 } 228 229 func TestWriteAndParseUDPClientMessageHeader(t *testing.T) { 230 var cachedDomain string 231 targetAddr := conn.AddrFromIPPort(netip.AddrPortFrom(netip.IPv6Unspecified(), 53)) 232 targetAddrLen := socks5.LengthOfAddrFromConnAddr(targetAddr) 233 noPaddingLen := UDPClientMessageHeaderFixedLength + targetAddrLen 234 paddingLen := 1 + mrand.IntN(MaxPaddingLength) 235 headerLen := noPaddingLen + paddingLen 236 payloadLen := 1 + int(mrand.Uint64()&math.MaxUint16) 237 bufLen := headerLen + payloadLen 238 b := make([]byte, bufLen) 239 bNoPadding := b[paddingLen:] 240 headerBuf := b[:headerLen] 241 headerNoPaddingBuf := bNoPadding[:noPaddingLen] 242 payload := b[headerLen:] 243 _, err := rand.Read(payload) 244 if err != nil { 245 t.Fatal(err) 246 } 247 248 // 1. Good header (no padding) 249 WriteUDPClientMessageHeader(headerNoPaddingBuf, 0, targetAddr) 250 251 ta, cachedDomain, ps, pl, err := ParseUDPClientMessageHeader(bNoPadding, cachedDomain) 252 if err != nil { 253 t.Fatal(err) 254 } 255 ps += headerLen - noPaddingLen 256 if ps != headerLen { 257 t.Errorf("Expected payload start %d, got %d", headerLen, ps) 258 } 259 if pl != payloadLen { 260 t.Errorf("Expected payload length %d, got %d", payloadLen, pl) 261 } 262 if !ta.Equals(targetAddr) { 263 t.Errorf("Expected target address %s, got %s", targetAddr, ta) 264 } 265 266 // 2. Good header (padding) 267 WriteUDPClientMessageHeader(headerBuf, paddingLen, targetAddr) 268 269 ta, cachedDomain, ps, pl, err = ParseUDPClientMessageHeader(b, cachedDomain) 270 if err != nil { 271 t.Fatal(err) 272 } 273 if ps != headerLen { 274 t.Errorf("Expected payload start %d, got %d", headerLen, ps) 275 } 276 if pl != payloadLen { 277 t.Errorf("Expected payload length %d, got %d", payloadLen, pl) 278 } 279 if !ta.Equals(targetAddr) { 280 t.Errorf("Expected target address %s, got %s", targetAddr, ta) 281 } 282 283 // 3. Bad header (incomplete SOCKS address) 284 b = b[:headerLen-1] 285 286 _, cachedDomain, _, _, err = ParseUDPClientMessageHeader(b, cachedDomain) 287 if err == nil { 288 t.Error("Expected error, got nil") 289 } 290 291 // 4. Bad header (incomplete padding) 292 b = b[:len(b)-targetAddrLen] 293 294 _, cachedDomain, _, _, err = ParseUDPClientMessageHeader(b, cachedDomain) 295 if !errors.Is(err, ErrPacketIncompleteHeader) { 296 t.Errorf("Expected: %s\nGot: %s", ErrPacketIncompleteHeader, err) 297 } 298 299 // 5. Bad header (incomplete padding length) 300 b = b[:1+8+1] 301 302 _, cachedDomain, _, _, err = ParseUDPClientMessageHeader(b, cachedDomain) 303 if !errors.Is(err, ErrPacketIncompleteHeader) { 304 t.Errorf("Expected: %s\nGot: %s", ErrPacketIncompleteHeader, err) 305 } 306 307 // 6. Bad timestamp (31s ago) 308 b = b[:UDPClientMessageHeaderFixedLength] 309 310 ts := time.Now().Add(-31 * time.Second) 311 binary.BigEndian.PutUint64(b[1:], uint64(ts.Unix())) 312 313 _, cachedDomain, _, _, err = ParseUDPClientMessageHeader(b, cachedDomain) 314 if !errors.Is(err, ErrBadTimestamp) { 315 t.Errorf("Expected: %s\nGot: %s", ErrBadTimestamp, err) 316 } 317 318 // 7. Bad timestamp (31s later) 319 ts = time.Now().Add(31 * time.Second) 320 binary.BigEndian.PutUint64(b[1:], uint64(ts.Unix())) 321 322 _, cachedDomain, _, _, err = ParseUDPClientMessageHeader(b, cachedDomain) 323 if !errors.Is(err, ErrBadTimestamp) { 324 t.Errorf("Expected: %s\nGot: %s", ErrBadTimestamp, err) 325 } 326 327 // 8. Bad type 328 b[0] = HeaderTypeServerPacket 329 330 _, _, _, _, err = ParseUDPClientMessageHeader(b, cachedDomain) 331 if !errors.Is(err, ErrTypeMismatch) { 332 t.Errorf("Expected: %s\nGot: %s", ErrTypeMismatch, err) 333 } 334 } 335 336 func TestWriteAndParseUDPServerMessageHeader(t *testing.T) { 337 csid := mrand.Uint64() 338 sourceAddrPort := netip.AddrPortFrom(netip.IPv6Unspecified(), 53) 339 sourceAddrPortLen := socks5.LengthOfAddrFromAddrPort(sourceAddrPort) 340 noPaddingLen := UDPServerMessageHeaderFixedLength + sourceAddrPortLen 341 paddingLen := 1 + mrand.IntN(MaxPaddingLength) 342 headerLen := noPaddingLen + paddingLen 343 payloadLen := 1 + int(mrand.Uint64()&math.MaxUint16) 344 bufLen := headerLen + payloadLen 345 b := make([]byte, bufLen) 346 bNoPadding := b[paddingLen:] 347 headerBuf := b[:headerLen] 348 headerNoPaddingBuf := bNoPadding[:noPaddingLen] 349 payload := b[headerLen:] 350 _, err := rand.Read(payload) 351 if err != nil { 352 t.Fatal(err) 353 } 354 355 // 1. Good header (no padding) 356 WriteUDPServerMessageHeader(headerNoPaddingBuf, csid, 0, sourceAddrPort) 357 358 sa, ps, pl, err := ParseUDPServerMessageHeader(bNoPadding, csid) 359 if err != nil { 360 t.Fatal(err) 361 } 362 ps += headerLen - noPaddingLen 363 if ps != headerLen { 364 t.Errorf("Expected payload start %d, got %d", headerLen, ps) 365 } 366 if pl != payloadLen { 367 t.Errorf("Expected payload length %d, got %d", payloadLen, pl) 368 } 369 if sa != sourceAddrPort { 370 t.Errorf("Expected target address %s, got %s", sourceAddrPort, sa) 371 } 372 373 // 2. Good header (pad) 374 WriteUDPServerMessageHeader(headerBuf, csid, paddingLen, sourceAddrPort) 375 376 sa, ps, pl, err = ParseUDPServerMessageHeader(b, csid) 377 if err != nil { 378 t.Fatal(err) 379 } 380 if ps != headerLen { 381 t.Errorf("Expected payload start %d, got %d", headerLen, ps) 382 } 383 if pl != payloadLen { 384 t.Errorf("Expected payload length %d, got %d", payloadLen, pl) 385 } 386 if sa != sourceAddrPort { 387 t.Errorf("Expected target address %s, got %s", sourceAddrPort, sa) 388 } 389 390 // 3. Bad header (incomplete SOCKS address) 391 b = b[:headerLen-1] 392 393 _, _, _, err = ParseUDPServerMessageHeader(b, csid) 394 if err == nil { 395 t.Error("Expected error, got nil") 396 } 397 398 // 4. Bad header (incomplete padding) 399 b = b[:len(b)-sourceAddrPortLen] 400 401 _, _, _, err = ParseUDPServerMessageHeader(b, csid) 402 if !errors.Is(err, ErrPacketIncompleteHeader) { 403 t.Errorf("Expected: %s\nGot: %s", ErrPacketIncompleteHeader, err) 404 } 405 406 // 5. Bad header (incomplete padding length) 407 b = b[:1+8+8+1] 408 409 _, _, _, err = ParseUDPServerMessageHeader(b, csid) 410 if !errors.Is(err, ErrPacketIncompleteHeader) { 411 t.Errorf("Expected: %s\nGot: %s", ErrPacketIncompleteHeader, err) 412 } 413 414 // 6. Bad client session ID 415 b = b[:UDPServerMessageHeaderFixedLength] 416 badCsid := csid + 1 417 binary.BigEndian.PutUint64(b[1+8:], badCsid) 418 419 _, _, _, err = ParseUDPServerMessageHeader(b, csid) 420 if !errors.Is(err, ErrClientSessionIDMismatch) { 421 t.Errorf("Expected: %s\nGot: %s", ErrClientSessionIDMismatch, err) 422 } 423 424 // 7. Bad timestamp (31s ago) 425 ts := time.Now().Add(-31 * time.Second) 426 binary.BigEndian.PutUint64(b[1:], uint64(ts.Unix())) 427 428 _, _, _, err = ParseUDPServerMessageHeader(b, csid) 429 if !errors.Is(err, ErrBadTimestamp) { 430 t.Errorf("Expected: %s\nGot: %s", ErrBadTimestamp, err) 431 } 432 433 // 8. Bad timestamp (31s later) 434 ts = time.Now().Add(31 * time.Second) 435 binary.BigEndian.PutUint64(b[1:], uint64(ts.Unix())) 436 437 _, _, _, err = ParseUDPServerMessageHeader(b, csid) 438 if !errors.Is(err, ErrBadTimestamp) { 439 t.Errorf("Expected: %s\nGot: %s", ErrBadTimestamp, err) 440 } 441 442 // 9. Bad type 443 b[0] = HeaderTypeClientPacket 444 445 _, _, _, err = ParseUDPServerMessageHeader(b, csid) 446 if !errors.Is(err, ErrTypeMismatch) { 447 t.Errorf("Expected: %s\nGot: %s", ErrTypeMismatch, err) 448 } 449 }