github.com/sagernet/sing-shadowsocks@v0.2.6/shadowaead_2022/service_multi.go (about)

     1  package shadowaead_2022
     2  
     3  import (
     4  	"context"
     5  	"crypto/aes"
     6  	"crypto/cipher"
     7  	"crypto/rand"
     8  	"encoding/base64"
     9  	"encoding/binary"
    10  	"io"
    11  	"math"
    12  	"net"
    13  	"os"
    14  	"time"
    15  
    16  	"github.com/sagernet/sing-shadowsocks"
    17  	"github.com/sagernet/sing-shadowsocks/shadowaead"
    18  	"github.com/sagernet/sing/common"
    19  	"github.com/sagernet/sing/common/auth"
    20  	"github.com/sagernet/sing/common/buf"
    21  	E "github.com/sagernet/sing/common/exceptions"
    22  	M "github.com/sagernet/sing/common/metadata"
    23  	N "github.com/sagernet/sing/common/network"
    24  	"github.com/sagernet/sing/common/rw"
    25  
    26  	"lukechampine.com/blake3"
    27  )
    28  
    29  var _ shadowsocks.MultiService[int] = (*MultiService[int])(nil)
    30  
    31  type MultiService[U comparable] struct {
    32  	*Service
    33  
    34  	uPSK     map[U][]byte
    35  	uPSKHash map[[aes.BlockSize]byte]U
    36  	uCipher  map[U]cipher.Block
    37  }
    38  
    39  func NewMultiServiceWithPassword[U comparable](method string, password string, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (*MultiService[U], error) {
    40  	if password == "" {
    41  		return nil, ErrMissingPSK
    42  	}
    43  	iPSK, err := base64.StdEncoding.DecodeString(password)
    44  	if err != nil {
    45  		return nil, E.Cause(err, "decode psk")
    46  	}
    47  	return NewMultiService[U](method, iPSK, udpTimeout, handler, timeFunc)
    48  }
    49  
    50  func NewMultiService[U comparable](method string, iPSK []byte, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (*MultiService[U], error) {
    51  	switch method {
    52  	case "2022-blake3-aes-128-gcm":
    53  	case "2022-blake3-aes-256-gcm":
    54  	default:
    55  		return nil, os.ErrInvalid
    56  	}
    57  
    58  	ss, err := NewService(method, iPSK, udpTimeout, handler, timeFunc)
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  
    63  	s := &MultiService[U]{
    64  		Service: ss.(*Service),
    65  
    66  		uPSK:     make(map[U][]byte),
    67  		uPSKHash: make(map[[aes.BlockSize]byte]U),
    68  	}
    69  	return s, nil
    70  }
    71  
    72  func (s *MultiService[U]) UpdateUsers(userList []U, keyList [][]byte) error {
    73  	uPSK := make(map[U][]byte)
    74  	uPSKHash := make(map[[aes.BlockSize]byte]U)
    75  	uCipher := make(map[U]cipher.Block)
    76  	for i, user := range userList {
    77  		key := keyList[i]
    78  		if len(key) < s.keySaltLength {
    79  			return shadowsocks.ErrBadKey
    80  		} else if len(key) > s.keySaltLength {
    81  			key = Key(key, s.keySaltLength)
    82  		}
    83  
    84  		var hash [aes.BlockSize]byte
    85  		hash512 := blake3.Sum512(key)
    86  		copy(hash[:], hash512[:])
    87  
    88  		uPSKHash[hash] = user
    89  		uPSK[user] = key
    90  		var err error
    91  		uCipher[user], err = s.blockConstructor(key)
    92  		if err != nil {
    93  			return err
    94  		}
    95  	}
    96  
    97  	s.uPSK = uPSK
    98  	s.uPSKHash = uPSKHash
    99  	s.uCipher = uCipher
   100  	return nil
   101  }
   102  
   103  func (s *MultiService[U]) UpdateUsersWithPasswords(userList []U, passwordList []string) error {
   104  	keyList := make([][]byte, 0, len(passwordList))
   105  	for _, password := range passwordList {
   106  		if password == "" {
   107  			return shadowsocks.ErrMissingPassword
   108  		}
   109  		uPSK, err := base64.StdEncoding.DecodeString(password)
   110  		if err != nil {
   111  			return E.Cause(err, "decode psk")
   112  		}
   113  		keyList = append(keyList, uPSK)
   114  	}
   115  	return s.UpdateUsers(userList, keyList)
   116  }
   117  
   118  func (s *MultiService[U]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
   119  	err := s.NewConnection0(ctx, conn, metadata, conn, nil)
   120  	if err != nil {
   121  		err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err}
   122  	}
   123  	return err
   124  }
   125  
   126  func (s *MultiService[U]) NewConnection0(ctx context.Context, conn net.Conn, metadata M.Metadata, handshakeReader io.Reader, handshakeSuccess func()) error {
   127  	requestHeader := make([]byte, s.keySaltLength+aes.BlockSize+shadowaead.Overhead+RequestHeaderFixedChunkLength)
   128  	var (
   129  		n   int
   130  		err error
   131  	)
   132  	if handshakeSuccess != nil {
   133  		n, err = io.ReadFull(handshakeReader, requestHeader)
   134  	} else {
   135  		n, err = handshakeReader.Read(requestHeader)
   136  	}
   137  	if err != nil {
   138  		return err
   139  	} else if n < len(requestHeader) {
   140  		return shadowaead.ErrBadHeader
   141  	}
   142  	requestSalt := requestHeader[:s.keySaltLength]
   143  	if !s.replayFilter.Check(requestSalt) {
   144  		return ErrSaltNotUnique
   145  	}
   146  
   147  	var _eiHeader [aes.BlockSize]byte
   148  	eiHeader := _eiHeader[:]
   149  	copy(eiHeader, requestHeader[s.keySaltLength:s.keySaltLength+aes.BlockSize])
   150  
   151  	keyMaterial := make([]byte, s.keySaltLength*2)
   152  	copy(keyMaterial, s.psk)
   153  	copy(keyMaterial[s.keySaltLength:], requestSalt)
   154  	identitySubkey := buf.NewSize(s.keySaltLength)
   155  	identitySubkey.Extend(identitySubkey.FreeLen())
   156  	blake3.DeriveKey(identitySubkey.Bytes(), "shadowsocks 2022 identity subkey", keyMaterial)
   157  	b, err := s.blockConstructor(identitySubkey.Bytes())
   158  	identitySubkey.Release()
   159  	if err != nil {
   160  		return err
   161  	}
   162  	b.Decrypt(eiHeader, eiHeader)
   163  
   164  	var user U
   165  	var uPSK []byte
   166  	if u, loaded := s.uPSKHash[_eiHeader]; loaded {
   167  		user = u
   168  		uPSK = s.uPSK[u]
   169  	} else {
   170  		return ErrInvalidRequest
   171  	}
   172  
   173  	if handshakeSuccess != nil {
   174  		handshakeSuccess()
   175  	}
   176  
   177  	requestKey := SessionKey(uPSK, requestSalt, s.keySaltLength)
   178  	readCipher, err := s.constructor(requestKey)
   179  	if err != nil {
   180  		return err
   181  	}
   182  	reader := shadowaead.NewReader(
   183  		conn,
   184  		readCipher,
   185  		MaxPacketSize,
   186  	)
   187  
   188  	err = reader.ReadExternalChunk(requestHeader[s.keySaltLength+aes.BlockSize:])
   189  	if err != nil {
   190  		return err
   191  	}
   192  
   193  	headerType, err := rw.ReadByte(reader)
   194  	if err != nil {
   195  		return E.Cause(err, "read header")
   196  	}
   197  
   198  	if headerType != HeaderTypeClient {
   199  		return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
   200  	}
   201  
   202  	var epoch uint64
   203  	err = binary.Read(reader, binary.BigEndian, &epoch)
   204  	if err != nil {
   205  		return E.Cause(err, "read timestamp")
   206  	}
   207  	diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
   208  	if diff > 30 {
   209  		return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
   210  	}
   211  	var length uint16
   212  	err = binary.Read(reader, binary.BigEndian, &length)
   213  	if err != nil {
   214  		return E.Cause(err, "read length")
   215  	}
   216  
   217  	err = reader.ReadWithLength(length)
   218  	if err != nil {
   219  		return err
   220  	}
   221  
   222  	destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
   223  	if err != nil {
   224  		return E.Cause(err, "read destination")
   225  	}
   226  
   227  	var paddingLen uint16
   228  	err = binary.Read(reader, binary.BigEndian, &paddingLen)
   229  	if err != nil {
   230  		return E.Cause(err, "read padding length")
   231  	}
   232  
   233  	if reader.Cached() < int(paddingLen) {
   234  		return ErrBadPadding
   235  	} else if paddingLen > 0 {
   236  		err = reader.Discard(int(paddingLen))
   237  		if err != nil {
   238  			return E.Cause(err, "discard padding")
   239  		}
   240  	} else if reader.Cached() == 0 {
   241  		return ErrNoPadding
   242  	}
   243  
   244  	protocolConn := &serverConn{
   245  		Service:     s.Service,
   246  		Conn:        conn,
   247  		uPSK:        uPSK,
   248  		headerType:  headerType,
   249  		requestSalt: requestSalt,
   250  	}
   251  
   252  	protocolConn.reader = reader
   253  	metadata.Protocol = "shadowsocks"
   254  	metadata.Destination = destination
   255  	return s.handler.NewConnection(auth.ContextWithUser(ctx, user), protocolConn, metadata)
   256  }
   257  
   258  func (s *MultiService[U]) WriteIsThreadUnsafe() {
   259  }
   260  
   261  func (s *MultiService[U]) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
   262  	err := s.newPacket(ctx, conn, buffer, metadata)
   263  	if err != nil {
   264  		err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err}
   265  	}
   266  	return err
   267  }
   268  
   269  func (s *MultiService[U]) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
   270  	if buffer.Len() < PacketMinimalHeaderSize {
   271  		return ErrPacketTooShort
   272  	}
   273  
   274  	packetHeader := buffer.To(aes.BlockSize)
   275  	s.udpBlockCipher.Decrypt(packetHeader, packetHeader)
   276  
   277  	var _eiHeader [aes.BlockSize]byte
   278  	eiHeader := _eiHeader[:]
   279  	s.udpBlockCipher.Decrypt(eiHeader, buffer.Range(aes.BlockSize, 2*aes.BlockSize))
   280  	xorWords(eiHeader, eiHeader, packetHeader)
   281  
   282  	var user U
   283  	var uPSK []byte
   284  	if u, loaded := s.uPSKHash[_eiHeader]; loaded {
   285  		user = u
   286  		uPSK = s.uPSK[u]
   287  	} else {
   288  		return E.New("invalid request")
   289  	}
   290  
   291  	var sessionId, packetId uint64
   292  	err := binary.Read(buffer, binary.BigEndian, &sessionId)
   293  	if err != nil {
   294  		return err
   295  	}
   296  	err = binary.Read(buffer, binary.BigEndian, &packetId)
   297  	if err != nil {
   298  		return err
   299  	}
   300  
   301  	buffer.Advance(aes.BlockSize)
   302  
   303  	session, loaded := s.udpSessions.LoadOrStore(sessionId, func() *serverUDPSession {
   304  		return s.newUDPSession(uPSK)
   305  	})
   306  	if !loaded {
   307  		session.remoteSessionId = sessionId
   308  		key := SessionKey(uPSK, packetHeader[:8], s.keySaltLength)
   309  		session.remoteCipher, err = s.constructor(key)
   310  		if err != nil {
   311  			return err
   312  		}
   313  	}
   314  
   315  	goto process
   316  
   317  returnErr:
   318  	if !loaded {
   319  		s.udpSessions.Delete(sessionId)
   320  	}
   321  	return err
   322  
   323  process:
   324  	if !session.window.Check(packetId) {
   325  		err = ErrPacketIdNotUnique
   326  		goto returnErr
   327  	}
   328  
   329  	if packetHeader != nil {
   330  		_, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
   331  		if err != nil {
   332  			err = E.Cause(err, "decrypt packet")
   333  			goto returnErr
   334  		}
   335  		buffer.Truncate(buffer.Len() - shadowaead.Overhead)
   336  	}
   337  
   338  	session.window.Add(packetId)
   339  
   340  	var headerType byte
   341  	headerType, err = buffer.ReadByte()
   342  	if err != nil {
   343  		err = E.Cause(err, "decrypt packet")
   344  		goto returnErr
   345  	}
   346  	if headerType != HeaderTypeClient {
   347  		err = E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
   348  		goto returnErr
   349  	}
   350  
   351  	var epoch uint64
   352  	err = binary.Read(buffer, binary.BigEndian, &epoch)
   353  	if err != nil {
   354  		goto returnErr
   355  	}
   356  	diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
   357  	if diff > 30 {
   358  		err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
   359  		goto returnErr
   360  	}
   361  
   362  	var paddingLen uint16
   363  	err = binary.Read(buffer, binary.BigEndian, &paddingLen)
   364  	if err != nil {
   365  		err = E.Cause(err, "read padding length")
   366  		goto returnErr
   367  	}
   368  	buffer.Advance(int(paddingLen))
   369  
   370  	destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
   371  	if err != nil {
   372  		goto returnErr
   373  	}
   374  
   375  	metadata.Protocol = "shadowsocks"
   376  	metadata.Destination = destination
   377  	s.udpNat.NewContextPacket(ctx, sessionId, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
   378  		return auth.ContextWithUser(ctx, user), &serverPacketWriter{s.Service, conn, natConn, session, s.uCipher[user]}
   379  	})
   380  	return nil
   381  }
   382  
   383  func (s *MultiService[U]) newUDPSession(uPSK []byte) *serverUDPSession {
   384  	session := &serverUDPSession{}
   385  	if s.udpCipher != nil {
   386  		session.rng = Blake3KeyedHash(rand.Reader)
   387  		common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId))
   388  	} else {
   389  		common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId))
   390  	}
   391  	session.packetId--
   392  	sessionId := make([]byte, 8)
   393  	binary.BigEndian.PutUint64(sessionId, session.sessionId)
   394  	key := SessionKey(uPSK, sessionId, s.keySaltLength)
   395  	var err error
   396  	session.cipher, err = s.constructor(key)
   397  	common.Must(err)
   398  	return session
   399  }