github.com/MicalKarl/sing-shadowsocks@v0.0.5/shadowaead_2022/service.go (about) 1 package shadowaead_2022 2 3 import ( 4 "context" 5 "crypto/aes" 6 "crypto/cipher" 7 "crypto/rand" 8 "encoding/base64" 9 "encoding/binary" 10 "io" 11 "math" 12 mRand "math/rand" 13 "net" 14 "os" 15 "sync" 16 "sync/atomic" 17 "time" 18 19 shadowsocks "github.com/MicalKarl/sing-shadowsocks" 20 "github.com/MicalKarl/sing-shadowsocks/shadowaead" 21 "github.com/MicalKarl/sing-shadowsocks/ssv" 22 "github.com/sagernet/sing/common" 23 "github.com/sagernet/sing/common/buf" 24 "github.com/sagernet/sing/common/cache" 25 E "github.com/sagernet/sing/common/exceptions" 26 M "github.com/sagernet/sing/common/metadata" 27 N "github.com/sagernet/sing/common/network" 28 "github.com/sagernet/sing/common/replay" 29 "github.com/sagernet/sing/common/udpnat" 30 31 "golang.org/x/crypto/chacha20poly1305" 32 ) 33 34 var ( 35 ErrNoPadding = E.New("bad request: missing payload or padding") 36 ErrBadPadding = E.New("bad request: damaged padding") 37 ) 38 39 var _ shadowsocks.Service = (*Service)(nil) 40 41 type Service struct { 42 name string 43 keySaltLength int 44 handler shadowsocks.Handler 45 timeFunc func() time.Time 46 47 constructor func(key []byte) (cipher.AEAD, error) 48 blockConstructor func(key []byte) (cipher.Block, error) 49 udpCipher cipher.AEAD 50 udpBlockCipher cipher.Block 51 psk []byte 52 53 replayFilter replay.Filter 54 udpNat *udpnat.Service[uint64] 55 udpSessions *cache.LruCache[uint64, *serverUDPSession] 56 } 57 58 func NewServiceWithPassword(method string, password string, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (shadowsocks.Service, error) { 59 if password == "" { 60 return nil, ErrMissingPSK 61 } 62 psk, err := base64.StdEncoding.DecodeString(password) 63 if err != nil { 64 return nil, E.Cause(err, "decode psk") 65 } 66 return NewService(method, psk, udpTimeout, handler, timeFunc) 67 } 68 69 func NewService(method string, psk []byte, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (shadowsocks.Service, error) { 70 s := &Service{ 71 name: method, 72 handler: handler, 73 timeFunc: timeFunc, 74 75 replayFilter: replay.NewSimple(60 * time.Second), 76 udpNat: udpnat.New[uint64](udpTimeout, handler), 77 udpSessions: cache.New[uint64, *serverUDPSession]( 78 cache.WithAge[uint64, *serverUDPSession](udpTimeout), 79 cache.WithUpdateAgeOnGet[uint64, *serverUDPSession](), 80 ), 81 } 82 83 switch method { 84 case "2022-blake3-aes-128-gcm": 85 s.keySaltLength = 16 86 s.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 87 s.blockConstructor = aes.NewCipher 88 case "2022-blake3-aes-256-gcm": 89 s.keySaltLength = 32 90 s.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 91 s.blockConstructor = aes.NewCipher 92 case "2022-blake3-chacha20-poly1305": 93 s.keySaltLength = 32 94 s.constructor = chacha20poly1305.New 95 default: 96 return nil, os.ErrInvalid 97 } 98 99 if len(psk) != s.keySaltLength { 100 if len(psk) < s.keySaltLength { 101 return nil, shadowsocks.ErrBadKey 102 } else if len(psk) > s.keySaltLength { 103 psk = Key(psk, s.keySaltLength) 104 } else { 105 return nil, ErrMissingPSK 106 } 107 } 108 109 var err error 110 switch method { 111 case "2022-blake3-aes-128-gcm", "2022-blake3-aes-256-gcm": 112 s.udpBlockCipher, err = aes.NewCipher(psk) 113 case "2022-blake3-chacha20-poly1305": 114 s.udpCipher, err = chacha20poly1305.NewX(psk) 115 } 116 if err != nil { 117 return nil, err 118 } 119 120 s.psk = psk 121 return s, nil 122 } 123 124 func (s *Service) Name() string { 125 return s.name 126 } 127 128 func (s *Service) Password() string { 129 return base64.StdEncoding.EncodeToString(s.psk) 130 } 131 132 func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { 133 err := s.newConnection(ctx, conn, metadata) 134 if err != nil { 135 err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err} 136 } 137 return err 138 } 139 140 func (s *Service) time() time.Time { 141 if s.timeFunc != nil { 142 return s.timeFunc() 143 } else { 144 return time.Now() 145 } 146 } 147 148 func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { 149 header := make([]byte, s.keySaltLength+shadowaead.Overhead+RequestHeaderFixedChunkLength) 150 151 n, err := conn.Read(header) 152 if err != nil { 153 return E.Cause(err, "read header") 154 } else if n < len(header) { 155 return shadowaead.ErrBadHeader 156 } 157 158 requestSalt := header[:s.keySaltLength] 159 160 if !s.replayFilter.Check(requestSalt) { 161 return ErrSaltNotUnique 162 } 163 164 requestKey := SessionKey(s.psk, requestSalt, s.keySaltLength) 165 readCipher, err := s.constructor(requestKey) 166 if err != nil { 167 return err 168 } 169 reader := shadowaead.NewReader( 170 conn, 171 readCipher, 172 MaxPacketSize, 173 ) 174 175 err = reader.ReadExternalChunk(header[s.keySaltLength:]) 176 if err != nil { 177 return err 178 } 179 180 headerType, err := reader.ReadByte() 181 if err != nil { 182 return E.Cause(err, "read header") 183 } 184 185 if headerType != HeaderTypeClient { 186 return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType) 187 } 188 189 var epoch uint64 190 err = binary.Read(reader, binary.BigEndian, &epoch) 191 if err != nil { 192 return err 193 } 194 195 diff := int(math.Abs(float64(s.time().Unix() - int64(epoch)))) 196 if diff > 30 { 197 return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") 198 } 199 200 var length uint16 201 err = binary.Read(reader, binary.BigEndian, &length) 202 if err != nil { 203 return err 204 } 205 206 err = reader.ReadWithLength(length) 207 if err != nil { 208 return err 209 } 210 211 destination, err := ssv.FakeSockSerializer.ReadAddrPort(reader) 212 if err != nil { 213 return err 214 } 215 216 var paddingLen uint16 217 err = binary.Read(reader, binary.BigEndian, &paddingLen) 218 if err != nil { 219 return err 220 } 221 222 if uint16(reader.Cached()) < paddingLen { 223 return ErrNoPadding 224 } 225 226 if paddingLen > 0 { 227 err = reader.Discard(int(paddingLen)) 228 if err != nil { 229 return E.Cause(err, "discard padding") 230 } 231 } else if reader.Cached() == 0 { 232 return ErrNoPadding 233 } 234 235 protocolConn := &serverConn{ 236 Service: s, 237 Conn: conn, 238 uPSK: s.psk, 239 headerType: headerType, 240 requestSalt: requestSalt, 241 } 242 243 protocolConn.reader = reader 244 245 metadata.Protocol = "shadowsocks" 246 metadata.Destination = destination 247 return s.handler.NewConnection(ctx, protocolConn, metadata) 248 } 249 250 type serverConn struct { 251 *Service 252 net.Conn 253 uPSK []byte 254 access sync.Mutex 255 headerType byte 256 reader *shadowaead.Reader 257 writer *shadowaead.Writer 258 requestSalt []byte 259 } 260 261 func (c *serverConn) writeResponse(payload []byte) (n int, err error) { 262 salt := buf.NewSize(c.keySaltLength) 263 salt.WriteRandom(salt.FreeLen()) 264 265 key := SessionKey(c.uPSK, salt.Bytes(), c.keySaltLength) 266 writeCipher, err := c.constructor(key) 267 if err != nil { 268 salt.Release() 269 return 270 } 271 writer := shadowaead.NewWriter( 272 c.Conn, 273 writeCipher, 274 MaxPacketSize, 275 ) 276 header := writer.Buffer() 277 header.Write(salt.Bytes()) 278 279 salt.Release() 280 281 headerType := byte(HeaderTypeServer) 282 payloadLen := len(payload) 283 284 headerFixedChunk := buf.NewSize(1 + 8 + c.keySaltLength + 2) 285 common.Must(headerFixedChunk.WriteByte(headerType)) 286 common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint64(c.time().Unix()))) 287 common.Must1(headerFixedChunk.Write(c.requestSalt)) 288 common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint16(payloadLen))) 289 290 writer.WriteChunk(header, headerFixedChunk.Slice()) 291 headerFixedChunk.Release() 292 c.requestSalt = nil 293 294 if payloadLen > 0 { 295 writer.WriteChunk(header, payload[:payloadLen]) 296 } 297 298 err = writer.BufferedWriter(header.Len()).Flush() 299 if err != nil { 300 return 301 } 302 303 switch headerType { 304 case HeaderTypeServer: 305 c.writer = writer 306 // case HeaderTypeServerEncrypted: 307 // encryptedWriter := NewTLSEncryptedStreamWriter(writer) 308 // if payloadLen < len(payload) { 309 // _, err = encryptedWriter.Write(payload[payloadLen:]) 310 // if err != nil { 311 // return 312 // } 313 // } 314 // c.writer = encryptedWriter 315 } 316 317 n = len(payload) 318 return 319 } 320 321 func (c *serverConn) Read(b []byte) (n int, err error) { 322 return c.reader.Read(b) 323 } 324 325 func (c *serverConn) Write(p []byte) (n int, err error) { 326 if c.writer != nil { 327 return c.writer.Write(p) 328 } 329 c.access.Lock() 330 if c.writer != nil { 331 c.access.Unlock() 332 return c.writer.Write(p) 333 } 334 defer c.access.Unlock() 335 return c.writeResponse(p) 336 } 337 338 func (c *serverConn) WriteVectorised(buffers []*buf.Buffer) error { 339 if c.writer != nil { 340 return c.writer.WriteVectorised(buffers) 341 } 342 c.access.Lock() 343 if c.writer != nil { 344 c.access.Unlock() 345 return c.writer.WriteVectorised(buffers) 346 } 347 defer c.access.Unlock() 348 _, err := c.writeResponse(buffers[0].Bytes()) 349 if err != nil { 350 buf.ReleaseMulti(buffers) 351 return err 352 } 353 buffers[0].Release() 354 return c.writer.WriteVectorised(buffers[1:]) 355 } 356 357 func (c *serverConn) Close() error { 358 return common.Close( 359 c.Conn, 360 common.PtrOrNil(c.reader), 361 common.PtrOrNil(c.writer), 362 ) 363 } 364 365 func (c *serverConn) NeedAdditionalReadDeadline() bool { 366 return true 367 } 368 369 func (c *serverConn) Upstream() any { 370 return c.Conn 371 } 372 373 func (s *Service) WriteIsThreadUnsafe() { 374 } 375 376 func (s *Service) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { 377 err := s.newPacket(ctx, conn, buffer, metadata) 378 if err != nil { 379 err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err} 380 } 381 return err 382 } 383 384 func (s *Service) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { 385 var packetHeader []byte 386 if s.udpCipher != nil { 387 if buffer.Len() < PacketNonceSize+PacketMinimalHeaderSize { 388 return ErrPacketTooShort 389 } 390 _, err := s.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil) 391 if err != nil { 392 return E.Cause(err, "decrypt packet header") 393 } 394 buffer.Advance(PacketNonceSize) 395 buffer.Truncate(buffer.Len() - shadowaead.Overhead) 396 } else { 397 if buffer.Len() < PacketMinimalHeaderSize { 398 return ErrPacketTooShort 399 } 400 packetHeader = buffer.To(aes.BlockSize) 401 s.udpBlockCipher.Decrypt(packetHeader, packetHeader) 402 } 403 404 var sessionId, packetId uint64 405 err := binary.Read(buffer, binary.BigEndian, &sessionId) 406 if err != nil { 407 return err 408 } 409 err = binary.Read(buffer, binary.BigEndian, &packetId) 410 if err != nil { 411 return err 412 } 413 414 session, loaded := s.udpSessions.LoadOrStore(sessionId, s.newUDPSession) 415 if !loaded { 416 session.remoteSessionId = sessionId 417 if packetHeader != nil { 418 key := SessionKey(s.psk, packetHeader[:8], s.keySaltLength) 419 session.remoteCipher, err = s.constructor(key) 420 if err != nil { 421 return err 422 } 423 } 424 } 425 goto process 426 427 returnErr: 428 if !loaded { 429 s.udpSessions.Delete(sessionId) 430 } 431 return err 432 433 process: 434 if !session.window.Check(packetId) { 435 err = ErrPacketIdNotUnique 436 goto returnErr 437 } 438 439 if packetHeader != nil { 440 _, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) 441 if err != nil { 442 err = E.Cause(err, "decrypt packet") 443 goto returnErr 444 } 445 buffer.Truncate(buffer.Len() - shadowaead.Overhead) 446 } 447 448 session.window.Add(packetId) 449 450 var headerType byte 451 headerType, err = buffer.ReadByte() 452 if err != nil { 453 err = E.Cause(err, "decrypt packet") 454 goto returnErr 455 } 456 if headerType != HeaderTypeClient { 457 err = E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType) 458 goto returnErr 459 } 460 461 var epoch uint64 462 err = binary.Read(buffer, binary.BigEndian, &epoch) 463 if err != nil { 464 goto returnErr 465 } 466 diff := int(math.Abs(float64(s.time().Unix() - int64(epoch)))) 467 if diff > 30 { 468 err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") 469 goto returnErr 470 } 471 472 var paddingLen uint16 473 err = binary.Read(buffer, binary.BigEndian, &paddingLen) 474 if err != nil { 475 err = E.Cause(err, "read padding length") 476 goto returnErr 477 } 478 buffer.Advance(int(paddingLen)) 479 480 destination, err := ssv.FakeSockSerializer.ReadAddrPort(buffer) 481 if err != nil { 482 goto returnErr 483 } 484 metadata.Protocol = "shadowsocks" 485 metadata.Destination = destination 486 s.udpNat.NewPacket(ctx, sessionId, buffer, metadata, func(natConn N.PacketConn) N.PacketWriter { 487 return &serverPacketWriter{s, conn, natConn, session, s.udpBlockCipher} 488 }) 489 return nil 490 } 491 492 func (s *Service) NewError(ctx context.Context, err error) { 493 s.handler.NewError(ctx, err) 494 } 495 496 type serverPacketWriter struct { 497 *Service 498 source N.PacketConn 499 nat N.PacketConn 500 session *serverUDPSession 501 udpBlockCipher cipher.Block 502 } 503 504 func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 505 var hdrLen int 506 if w.udpCipher != nil { 507 hdrLen = PacketNonceSize 508 } 509 510 var paddingLen int 511 if destination.Port == 53 && buffer.Len() < MaxPaddingLength { 512 paddingLen = mRand.Intn(MaxPaddingLength-buffer.Len()) + 1 513 } 514 515 hdrLen += 16 // packet header 516 hdrLen += 1 // header type 517 hdrLen += 8 // timestamp 518 hdrLen += 8 // remote session id 519 hdrLen += 2 // padding length 520 hdrLen += paddingLen 521 hdrLen += M.SocksaddrSerializer.AddrPortLen(destination) 522 header := buf.With(buffer.ExtendHeader(hdrLen)) 523 524 var dataIndex int 525 if w.udpCipher != nil { 526 common.Must1(header.ReadFullFrom(w.session.rng, PacketNonceSize)) 527 dataIndex = PacketNonceSize 528 } else { 529 dataIndex = aes.BlockSize 530 } 531 532 common.Must( 533 binary.Write(header, binary.BigEndian, w.session.sessionId), 534 binary.Write(header, binary.BigEndian, w.session.nextPacketId()), 535 header.WriteByte(HeaderTypeServer), 536 binary.Write(header, binary.BigEndian, uint64(w.time().Unix())), 537 binary.Write(header, binary.BigEndian, w.session.remoteSessionId), 538 binary.Write(header, binary.BigEndian, uint16(paddingLen)), // padding length 539 ) 540 541 if paddingLen > 0 { 542 header.Extend(paddingLen) 543 } 544 545 err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 546 if err != nil { 547 buffer.Release() 548 return err 549 } 550 551 if w.udpCipher != nil { 552 w.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil) 553 buffer.Extend(shadowaead.Overhead) 554 } else { 555 packetHeader := buffer.To(aes.BlockSize) 556 w.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil) 557 buffer.Extend(shadowaead.Overhead) 558 w.udpBlockCipher.Encrypt(packetHeader, packetHeader) 559 } 560 return w.source.WritePacket(buffer, M.SocksaddrFromNet(w.nat.LocalAddr())) 561 } 562 563 func (w *serverPacketWriter) FrontHeadroom() int { 564 var hdrLen int 565 if w.udpCipher != nil { 566 hdrLen = PacketNonceSize 567 } 568 hdrLen += 16 // packet header 569 hdrLen += 1 // header type 570 hdrLen += 8 // timestamp 571 hdrLen += 8 // remote session id 572 hdrLen += 2 // padding length 573 hdrLen += MaxPaddingLength 574 hdrLen += M.MaxSocksaddrLength 575 return hdrLen 576 } 577 578 func (w *serverPacketWriter) RearHeadroom() int { 579 return shadowaead.Overhead 580 } 581 582 func (w *serverPacketWriter) Upstream() any { 583 return w.source 584 } 585 586 type serverUDPSession struct { 587 sessionId uint64 588 remoteSessionId uint64 589 packetId uint64 590 cipher cipher.AEAD 591 remoteCipher cipher.AEAD 592 window SlidingWindow 593 rng io.Reader 594 } 595 596 func (s *serverUDPSession) nextPacketId() uint64 { 597 return atomic.AddUint64(&s.packetId, 1) 598 } 599 600 func (s *Service) newUDPSession() *serverUDPSession { 601 session := &serverUDPSession{} 602 if s.udpCipher != nil { 603 session.rng = Blake3KeyedHash(rand.Reader) 604 common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId)) 605 } else { 606 common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId)) 607 } 608 session.packetId-- 609 if s.udpCipher == nil { 610 sessionId := make([]byte, 8) 611 binary.BigEndian.PutUint64(sessionId, session.sessionId) 612 key := SessionKey(s.psk, sessionId, s.keySaltLength) 613 var err error 614 session.cipher, err = s.constructor(key) 615 common.Must(err) 616 } 617 return session 618 }