github.com/MerlinKodo/sing-shadowsocks@v0.2.6/shadowaead_2022/service_multi.go (about) 1 package shadowaead_2022 2 3 import ( 4 "context" 5 "crypto/aes" 6 "crypto/cipher" 7 "crypto/rand" 8 "encoding/base64" 9 "encoding/binary" 10 "math" 11 "net" 12 "os" 13 "time" 14 15 shadowsocks "github.com/MerlinKodo/sing-shadowsocks" 16 "github.com/MerlinKodo/sing-shadowsocks/shadowaead" 17 "github.com/sagernet/sing/common" 18 "github.com/sagernet/sing/common/auth" 19 "github.com/sagernet/sing/common/buf" 20 E "github.com/sagernet/sing/common/exceptions" 21 M "github.com/sagernet/sing/common/metadata" 22 N "github.com/sagernet/sing/common/network" 23 "github.com/sagernet/sing/common/rw" 24 25 "lukechampine.com/blake3" 26 ) 27 28 var _ shadowsocks.MultiService[int] = (*MultiService[int])(nil) 29 30 type MultiService[U comparable] struct { 31 *Service 32 33 uPSK map[U][]byte 34 uPSKHash map[[aes.BlockSize]byte]U 35 uCipher map[U]cipher.Block 36 } 37 38 func NewMultiServiceWithPassword[U comparable](method string, password string, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (*MultiService[U], error) { 39 if password == "" { 40 return nil, ErrMissingPSK 41 } 42 iPSK, err := base64.StdEncoding.DecodeString(password) 43 if err != nil { 44 return nil, E.Cause(err, "decode psk") 45 } 46 return NewMultiService[U](method, iPSK, udpTimeout, handler, timeFunc) 47 } 48 49 func NewMultiService[U comparable](method string, iPSK []byte, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (*MultiService[U], error) { 50 switch method { 51 case "2022-blake3-aes-128-gcm": 52 case "2022-blake3-aes-256-gcm": 53 default: 54 return nil, os.ErrInvalid 55 } 56 57 ss, err := NewService(method, iPSK, udpTimeout, handler, timeFunc) 58 if err != nil { 59 return nil, err 60 } 61 62 s := &MultiService[U]{ 63 Service: ss.(*Service), 64 65 uPSK: make(map[U][]byte), 66 uPSKHash: make(map[[aes.BlockSize]byte]U), 67 } 68 return s, nil 69 } 70 71 func (s *MultiService[U]) UpdateUsers(userList []U, keyList [][]byte) error { 72 uPSK := make(map[U][]byte) 73 uPSKHash := make(map[[aes.BlockSize]byte]U) 74 uCipher := make(map[U]cipher.Block) 75 for i, user := range userList { 76 key := keyList[i] 77 if len(key) < s.keySaltLength { 78 return shadowsocks.ErrBadKey 79 } else if len(key) > s.keySaltLength { 80 key = Key(key, s.keySaltLength) 81 } 82 83 var hash [aes.BlockSize]byte 84 hash512 := blake3.Sum512(key) 85 copy(hash[:], hash512[:]) 86 87 uPSKHash[hash] = user 88 uPSK[user] = key 89 var err error 90 uCipher[user], err = s.blockConstructor(key) 91 if err != nil { 92 return err 93 } 94 } 95 96 s.uPSK = uPSK 97 s.uPSKHash = uPSKHash 98 s.uCipher = uCipher 99 return nil 100 } 101 102 func (s *MultiService[U]) UpdateUsersWithPasswords(userList []U, passwordList []string) error { 103 keyList := make([][]byte, 0, len(passwordList)) 104 for _, password := range passwordList { 105 if password == "" { 106 return shadowsocks.ErrMissingPassword 107 } 108 uPSK, err := base64.StdEncoding.DecodeString(password) 109 if err != nil { 110 return E.Cause(err, "decode psk") 111 } 112 keyList = append(keyList, uPSK) 113 } 114 return s.UpdateUsers(userList, keyList) 115 } 116 117 func (s *MultiService[U]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { 118 err := s.newConnection(ctx, conn, metadata) 119 if err != nil { 120 err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err} 121 } 122 return err 123 } 124 125 func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { 126 requestHeader := make([]byte, s.keySaltLength+aes.BlockSize+shadowaead.Overhead+RequestHeaderFixedChunkLength) 127 n, err := conn.Read(requestHeader) 128 if err != nil { 129 return err 130 } else if n < len(requestHeader) { 131 return shadowaead.ErrBadHeader 132 } 133 requestSalt := requestHeader[:s.keySaltLength] 134 if !s.replayFilter.Check(requestSalt) { 135 return ErrSaltNotUnique 136 } 137 138 var _eiHeader [aes.BlockSize]byte 139 eiHeader := _eiHeader[:] 140 copy(eiHeader, requestHeader[s.keySaltLength:s.keySaltLength+aes.BlockSize]) 141 142 keyMaterial := make([]byte, s.keySaltLength*2) 143 copy(keyMaterial, s.psk) 144 copy(keyMaterial[s.keySaltLength:], requestSalt) 145 identitySubkey := buf.NewSize(s.keySaltLength) 146 identitySubkey.Extend(identitySubkey.FreeLen()) 147 blake3.DeriveKey(identitySubkey.Bytes(), "shadowsocks 2022 identity subkey", keyMaterial) 148 b, err := s.blockConstructor(identitySubkey.Bytes()) 149 identitySubkey.Release() 150 if err != nil { 151 return err 152 } 153 b.Decrypt(eiHeader, eiHeader) 154 155 var user U 156 var uPSK []byte 157 if u, loaded := s.uPSKHash[_eiHeader]; loaded { 158 user = u 159 uPSK = s.uPSK[u] 160 } else { 161 return E.New("invalid request") 162 } 163 164 requestKey := SessionKey(uPSK, requestSalt, s.keySaltLength) 165 readCipher, err := s.constructor(requestKey) 166 if err != nil { 167 return err 168 } 169 reader := shadowaead.NewReader( 170 conn, 171 readCipher, 172 MaxPacketSize, 173 ) 174 175 err = reader.ReadExternalChunk(requestHeader[s.keySaltLength+aes.BlockSize:]) 176 if err != nil { 177 return err 178 } 179 180 headerType, err := rw.ReadByte(reader) 181 if err != nil { 182 return E.Cause(err, "read header") 183 } 184 185 if headerType != HeaderTypeClient { 186 return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType) 187 } 188 189 var epoch uint64 190 err = binary.Read(reader, binary.BigEndian, &epoch) 191 if err != nil { 192 return E.Cause(err, "read timestamp") 193 } 194 diff := int(math.Abs(float64(s.time().Unix() - int64(epoch)))) 195 if diff > 30 { 196 return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") 197 } 198 var length uint16 199 err = binary.Read(reader, binary.BigEndian, &length) 200 if err != nil { 201 return E.Cause(err, "read length") 202 } 203 204 err = reader.ReadWithLength(length) 205 if err != nil { 206 return err 207 } 208 209 destination, err := M.SocksaddrSerializer.ReadAddrPort(reader) 210 if err != nil { 211 return E.Cause(err, "read destination") 212 } 213 214 var paddingLen uint16 215 err = binary.Read(reader, binary.BigEndian, &paddingLen) 216 if err != nil { 217 return E.Cause(err, "read padding length") 218 } 219 220 if reader.Cached() < int(paddingLen) { 221 return ErrBadPadding 222 } else if paddingLen > 0 { 223 err = reader.Discard(int(paddingLen)) 224 if err != nil { 225 return E.Cause(err, "discard padding") 226 } 227 } else if reader.Cached() == 0 { 228 return ErrNoPadding 229 } 230 231 protocolConn := &serverConn{ 232 Service: s.Service, 233 Conn: conn, 234 uPSK: uPSK, 235 headerType: headerType, 236 requestSalt: requestSalt, 237 } 238 239 protocolConn.reader = reader 240 metadata.Protocol = "shadowsocks" 241 metadata.Destination = destination 242 return s.handler.NewConnection(auth.ContextWithUser(ctx, user), protocolConn, metadata) 243 } 244 245 func (s *MultiService[U]) WriteIsThreadUnsafe() { 246 } 247 248 func (s *MultiService[U]) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { 249 err := s.newPacket(ctx, conn, buffer, metadata) 250 if err != nil { 251 err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err} 252 } 253 return err 254 } 255 256 func (s *MultiService[U]) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { 257 if buffer.Len() < PacketMinimalHeaderSize { 258 return ErrPacketTooShort 259 } 260 261 packetHeader := buffer.To(aes.BlockSize) 262 s.udpBlockCipher.Decrypt(packetHeader, packetHeader) 263 264 var _eiHeader [aes.BlockSize]byte 265 eiHeader := _eiHeader[:] 266 s.udpBlockCipher.Decrypt(eiHeader, buffer.Range(aes.BlockSize, 2*aes.BlockSize)) 267 xorWords(eiHeader, eiHeader, packetHeader) 268 269 var user U 270 var uPSK []byte 271 if u, loaded := s.uPSKHash[_eiHeader]; loaded { 272 user = u 273 uPSK = s.uPSK[u] 274 } else { 275 return E.New("invalid request") 276 } 277 278 var sessionId, packetId uint64 279 err := binary.Read(buffer, binary.BigEndian, &sessionId) 280 if err != nil { 281 return err 282 } 283 err = binary.Read(buffer, binary.BigEndian, &packetId) 284 if err != nil { 285 return err 286 } 287 288 buffer.Advance(aes.BlockSize) 289 290 session, loaded := s.udpSessions.LoadOrStore(sessionId, func() *serverUDPSession { 291 return s.newUDPSession(uPSK) 292 }) 293 if !loaded { 294 session.remoteSessionId = sessionId 295 key := SessionKey(uPSK, packetHeader[:8], s.keySaltLength) 296 session.remoteCipher, err = s.constructor(key) 297 if err != nil { 298 return err 299 } 300 } 301 302 goto process 303 304 returnErr: 305 if !loaded { 306 s.udpSessions.Delete(sessionId) 307 } 308 return err 309 310 process: 311 if !session.window.Check(packetId) { 312 err = ErrPacketIdNotUnique 313 goto returnErr 314 } 315 316 if packetHeader != nil { 317 _, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) 318 if err != nil { 319 err = E.Cause(err, "decrypt packet") 320 goto returnErr 321 } 322 buffer.Truncate(buffer.Len() - shadowaead.Overhead) 323 } 324 325 session.window.Add(packetId) 326 327 var headerType byte 328 headerType, err = buffer.ReadByte() 329 if err != nil { 330 err = E.Cause(err, "decrypt packet") 331 goto returnErr 332 } 333 if headerType != HeaderTypeClient { 334 err = E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType) 335 goto returnErr 336 } 337 338 var epoch uint64 339 err = binary.Read(buffer, binary.BigEndian, &epoch) 340 if err != nil { 341 goto returnErr 342 } 343 diff := int(math.Abs(float64(s.time().Unix() - int64(epoch)))) 344 if diff > 30 { 345 err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") 346 goto returnErr 347 } 348 349 var paddingLen uint16 350 err = binary.Read(buffer, binary.BigEndian, &paddingLen) 351 if err != nil { 352 err = E.Cause(err, "read padding length") 353 goto returnErr 354 } 355 buffer.Advance(int(paddingLen)) 356 357 destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) 358 if err != nil { 359 goto returnErr 360 } 361 362 metadata.Protocol = "shadowsocks" 363 metadata.Destination = destination 364 s.udpNat.NewContextPacket(ctx, sessionId, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) { 365 return auth.ContextWithUser(ctx, user), &serverPacketWriter{s.Service, conn, natConn, session, s.uCipher[user]} 366 }) 367 return nil 368 } 369 370 func (s *MultiService[U]) newUDPSession(uPSK []byte) *serverUDPSession { 371 session := &serverUDPSession{} 372 if s.udpCipher != nil { 373 session.rng = Blake3KeyedHash(rand.Reader) 374 common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId)) 375 } else { 376 common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId)) 377 } 378 session.packetId-- 379 sessionId := make([]byte, 8) 380 binary.BigEndian.PutUint64(sessionId, session.sessionId) 381 key := SessionKey(uPSK, sessionId, s.keySaltLength) 382 var err error 383 session.cipher, err = s.constructor(key) 384 common.Must(err) 385 return session 386 }