github.com/MerlinKodo/sing-shadowsocks@v0.2.6/shadowaead_2022/service.go (about) 1 package shadowaead_2022 2 3 import ( 4 "context" 5 "crypto/aes" 6 "crypto/cipher" 7 "crypto/rand" 8 "encoding/base64" 9 "encoding/binary" 10 "io" 11 "math" 12 mRand "math/rand" 13 "net" 14 "os" 15 "sync" 16 "sync/atomic" 17 "time" 18 19 shadowsocks "github.com/MerlinKodo/sing-shadowsocks" 20 "github.com/MerlinKodo/sing-shadowsocks/shadowaead" 21 "github.com/sagernet/sing/common" 22 "github.com/sagernet/sing/common/buf" 23 "github.com/sagernet/sing/common/cache" 24 E "github.com/sagernet/sing/common/exceptions" 25 M "github.com/sagernet/sing/common/metadata" 26 N "github.com/sagernet/sing/common/network" 27 "github.com/sagernet/sing/common/replay" 28 "github.com/sagernet/sing/common/udpnat" 29 30 "golang.org/x/crypto/chacha20poly1305" 31 ) 32 33 var ( 34 ErrNoPadding = E.New("bad request: missing payload or padding") 35 ErrBadPadding = E.New("bad request: damaged padding") 36 ) 37 38 var _ shadowsocks.Service = (*Service)(nil) 39 40 type Service struct { 41 name string 42 keySaltLength int 43 handler shadowsocks.Handler 44 timeFunc func() time.Time 45 46 constructor func(key []byte) (cipher.AEAD, error) 47 blockConstructor func(key []byte) (cipher.Block, error) 48 udpCipher cipher.AEAD 49 udpBlockCipher cipher.Block 50 psk []byte 51 52 replayFilter replay.Filter 53 udpNat *udpnat.Service[uint64] 54 udpSessions *cache.LruCache[uint64, *serverUDPSession] 55 } 56 57 func NewServiceWithPassword(method string, password string, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (shadowsocks.Service, error) { 58 if password == "" { 59 return nil, ErrMissingPSK 60 } 61 psk, err := base64.StdEncoding.DecodeString(password) 62 if err != nil { 63 return nil, E.Cause(err, "decode psk") 64 } 65 return NewService(method, psk, udpTimeout, handler, timeFunc) 66 } 67 68 func NewService(method string, psk []byte, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (shadowsocks.Service, error) { 69 s := &Service{ 70 name: method, 71 handler: handler, 72 timeFunc: timeFunc, 73 74 replayFilter: replay.NewSimple(60 * time.Second), 75 udpNat: udpnat.New[uint64](udpTimeout, handler), 76 udpSessions: cache.New[uint64, *serverUDPSession]( 77 cache.WithAge[uint64, *serverUDPSession](udpTimeout), 78 cache.WithUpdateAgeOnGet[uint64, *serverUDPSession](), 79 ), 80 } 81 82 switch method { 83 case "2022-blake3-aes-128-gcm": 84 s.keySaltLength = 16 85 s.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 86 s.blockConstructor = aes.NewCipher 87 case "2022-blake3-aes-256-gcm": 88 s.keySaltLength = 32 89 s.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 90 s.blockConstructor = aes.NewCipher 91 case "2022-blake3-chacha20-poly1305": 92 s.keySaltLength = 32 93 s.constructor = chacha20poly1305.New 94 default: 95 return nil, os.ErrInvalid 96 } 97 98 if len(psk) != s.keySaltLength { 99 if len(psk) < s.keySaltLength { 100 return nil, shadowsocks.ErrBadKey 101 } else if len(psk) > s.keySaltLength { 102 psk = Key(psk, s.keySaltLength) 103 } else { 104 return nil, ErrMissingPSK 105 } 106 } 107 108 var err error 109 switch method { 110 case "2022-blake3-aes-128-gcm", "2022-blake3-aes-256-gcm": 111 s.udpBlockCipher, err = aes.NewCipher(psk) 112 case "2022-blake3-chacha20-poly1305": 113 s.udpCipher, err = chacha20poly1305.NewX(psk) 114 } 115 if err != nil { 116 return nil, err 117 } 118 119 s.psk = psk 120 return s, nil 121 } 122 123 func (s *Service) Name() string { 124 return s.name 125 } 126 127 func (s *Service) Password() string { 128 return base64.StdEncoding.EncodeToString(s.psk) 129 } 130 131 func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { 132 err := s.newConnection(ctx, conn, metadata) 133 if err != nil { 134 err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err} 135 } 136 return err 137 } 138 139 func (s *Service) time() time.Time { 140 if s.timeFunc != nil { 141 return s.timeFunc() 142 } else { 143 return time.Now() 144 } 145 } 146 147 func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { 148 header := make([]byte, s.keySaltLength+shadowaead.Overhead+RequestHeaderFixedChunkLength) 149 150 n, err := conn.Read(header) 151 if err != nil { 152 return E.Cause(err, "read header") 153 } else if n < len(header) { 154 return shadowaead.ErrBadHeader 155 } 156 157 requestSalt := header[:s.keySaltLength] 158 159 if !s.replayFilter.Check(requestSalt) { 160 return ErrSaltNotUnique 161 } 162 163 requestKey := SessionKey(s.psk, requestSalt, s.keySaltLength) 164 readCipher, err := s.constructor(requestKey) 165 if err != nil { 166 return err 167 } 168 reader := shadowaead.NewReader( 169 conn, 170 readCipher, 171 MaxPacketSize, 172 ) 173 174 err = reader.ReadExternalChunk(header[s.keySaltLength:]) 175 if err != nil { 176 return err 177 } 178 179 headerType, err := reader.ReadByte() 180 if err != nil { 181 return E.Cause(err, "read header") 182 } 183 184 if headerType != HeaderTypeClient { 185 return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType) 186 } 187 188 var epoch uint64 189 err = binary.Read(reader, binary.BigEndian, &epoch) 190 if err != nil { 191 return err 192 } 193 194 diff := int(math.Abs(float64(s.time().Unix() - int64(epoch)))) 195 if diff > 30 { 196 return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") 197 } 198 199 var length uint16 200 err = binary.Read(reader, binary.BigEndian, &length) 201 if err != nil { 202 return err 203 } 204 205 err = reader.ReadWithLength(length) 206 if err != nil { 207 return err 208 } 209 210 destination, err := M.SocksaddrSerializer.ReadAddrPort(reader) 211 if err != nil { 212 return err 213 } 214 215 var paddingLen uint16 216 err = binary.Read(reader, binary.BigEndian, &paddingLen) 217 if err != nil { 218 return err 219 } 220 221 if uint16(reader.Cached()) < paddingLen { 222 return ErrNoPadding 223 } 224 225 if paddingLen > 0 { 226 err = reader.Discard(int(paddingLen)) 227 if err != nil { 228 return E.Cause(err, "discard padding") 229 } 230 } else if reader.Cached() == 0 { 231 return ErrNoPadding 232 } 233 234 protocolConn := &serverConn{ 235 Service: s, 236 Conn: conn, 237 uPSK: s.psk, 238 headerType: headerType, 239 requestSalt: requestSalt, 240 } 241 242 protocolConn.reader = reader 243 244 metadata.Protocol = "shadowsocks" 245 metadata.Destination = destination 246 return s.handler.NewConnection(ctx, protocolConn, metadata) 247 } 248 249 type serverConn struct { 250 *Service 251 net.Conn 252 uPSK []byte 253 access sync.Mutex 254 headerType byte 255 reader *shadowaead.Reader 256 writer *shadowaead.Writer 257 requestSalt []byte 258 } 259 260 func (c *serverConn) writeResponse(payload []byte) (n int, err error) { 261 salt := buf.NewSize(c.keySaltLength) 262 salt.WriteRandom(salt.FreeLen()) 263 264 key := SessionKey(c.uPSK, salt.Bytes(), c.keySaltLength) 265 writeCipher, err := c.constructor(key) 266 if err != nil { 267 salt.Release() 268 return 269 } 270 writer := shadowaead.NewWriter( 271 c.Conn, 272 writeCipher, 273 MaxPacketSize, 274 ) 275 header := writer.Buffer() 276 header.Write(salt.Bytes()) 277 278 salt.Release() 279 280 headerType := byte(HeaderTypeServer) 281 payloadLen := len(payload) 282 283 headerFixedChunk := buf.NewSize(1 + 8 + c.keySaltLength + 2) 284 common.Must(headerFixedChunk.WriteByte(headerType)) 285 common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint64(c.time().Unix()))) 286 common.Must1(headerFixedChunk.Write(c.requestSalt)) 287 common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint16(payloadLen))) 288 289 writer.WriteChunk(header, headerFixedChunk.Slice()) 290 headerFixedChunk.Release() 291 c.requestSalt = nil 292 293 if payloadLen > 0 { 294 writer.WriteChunk(header, payload[:payloadLen]) 295 } 296 297 err = writer.BufferedWriter(header.Len()).Flush() 298 if err != nil { 299 return 300 } 301 302 switch headerType { 303 case HeaderTypeServer: 304 c.writer = writer 305 // case HeaderTypeServerEncrypted: 306 // encryptedWriter := NewTLSEncryptedStreamWriter(writer) 307 // if payloadLen < len(payload) { 308 // _, err = encryptedWriter.Write(payload[payloadLen:]) 309 // if err != nil { 310 // return 311 // } 312 // } 313 // c.writer = encryptedWriter 314 } 315 316 n = len(payload) 317 return 318 } 319 320 func (c *serverConn) Read(b []byte) (n int, err error) { 321 return c.reader.Read(b) 322 } 323 324 func (c *serverConn) Write(p []byte) (n int, err error) { 325 if c.writer != nil { 326 return c.writer.Write(p) 327 } 328 c.access.Lock() 329 if c.writer != nil { 330 c.access.Unlock() 331 return c.writer.Write(p) 332 } 333 defer c.access.Unlock() 334 return c.writeResponse(p) 335 } 336 337 func (c *serverConn) WriteVectorised(buffers []*buf.Buffer) error { 338 if c.writer != nil { 339 return c.writer.WriteVectorised(buffers) 340 } 341 c.access.Lock() 342 if c.writer != nil { 343 c.access.Unlock() 344 return c.writer.WriteVectorised(buffers) 345 } 346 defer c.access.Unlock() 347 _, err := c.writeResponse(buffers[0].Bytes()) 348 if err != nil { 349 buf.ReleaseMulti(buffers) 350 return err 351 } 352 buffers[0].Release() 353 return c.writer.WriteVectorised(buffers[1:]) 354 } 355 356 func (c *serverConn) Close() error { 357 return common.Close( 358 c.Conn, 359 common.PtrOrNil(c.reader), 360 common.PtrOrNil(c.writer), 361 ) 362 } 363 364 func (c *serverConn) NeedAdditionalReadDeadline() bool { 365 return true 366 } 367 368 func (c *serverConn) Upstream() any { 369 return c.Conn 370 } 371 372 func (s *Service) WriteIsThreadUnsafe() { 373 } 374 375 func (s *Service) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { 376 err := s.newPacket(ctx, conn, buffer, metadata) 377 if err != nil { 378 err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err} 379 } 380 return err 381 } 382 383 func (s *Service) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { 384 var packetHeader []byte 385 if s.udpCipher != nil { 386 if buffer.Len() < PacketNonceSize+PacketMinimalHeaderSize { 387 return ErrPacketTooShort 388 } 389 _, err := s.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil) 390 if err != nil { 391 return E.Cause(err, "decrypt packet header") 392 } 393 buffer.Advance(PacketNonceSize) 394 buffer.Truncate(buffer.Len() - shadowaead.Overhead) 395 } else { 396 if buffer.Len() < PacketMinimalHeaderSize { 397 return ErrPacketTooShort 398 } 399 packetHeader = buffer.To(aes.BlockSize) 400 s.udpBlockCipher.Decrypt(packetHeader, packetHeader) 401 } 402 403 var sessionId, packetId uint64 404 err := binary.Read(buffer, binary.BigEndian, &sessionId) 405 if err != nil { 406 return err 407 } 408 err = binary.Read(buffer, binary.BigEndian, &packetId) 409 if err != nil { 410 return err 411 } 412 413 session, loaded := s.udpSessions.LoadOrStore(sessionId, s.newUDPSession) 414 if !loaded { 415 session.remoteSessionId = sessionId 416 if packetHeader != nil { 417 key := SessionKey(s.psk, packetHeader[:8], s.keySaltLength) 418 session.remoteCipher, err = s.constructor(key) 419 if err != nil { 420 return err 421 } 422 } 423 } 424 goto process 425 426 returnErr: 427 if !loaded { 428 s.udpSessions.Delete(sessionId) 429 } 430 return err 431 432 process: 433 if !session.window.Check(packetId) { 434 err = ErrPacketIdNotUnique 435 goto returnErr 436 } 437 438 if packetHeader != nil { 439 _, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) 440 if err != nil { 441 err = E.Cause(err, "decrypt packet") 442 goto returnErr 443 } 444 buffer.Truncate(buffer.Len() - shadowaead.Overhead) 445 } 446 447 session.window.Add(packetId) 448 449 var headerType byte 450 headerType, err = buffer.ReadByte() 451 if err != nil { 452 err = E.Cause(err, "decrypt packet") 453 goto returnErr 454 } 455 if headerType != HeaderTypeClient { 456 err = E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType) 457 goto returnErr 458 } 459 460 var epoch uint64 461 err = binary.Read(buffer, binary.BigEndian, &epoch) 462 if err != nil { 463 goto returnErr 464 } 465 diff := int(math.Abs(float64(s.time().Unix() - int64(epoch)))) 466 if diff > 30 { 467 err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") 468 goto returnErr 469 } 470 471 var paddingLen uint16 472 err = binary.Read(buffer, binary.BigEndian, &paddingLen) 473 if err != nil { 474 err = E.Cause(err, "read padding length") 475 goto returnErr 476 } 477 buffer.Advance(int(paddingLen)) 478 479 destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) 480 if err != nil { 481 goto returnErr 482 } 483 metadata.Protocol = "shadowsocks" 484 metadata.Destination = destination 485 s.udpNat.NewPacket(ctx, sessionId, buffer, metadata, func(natConn N.PacketConn) N.PacketWriter { 486 return &serverPacketWriter{s, conn, natConn, session, s.udpBlockCipher} 487 }) 488 return nil 489 } 490 491 func (s *Service) NewError(ctx context.Context, err error) { 492 s.handler.NewError(ctx, err) 493 } 494 495 type serverPacketWriter struct { 496 *Service 497 source N.PacketConn 498 nat N.PacketConn 499 session *serverUDPSession 500 udpBlockCipher cipher.Block 501 } 502 503 func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 504 var hdrLen int 505 if w.udpCipher != nil { 506 hdrLen = PacketNonceSize 507 } 508 509 var paddingLen int 510 if destination.Port == 53 && buffer.Len() < MaxPaddingLength { 511 paddingLen = mRand.Intn(MaxPaddingLength-buffer.Len()) + 1 512 } 513 514 hdrLen += 16 // packet header 515 hdrLen += 1 // header type 516 hdrLen += 8 // timestamp 517 hdrLen += 8 // remote session id 518 hdrLen += 2 // padding length 519 hdrLen += paddingLen 520 hdrLen += M.SocksaddrSerializer.AddrPortLen(destination) 521 header := buf.With(buffer.ExtendHeader(hdrLen)) 522 523 var dataIndex int 524 if w.udpCipher != nil { 525 common.Must1(header.ReadFullFrom(w.session.rng, PacketNonceSize)) 526 dataIndex = PacketNonceSize 527 } else { 528 dataIndex = aes.BlockSize 529 } 530 531 common.Must( 532 binary.Write(header, binary.BigEndian, w.session.sessionId), 533 binary.Write(header, binary.BigEndian, w.session.nextPacketId()), 534 header.WriteByte(HeaderTypeServer), 535 binary.Write(header, binary.BigEndian, uint64(w.time().Unix())), 536 binary.Write(header, binary.BigEndian, w.session.remoteSessionId), 537 binary.Write(header, binary.BigEndian, uint16(paddingLen)), // padding length 538 ) 539 540 if paddingLen > 0 { 541 header.Extend(paddingLen) 542 } 543 544 err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 545 if err != nil { 546 buffer.Release() 547 return err 548 } 549 550 if w.udpCipher != nil { 551 w.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) 552 buffer.Extend(shadowaead.Overhead) 553 } else { 554 packetHeader := buffer.To(aes.BlockSize) 555 w.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil) 556 buffer.Extend(shadowaead.Overhead) 557 w.udpBlockCipher.Encrypt(packetHeader, packetHeader) 558 } 559 return w.source.WritePacket(buffer, M.SocksaddrFromNet(w.nat.LocalAddr())) 560 } 561 562 func (w *serverPacketWriter) FrontHeadroom() int { 563 var hdrLen int 564 if w.udpCipher != nil { 565 hdrLen = PacketNonceSize 566 } 567 hdrLen += 16 // packet header 568 hdrLen += 1 // header type 569 hdrLen += 8 // timestamp 570 hdrLen += 8 // remote session id 571 hdrLen += 2 // padding length 572 hdrLen += MaxPaddingLength 573 hdrLen += M.MaxSocksaddrLength 574 return hdrLen 575 } 576 577 func (w *serverPacketWriter) RearHeadroom() int { 578 return shadowaead.Overhead 579 } 580 581 func (w *serverPacketWriter) Upstream() any { 582 return w.source 583 } 584 585 type serverUDPSession struct { 586 sessionId uint64 587 remoteSessionId uint64 588 packetId uint64 589 cipher cipher.AEAD 590 remoteCipher cipher.AEAD 591 window SlidingWindow 592 rng io.Reader 593 } 594 595 func (s *serverUDPSession) nextPacketId() uint64 { 596 return atomic.AddUint64(&s.packetId, 1) 597 } 598 599 func (s *Service) newUDPSession() *serverUDPSession { 600 session := &serverUDPSession{} 601 if s.udpCipher != nil { 602 session.rng = Blake3KeyedHash(rand.Reader) 603 common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId)) 604 } else { 605 common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId)) 606 } 607 session.packetId-- 608 if s.udpCipher == nil { 609 sessionId := make([]byte, 8) 610 binary.BigEndian.PutUint64(sessionId, session.sessionId) 611 key := SessionKey(s.psk, sessionId, s.keySaltLength) 612 var err error 613 session.cipher, err = s.constructor(key) 614 common.Must(err) 615 } 616 return session 617 }