github.com/MicalKarl/sing-shadowsocks@v0.0.5/shadowaead_2022/service_multi.go (about) 1 package shadowaead_2022 2 3 import ( 4 "context" 5 "crypto/aes" 6 "crypto/cipher" 7 "crypto/rand" 8 "encoding/base64" 9 "encoding/binary" 10 "math" 11 "net" 12 "os" 13 "time" 14 15 shadowsocks "github.com/MicalKarl/sing-shadowsocks" 16 "github.com/MicalKarl/sing-shadowsocks/shadowaead" 17 "github.com/MicalKarl/sing-shadowsocks/ssv" 18 "github.com/sagernet/sing/common" 19 "github.com/sagernet/sing/common/auth" 20 "github.com/sagernet/sing/common/buf" 21 E "github.com/sagernet/sing/common/exceptions" 22 M "github.com/sagernet/sing/common/metadata" 23 N "github.com/sagernet/sing/common/network" 24 "github.com/sagernet/sing/common/rw" 25 26 "lukechampine.com/blake3" 27 ) 28 29 var _ shadowsocks.MultiService[int] = (*MultiService[int])(nil) 30 31 type MultiService[U comparable] struct { 32 *Service 33 34 uPSK map[U][]byte 35 uPSKHash map[[aes.BlockSize]byte]U 36 uCipher map[U]cipher.Block 37 } 38 39 func NewMultiServiceWithPassword[U comparable](method string, password string, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (*MultiService[U], error) { 40 if password == "" { 41 return nil, ErrMissingPSK 42 } 43 iPSK, err := base64.StdEncoding.DecodeString(password) 44 if err != nil { 45 return nil, E.Cause(err, "decode psk") 46 } 47 return NewMultiService[U](method, iPSK, udpTimeout, handler, timeFunc) 48 } 49 50 func NewMultiService[U comparable](method string, iPSK []byte, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (*MultiService[U], error) { 51 switch method { 52 case "2022-blake3-aes-128-gcm": 53 case "2022-blake3-aes-256-gcm": 54 default: 55 return nil, os.ErrInvalid 56 } 57 58 ss, err := NewService(method, iPSK, udpTimeout, handler, timeFunc) 59 if err != nil { 60 return nil, err 61 } 62 63 s := &MultiService[U]{ 64 Service: ss.(*Service), 65 66 uPSK: make(map[U][]byte), 67 uPSKHash: make(map[[aes.BlockSize]byte]U), 68 } 69 return s, nil 70 } 71 72 func (s *MultiService[U]) UpdateUsers(userList []U, keyList [][]byte) error { 73 uPSK := make(map[U][]byte) 74 uPSKHash := make(map[[aes.BlockSize]byte]U) 75 uCipher := make(map[U]cipher.Block) 76 for i, user := range userList { 77 key := keyList[i] 78 if len(key) < s.keySaltLength { 79 return shadowsocks.ErrBadKey 80 } else if len(key) > s.keySaltLength { 81 key = Key(key, s.keySaltLength) 82 } 83 84 var hash [aes.BlockSize]byte 85 hash512 := blake3.Sum512(key) 86 copy(hash[:], hash512[:]) 87 88 uPSKHash[hash] = user 89 uPSK[user] = key 90 var err error 91 uCipher[user], err = s.blockConstructor(key) 92 if err != nil { 93 return err 94 } 95 } 96 97 s.uPSK = uPSK 98 s.uPSKHash = uPSKHash 99 s.uCipher = uCipher 100 return nil 101 } 102 103 func (s *MultiService[U]) UpdateUsersWithPasswords(userList []U, passwordList []string) error { 104 keyList := make([][]byte, 0, len(passwordList)) 105 for _, password := range passwordList { 106 if password == "" { 107 return shadowsocks.ErrMissingPassword 108 } 109 uPSK, err := base64.StdEncoding.DecodeString(password) 110 if err != nil { 111 return E.Cause(err, "decode psk") 112 } 113 keyList = append(keyList, uPSK) 114 } 115 return s.UpdateUsers(userList, keyList) 116 } 117 118 func (s *MultiService[U]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { 119 err := s.newConnection(ctx, conn, metadata) 120 if err != nil { 121 err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err} 122 } 123 return err 124 } 125 126 func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { 127 requestHeader := make([]byte, s.keySaltLength+aes.BlockSize+shadowaead.Overhead+RequestHeaderFixedChunkLength) 128 n, err := conn.Read(requestHeader) 129 if err != nil { 130 return err 131 } else if n < len(requestHeader) { 132 return shadowaead.ErrBadHeader 133 } 134 requestSalt := requestHeader[:s.keySaltLength] 135 if !s.replayFilter.Check(requestSalt) { 136 return ErrSaltNotUnique 137 } 138 139 var _eiHeader [aes.BlockSize]byte 140 eiHeader := _eiHeader[:] 141 copy(eiHeader, requestHeader[s.keySaltLength:s.keySaltLength+aes.BlockSize]) 142 143 keyMaterial := make([]byte, s.keySaltLength*2) 144 copy(keyMaterial, s.psk) 145 copy(keyMaterial[s.keySaltLength:], requestSalt) 146 identitySubkey := buf.NewSize(s.keySaltLength) 147 identitySubkey.Extend(identitySubkey.FreeLen()) 148 blake3.DeriveKey(identitySubkey.Bytes(), "shadowsocks 2022 identity subkey", keyMaterial) 149 b, err := s.blockConstructor(identitySubkey.Bytes()) 150 identitySubkey.Release() 151 if err != nil { 152 return err 153 } 154 b.Decrypt(eiHeader, eiHeader) 155 156 var user U 157 var uPSK []byte 158 if u, loaded := s.uPSKHash[_eiHeader]; loaded { 159 user = u 160 uPSK = s.uPSK[u] 161 } else { 162 return E.New("invalid request") 163 } 164 165 requestKey := SessionKey(uPSK, requestSalt, s.keySaltLength) 166 readCipher, err := s.constructor(requestKey) 167 if err != nil { 168 return err 169 } 170 reader := shadowaead.NewReader( 171 conn, 172 readCipher, 173 MaxPacketSize, 174 ) 175 176 err = reader.ReadExternalChunk(requestHeader[s.keySaltLength+aes.BlockSize:]) 177 if err != nil { 178 return err 179 } 180 181 headerType, err := rw.ReadByte(reader) 182 if err != nil { 183 return E.Cause(err, "read header") 184 } 185 186 if headerType != HeaderTypeClient { 187 return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType) 188 } 189 190 var epoch uint64 191 err = binary.Read(reader, binary.BigEndian, &epoch) 192 if err != nil { 193 return E.Cause(err, "read timestamp") 194 } 195 diff := int(math.Abs(float64(s.time().Unix() - int64(epoch)))) 196 if diff > 30 { 197 return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") 198 } 199 var length uint16 200 err = binary.Read(reader, binary.BigEndian, &length) 201 if err != nil { 202 return E.Cause(err, "read length") 203 } 204 205 err = reader.ReadWithLength(length) 206 if err != nil { 207 return err 208 } 209 210 destination, err := ssv.FakeSockSerializer.ReadAddrPort(reader) 211 if err != nil { 212 return E.Cause(err, "read destination") 213 } 214 215 var paddingLen uint16 216 err = binary.Read(reader, binary.BigEndian, &paddingLen) 217 if err != nil { 218 return E.Cause(err, "read padding length") 219 } 220 221 if reader.Cached() < int(paddingLen) { 222 return ErrBadPadding 223 } else if paddingLen > 0 { 224 err = reader.Discard(int(paddingLen)) 225 if err != nil { 226 return E.Cause(err, "discard padding") 227 } 228 } else if reader.Cached() == 0 { 229 return ErrNoPadding 230 } 231 232 protocolConn := &serverConn{ 233 Service: s.Service, 234 Conn: conn, 235 uPSK: uPSK, 236 headerType: headerType, 237 requestSalt: requestSalt, 238 } 239 240 protocolConn.reader = reader 241 metadata.Protocol = "shadowsocks" 242 metadata.Destination = destination 243 return s.handler.NewConnection(auth.ContextWithUser(ctx, user), protocolConn, metadata) 244 } 245 246 func (s *MultiService[U]) WriteIsThreadUnsafe() { 247 } 248 249 func (s *MultiService[U]) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { 250 err := s.newPacket(ctx, conn, buffer, metadata) 251 if err != nil { 252 err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err} 253 } 254 return err 255 } 256 257 func (s *MultiService[U]) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { 258 if buffer.Len() < PacketMinimalHeaderSize { 259 return ErrPacketTooShort 260 } 261 262 packetHeader := buffer.To(aes.BlockSize) 263 s.udpBlockCipher.Decrypt(packetHeader, packetHeader) 264 265 var _eiHeader [aes.BlockSize]byte 266 eiHeader := _eiHeader[:] 267 s.udpBlockCipher.Decrypt(eiHeader, buffer.Range(aes.BlockSize, 2*aes.BlockSize)) 268 xorWords(eiHeader, eiHeader, packetHeader) 269 270 var user U 271 var uPSK []byte 272 if u, loaded := s.uPSKHash[_eiHeader]; loaded { 273 user = u 274 uPSK = s.uPSK[u] 275 } else { 276 return E.New("invalid request") 277 } 278 279 var sessionId, packetId uint64 280 err := binary.Read(buffer, binary.BigEndian, &sessionId) 281 if err != nil { 282 return err 283 } 284 err = binary.Read(buffer, binary.BigEndian, &packetId) 285 if err != nil { 286 return err 287 } 288 289 buffer.Advance(aes.BlockSize) 290 291 session, loaded := s.udpSessions.LoadOrStore(sessionId, func() *serverUDPSession { 292 return s.newUDPSession(uPSK) 293 }) 294 if !loaded { 295 session.remoteSessionId = sessionId 296 key := SessionKey(uPSK, packetHeader[:8], s.keySaltLength) 297 session.remoteCipher, err = s.constructor(key) 298 if err != nil { 299 return err 300 } 301 } 302 303 goto process 304 305 returnErr: 306 if !loaded { 307 s.udpSessions.Delete(sessionId) 308 } 309 return err 310 311 process: 312 if !session.window.Check(packetId) { 313 err = ErrPacketIdNotUnique 314 goto returnErr 315 } 316 317 if packetHeader != nil { 318 _, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil) 319 if err != nil { 320 err = E.Cause(err, "decrypt packet") 321 goto returnErr 322 } 323 buffer.Truncate(buffer.Len() - shadowaead.Overhead) 324 } 325 326 session.window.Add(packetId) 327 328 var headerType byte 329 headerType, err = buffer.ReadByte() 330 if err != nil { 331 err = E.Cause(err, "decrypt packet") 332 goto returnErr 333 } 334 if headerType != HeaderTypeClient { 335 err = E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType) 336 goto returnErr 337 } 338 339 var epoch uint64 340 err = binary.Read(buffer, binary.BigEndian, &epoch) 341 if err != nil { 342 goto returnErr 343 } 344 diff := int(math.Abs(float64(s.time().Unix() - int64(epoch)))) 345 if diff > 30 { 346 err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s") 347 goto returnErr 348 } 349 350 var paddingLen uint16 351 err = binary.Read(buffer, binary.BigEndian, &paddingLen) 352 if err != nil { 353 err = E.Cause(err, "read padding length") 354 goto returnErr 355 } 356 buffer.Advance(int(paddingLen)) 357 358 destination, err := ssv.FakeSockSerializer.ReadAddrPort(buffer) 359 if err != nil { 360 goto returnErr 361 } 362 363 metadata.Protocol = "shadowsocks" 364 metadata.Destination = destination 365 s.udpNat.NewContextPacket(ctx, sessionId, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) { 366 return auth.ContextWithUser(ctx, user), &serverPacketWriter{s.Service, conn, natConn, session, s.uCipher[user]} 367 }) 368 return nil 369 } 370 371 func (s *MultiService[U]) newUDPSession(uPSK []byte) *serverUDPSession { 372 session := &serverUDPSession{} 373 if s.udpCipher != nil { 374 session.rng = Blake3KeyedHash(rand.Reader) 375 common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId)) 376 } else { 377 common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId)) 378 } 379 session.packetId-- 380 sessionId := make([]byte, 8) 381 binary.BigEndian.PutUint64(sessionId, session.sessionId) 382 key := SessionKey(uPSK, sessionId, s.keySaltLength) 383 var err error 384 session.cipher, err = s.constructor(key) 385 common.Must(err) 386 return session 387 }