github.com/sagernet/sing-shadowsocks@v0.2.6/shadowaead_2022/service_multi.go (about) 1 package shadowaead_2022 2 3 import ( 4 "context" 5 "crypto/aes" 6 "crypto/cipher" 7 "crypto/rand" 8 "encoding/base64" 9 "encoding/binary" 10 "io" 11 "math" 12 "net" 13 "os" 14 "time" 15 16 "github.com/sagernet/sing-shadowsocks" 17 "github.com/sagernet/sing-shadowsocks/shadowaead" 18 "github.com/sagernet/sing/common" 19 "github.com/sagernet/sing/common/auth" 20 "github.com/sagernet/sing/common/buf" 21 E "github.com/sagernet/sing/common/exceptions" 22 M "github.com/sagernet/sing/common/metadata" 23 N "github.com/sagernet/sing/common/network" 24 "github.com/sagernet/sing/common/rw" 25 26 "lukechampine.com/blake3" 27 ) 28 29 var _ shadowsocks.MultiService[int] = (*MultiService[int])(nil) 30 31 type MultiService[U comparable] struct { 32 *Service 33 34 uPSK map[U][]byte 35 uPSKHash map[[aes.BlockSize]byte]U 36 uCipher map[U]cipher.Block 37 } 38 39 func NewMultiServiceWithPassword[U comparable](method string, password string, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (*MultiService[U], error) { 40 if password == "" { 41 return nil, ErrMissingPSK 42 } 43 iPSK, err := base64.StdEncoding.DecodeString(password) 44 if err != nil { 45 return nil, E.Cause(err, "decode psk") 46 } 47 return NewMultiService[U](method, iPSK, udpTimeout, handler, timeFunc) 48 } 49 50 func NewMultiService[U comparable](method string, iPSK []byte, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (*MultiService[U], error) { 51 switch method { 52 case "2022-blake3-aes-128-gcm": 53 case "2022-blake3-aes-256-gcm": 54 default: 55 return nil, os.ErrInvalid 56 } 57 58 ss, err := NewService(method, iPSK, udpTimeout, handler, timeFunc) 59 if err != nil { 60 return nil, err 61 } 62 63 s := &MultiService[U]{ 64 Service: ss.(*Service), 65 66 uPSK: make(map[U][]byte), 67 uPSKHash: make(map[[aes.BlockSize]byte]U), 68 } 69 return s, nil 70 } 71 72 func (s *MultiService[U]) UpdateUsers(userList []U, keyList [][]byte) error { 73 uPSK := make(map[U][]byte) 74 uPSKHash := make(map[[aes.BlockSize]byte]U) 75 uCipher := make(map[U]cipher.Block) 76 for i, user := range userList { 77 key := keyList[i] 78 if len(key) < s.keySaltLength { 79 return shadowsocks.ErrBadKey 80 } else if len(key) > s.keySaltLength { 81 key = Key(key, s.keySaltLength) 82 } 83 84 var hash [aes.BlockSize]byte 85 hash512 := blake3.Sum512(key) 86 copy(hash[:], hash512[:]) 87 88 uPSKHash[hash] = user 89 uPSK[user] = key 90 var err error 91 uCipher[user], err = s.blockConstructor(key) 92 if err != nil { 93 return err 94 } 95 } 96 97 s.uPSK = uPSK 98 s.uPSKHash = uPSKHash 99 s.uCipher = uCipher 100 return nil 101 } 102 103 func (s *MultiService[U]) UpdateUsersWithPasswords(userList []U, passwordList []string) error { 104 keyList := make([][]byte, 0, len(passwordList)) 105 for _, password := range passwordList { 106 if password == "" { 107 return shadowsocks.ErrMissingPassword 108 } 109 uPSK, err := base64.StdEncoding.DecodeString(password) 110 if err != nil { 111 return E.Cause(err, "decode psk") 112 } 113 keyList = append(keyList, uPSK) 114 } 115 return s.UpdateUsers(userList, keyList) 116 } 117 118 func (s *MultiService[U]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { 119 err := s.NewConnection0(ctx, conn, metadata, conn, nil) 120 if err != nil { 121 err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err} 122 } 123 return err 124 } 125 126 func (s *MultiService[U]) NewConnection0(ctx context.Context, conn net.Conn, metadata M.Metadata, handshakeReader io.Reader, handshakeSuccess func()) error { 127 requestHeader := make([]byte, s.keySaltLength+aes.BlockSize+shadowaead.Overhead+RequestHeaderFixedChunkLength) 128 var ( 129 n int 130 err error 131 ) 132 if handshakeSuccess != nil { 133 n, err = io.ReadFull(handshakeReader, requestHeader) 134 } else { 135 n, err = handshakeReader.Read(requestHeader) 136 } 137 if err != nil { 138 return err 139 } else if n < len(requestHeader) { 140 return shadowaead.ErrBadHeader 141 } 142 requestSalt := requestHeader[:s.keySaltLength] 143 if !s.replayFilter.Check(requestSalt) { 144 return ErrSaltNotUnique 145 } 146 147 var _eiHeader [aes.BlockSize]byte 148 eiHeader := _eiHeader[:] 149 copy(eiHeader, requestHeader[s.keySaltLength:s.keySaltLength+aes.BlockSize]) 150 151 keyMaterial := make([]byte, s.keySaltLength*2) 152 copy(keyMaterial, s.psk) 153 copy(keyMaterial[s.keySaltLength:], requestSalt) 154 identitySubkey := buf.NewSize(s.keySaltLength) 155 identitySubkey.Extend(identitySubkey.FreeLen()) 156 blake3.DeriveKey(identitySubkey.Bytes(), "shadowsocks 2022 identity subkey", keyMaterial) 157 b, err := s.blockConstructor(identitySubkey.Bytes()) 158 identitySubkey.Release() 159 if err != nil { 160 return err 161 } 162 b.Decrypt(eiHeader, eiHeader) 163 164 var user U 165 var uPSK []byte 166 if u, loaded := s.uPSKHash[_eiHeader]; loaded { 167 user = u 168 uPSK = s.uPSK[u] 169 } else { 170 return ErrInvalidRequest 171 } 172 173 if handshakeSuccess != nil { 174 handshakeSuccess() 175 } 176 177 requestKey := SessionKey(uPSK, requestSalt, s.keySaltLength) 178 readCipher, err := s.constructor(requestKey) 179 if err != nil { 180 return err 181 } 182 reader := shadowaead.NewReader( 183 conn, 184 readCipher, 185 MaxPacketSize, 186 ) 187 188 err = reader.ReadExternalChunk(requestHeader[s.keySaltLength+aes.BlockSize:]) 189 if err != nil { 190 return err 191 } 192 193 headerType, err := rw.ReadByte(reader) 194 if err != nil { 195 return E.Cause(err, "read header") 196 } 197 198 if headerType != HeaderTypeClient { 199 return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType) 200 } 201 202 var epoch uint64 203 err = binary.Read(reader, binary.BigEndian, &epoch) 204 if err != nil { 205 return E.Cause(err, "read timestamp") 206 } 207 diff := int(math.Abs(float64(s.time().Unix() - int64(epoch)))) 208 if diff > 30 { 209 return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") 210 } 211 var length uint16 212 err = binary.Read(reader, binary.BigEndian, &length) 213 if err != nil { 214 return E.Cause(err, "read length") 215 } 216 217 err = reader.ReadWithLength(length) 218 if err != nil { 219 return err 220 } 221 222 destination, err := M.SocksaddrSerializer.ReadAddrPort(reader) 223 if err != nil { 224 return E.Cause(err, "read destination") 225 } 226 227 var paddingLen uint16 228 err = binary.Read(reader, binary.BigEndian, &paddingLen) 229 if err != nil { 230 return E.Cause(err, "read padding length") 231 } 232 233 if reader.Cached() < int(paddingLen) { 234 return ErrBadPadding 235 } else if paddingLen > 0 { 236 err = reader.Discard(int(paddingLen)) 237 if err != nil { 238 return E.Cause(err, "discard padding") 239 } 240 } else if reader.Cached() == 0 { 241 return ErrNoPadding 242 } 243 244 protocolConn := &serverConn{ 245 Service: s.Service, 246 Conn: conn, 247 uPSK: uPSK, 248 headerType: headerType, 249 requestSalt: requestSalt, 250 } 251 252 protocolConn.reader = reader 253 metadata.Protocol = "shadowsocks" 254 metadata.Destination = destination 255 return s.handler.NewConnection(auth.ContextWithUser(ctx, user), protocolConn, metadata) 256 } 257 258 func (s *MultiService[U]) WriteIsThreadUnsafe() { 259 } 260 261 func (s *MultiService[U]) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { 262 err := s.newPacket(ctx, conn, buffer, metadata) 263 if err != nil { 264 err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err} 265 } 266 return err 267 } 268 269 func (s *MultiService[U]) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { 270 if buffer.Len() < PacketMinimalHeaderSize { 271 return ErrPacketTooShort 272 } 273 274 packetHeader := buffer.To(aes.BlockSize) 275 s.udpBlockCipher.Decrypt(packetHeader, packetHeader) 276 277 var _eiHeader [aes.BlockSize]byte 278 eiHeader := _eiHeader[:] 279 s.udpBlockCipher.Decrypt(eiHeader, buffer.Range(aes.BlockSize, 2*aes.BlockSize)) 280 xorWords(eiHeader, eiHeader, packetHeader) 281 282 var user U 283 var uPSK []byte 284 if u, loaded := s.uPSKHash[_eiHeader]; loaded { 285 user = u 286 uPSK = s.uPSK[u] 287 } else { 288 return E.New("invalid request") 289 } 290 291 var sessionId, packetId uint64 292 err := binary.Read(buffer, binary.BigEndian, &sessionId) 293 if err != nil { 294 return err 295 } 296 err = binary.Read(buffer, binary.BigEndian, &packetId) 297 if err != nil { 298 return err 299 } 300 301 buffer.Advance(aes.BlockSize) 302 303 session, loaded := s.udpSessions.LoadOrStore(sessionId, func() *serverUDPSession { 304 return s.newUDPSession(uPSK) 305 }) 306 if !loaded { 307 session.remoteSessionId = sessionId 308 key := SessionKey(uPSK, packetHeader[:8], s.keySaltLength) 309 session.remoteCipher, err = s.constructor(key) 310 if err != nil { 311 return err 312 } 313 } 314 315 goto process 316 317 returnErr: 318 if !loaded { 319 s.udpSessions.Delete(sessionId) 320 } 321 return err 322 323 process: 324 if !session.window.Check(packetId) { 325 err = ErrPacketIdNotUnique 326 goto returnErr 327 } 328 329 if packetHeader != nil { 330 _, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) 331 if err != nil { 332 err = E.Cause(err, "decrypt packet") 333 goto returnErr 334 } 335 buffer.Truncate(buffer.Len() - shadowaead.Overhead) 336 } 337 338 session.window.Add(packetId) 339 340 var headerType byte 341 headerType, err = buffer.ReadByte() 342 if err != nil { 343 err = E.Cause(err, "decrypt packet") 344 goto returnErr 345 } 346 if headerType != HeaderTypeClient { 347 err = E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType) 348 goto returnErr 349 } 350 351 var epoch uint64 352 err = binary.Read(buffer, binary.BigEndian, &epoch) 353 if err != nil { 354 goto returnErr 355 } 356 diff := int(math.Abs(float64(s.time().Unix() - int64(epoch)))) 357 if diff > 30 { 358 err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") 359 goto returnErr 360 } 361 362 var paddingLen uint16 363 err = binary.Read(buffer, binary.BigEndian, &paddingLen) 364 if err != nil { 365 err = E.Cause(err, "read padding length") 366 goto returnErr 367 } 368 buffer.Advance(int(paddingLen)) 369 370 destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) 371 if err != nil { 372 goto returnErr 373 } 374 375 metadata.Protocol = "shadowsocks" 376 metadata.Destination = destination 377 s.udpNat.NewContextPacket(ctx, sessionId, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) { 378 return auth.ContextWithUser(ctx, user), &serverPacketWriter{s.Service, conn, natConn, session, s.uCipher[user]} 379 }) 380 return nil 381 } 382 383 func (s *MultiService[U]) newUDPSession(uPSK []byte) *serverUDPSession { 384 session := &serverUDPSession{} 385 if s.udpCipher != nil { 386 session.rng = Blake3KeyedHash(rand.Reader) 387 common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId)) 388 } else { 389 common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId)) 390 } 391 session.packetId-- 392 sessionId := make([]byte, 8) 393 binary.BigEndian.PutUint64(sessionId, session.sessionId) 394 key := SessionKey(uPSK, sessionId, s.keySaltLength) 395 var err error 396 session.cipher, err = s.constructor(key) 397 common.Must(err) 398 return session 399 }