github.com/MerlinKodo/sing-shadowsocks@v0.2.6/shadowaead_2022/service_multi.go (about)

     1  package shadowaead_2022
     2  
     3  import (
     4  	"context"
     5  	"crypto/aes"
     6  	"crypto/cipher"
     7  	"crypto/rand"
     8  	"encoding/base64"
     9  	"encoding/binary"
    10  	"math"
    11  	"net"
    12  	"os"
    13  	"time"
    14  
    15  	shadowsocks "github.com/MerlinKodo/sing-shadowsocks"
    16  	"github.com/MerlinKodo/sing-shadowsocks/shadowaead"
    17  	"github.com/sagernet/sing/common"
    18  	"github.com/sagernet/sing/common/auth"
    19  	"github.com/sagernet/sing/common/buf"
    20  	E "github.com/sagernet/sing/common/exceptions"
    21  	M "github.com/sagernet/sing/common/metadata"
    22  	N "github.com/sagernet/sing/common/network"
    23  	"github.com/sagernet/sing/common/rw"
    24  
    25  	"lukechampine.com/blake3"
    26  )
    27  
    28  var _ shadowsocks.MultiService[int] = (*MultiService[int])(nil)
    29  
    30  type MultiService[U comparable] struct {
    31  	*Service
    32  
    33  	uPSK     map[U][]byte
    34  	uPSKHash map[[aes.BlockSize]byte]U
    35  	uCipher  map[U]cipher.Block
    36  }
    37  
    38  func NewMultiServiceWithPassword[U comparable](method string, password string, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (*MultiService[U], error) {
    39  	if password == "" {
    40  		return nil, ErrMissingPSK
    41  	}
    42  	iPSK, err := base64.StdEncoding.DecodeString(password)
    43  	if err != nil {
    44  		return nil, E.Cause(err, "decode psk")
    45  	}
    46  	return NewMultiService[U](method, iPSK, udpTimeout, handler, timeFunc)
    47  }
    48  
    49  func NewMultiService[U comparable](method string, iPSK []byte, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (*MultiService[U], error) {
    50  	switch method {
    51  	case "2022-blake3-aes-128-gcm":
    52  	case "2022-blake3-aes-256-gcm":
    53  	default:
    54  		return nil, os.ErrInvalid
    55  	}
    56  
    57  	ss, err := NewService(method, iPSK, udpTimeout, handler, timeFunc)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	s := &MultiService[U]{
    63  		Service: ss.(*Service),
    64  
    65  		uPSK:     make(map[U][]byte),
    66  		uPSKHash: make(map[[aes.BlockSize]byte]U),
    67  	}
    68  	return s, nil
    69  }
    70  
    71  func (s *MultiService[U]) UpdateUsers(userList []U, keyList [][]byte) error {
    72  	uPSK := make(map[U][]byte)
    73  	uPSKHash := make(map[[aes.BlockSize]byte]U)
    74  	uCipher := make(map[U]cipher.Block)
    75  	for i, user := range userList {
    76  		key := keyList[i]
    77  		if len(key) < s.keySaltLength {
    78  			return shadowsocks.ErrBadKey
    79  		} else if len(key) > s.keySaltLength {
    80  			key = Key(key, s.keySaltLength)
    81  		}
    82  
    83  		var hash [aes.BlockSize]byte
    84  		hash512 := blake3.Sum512(key)
    85  		copy(hash[:], hash512[:])
    86  
    87  		uPSKHash[hash] = user
    88  		uPSK[user] = key
    89  		var err error
    90  		uCipher[user], err = s.blockConstructor(key)
    91  		if err != nil {
    92  			return err
    93  		}
    94  	}
    95  
    96  	s.uPSK = uPSK
    97  	s.uPSKHash = uPSKHash
    98  	s.uCipher = uCipher
    99  	return nil
   100  }
   101  
   102  func (s *MultiService[U]) UpdateUsersWithPasswords(userList []U, passwordList []string) error {
   103  	keyList := make([][]byte, 0, len(passwordList))
   104  	for _, password := range passwordList {
   105  		if password == "" {
   106  			return shadowsocks.ErrMissingPassword
   107  		}
   108  		uPSK, err := base64.StdEncoding.DecodeString(password)
   109  		if err != nil {
   110  			return E.Cause(err, "decode psk")
   111  		}
   112  		keyList = append(keyList, uPSK)
   113  	}
   114  	return s.UpdateUsers(userList, keyList)
   115  }
   116  
   117  func (s *MultiService[U]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
   118  	err := s.newConnection(ctx, conn, metadata)
   119  	if err != nil {
   120  		err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err}
   121  	}
   122  	return err
   123  }
   124  
   125  func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
   126  	requestHeader := make([]byte, s.keySaltLength+aes.BlockSize+shadowaead.Overhead+RequestHeaderFixedChunkLength)
   127  	n, err := conn.Read(requestHeader)
   128  	if err != nil {
   129  		return err
   130  	} else if n < len(requestHeader) {
   131  		return shadowaead.ErrBadHeader
   132  	}
   133  	requestSalt := requestHeader[:s.keySaltLength]
   134  	if !s.replayFilter.Check(requestSalt) {
   135  		return ErrSaltNotUnique
   136  	}
   137  
   138  	var _eiHeader [aes.BlockSize]byte
   139  	eiHeader := _eiHeader[:]
   140  	copy(eiHeader, requestHeader[s.keySaltLength:s.keySaltLength+aes.BlockSize])
   141  
   142  	keyMaterial := make([]byte, s.keySaltLength*2)
   143  	copy(keyMaterial, s.psk)
   144  	copy(keyMaterial[s.keySaltLength:], requestSalt)
   145  	identitySubkey := buf.NewSize(s.keySaltLength)
   146  	identitySubkey.Extend(identitySubkey.FreeLen())
   147  	blake3.DeriveKey(identitySubkey.Bytes(), "shadowsocks 2022 identity subkey", keyMaterial)
   148  	b, err := s.blockConstructor(identitySubkey.Bytes())
   149  	identitySubkey.Release()
   150  	if err != nil {
   151  		return err
   152  	}
   153  	b.Decrypt(eiHeader, eiHeader)
   154  
   155  	var user U
   156  	var uPSK []byte
   157  	if u, loaded := s.uPSKHash[_eiHeader]; loaded {
   158  		user = u
   159  		uPSK = s.uPSK[u]
   160  	} else {
   161  		return E.New("invalid request")
   162  	}
   163  
   164  	requestKey := SessionKey(uPSK, requestSalt, s.keySaltLength)
   165  	readCipher, err := s.constructor(requestKey)
   166  	if err != nil {
   167  		return err
   168  	}
   169  	reader := shadowaead.NewReader(
   170  		conn,
   171  		readCipher,
   172  		MaxPacketSize,
   173  	)
   174  
   175  	err = reader.ReadExternalChunk(requestHeader[s.keySaltLength+aes.BlockSize:])
   176  	if err != nil {
   177  		return err
   178  	}
   179  
   180  	headerType, err := rw.ReadByte(reader)
   181  	if err != nil {
   182  		return E.Cause(err, "read header")
   183  	}
   184  
   185  	if headerType != HeaderTypeClient {
   186  		return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
   187  	}
   188  
   189  	var epoch uint64
   190  	err = binary.Read(reader, binary.BigEndian, &epoch)
   191  	if err != nil {
   192  		return E.Cause(err, "read timestamp")
   193  	}
   194  	diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
   195  	if diff > 30 {
   196  		return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
   197  	}
   198  	var length uint16
   199  	err = binary.Read(reader, binary.BigEndian, &length)
   200  	if err != nil {
   201  		return E.Cause(err, "read length")
   202  	}
   203  
   204  	err = reader.ReadWithLength(length)
   205  	if err != nil {
   206  		return err
   207  	}
   208  
   209  	destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
   210  	if err != nil {
   211  		return E.Cause(err, "read destination")
   212  	}
   213  
   214  	var paddingLen uint16
   215  	err = binary.Read(reader, binary.BigEndian, &paddingLen)
   216  	if err != nil {
   217  		return E.Cause(err, "read padding length")
   218  	}
   219  
   220  	if reader.Cached() < int(paddingLen) {
   221  		return ErrBadPadding
   222  	} else if paddingLen > 0 {
   223  		err = reader.Discard(int(paddingLen))
   224  		if err != nil {
   225  			return E.Cause(err, "discard padding")
   226  		}
   227  	} else if reader.Cached() == 0 {
   228  		return ErrNoPadding
   229  	}
   230  
   231  	protocolConn := &serverConn{
   232  		Service:     s.Service,
   233  		Conn:        conn,
   234  		uPSK:        uPSK,
   235  		headerType:  headerType,
   236  		requestSalt: requestSalt,
   237  	}
   238  
   239  	protocolConn.reader = reader
   240  	metadata.Protocol = "shadowsocks"
   241  	metadata.Destination = destination
   242  	return s.handler.NewConnection(auth.ContextWithUser(ctx, user), protocolConn, metadata)
   243  }
   244  
   245  func (s *MultiService[U]) WriteIsThreadUnsafe() {
   246  }
   247  
   248  func (s *MultiService[U]) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
   249  	err := s.newPacket(ctx, conn, buffer, metadata)
   250  	if err != nil {
   251  		err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err}
   252  	}
   253  	return err
   254  }
   255  
   256  func (s *MultiService[U]) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
   257  	if buffer.Len() < PacketMinimalHeaderSize {
   258  		return ErrPacketTooShort
   259  	}
   260  
   261  	packetHeader := buffer.To(aes.BlockSize)
   262  	s.udpBlockCipher.Decrypt(packetHeader, packetHeader)
   263  
   264  	var _eiHeader [aes.BlockSize]byte
   265  	eiHeader := _eiHeader[:]
   266  	s.udpBlockCipher.Decrypt(eiHeader, buffer.Range(aes.BlockSize, 2*aes.BlockSize))
   267  	xorWords(eiHeader, eiHeader, packetHeader)
   268  
   269  	var user U
   270  	var uPSK []byte
   271  	if u, loaded := s.uPSKHash[_eiHeader]; loaded {
   272  		user = u
   273  		uPSK = s.uPSK[u]
   274  	} else {
   275  		return E.New("invalid request")
   276  	}
   277  
   278  	var sessionId, packetId uint64
   279  	err := binary.Read(buffer, binary.BigEndian, &sessionId)
   280  	if err != nil {
   281  		return err
   282  	}
   283  	err = binary.Read(buffer, binary.BigEndian, &packetId)
   284  	if err != nil {
   285  		return err
   286  	}
   287  
   288  	buffer.Advance(aes.BlockSize)
   289  
   290  	session, loaded := s.udpSessions.LoadOrStore(sessionId, func() *serverUDPSession {
   291  		return s.newUDPSession(uPSK)
   292  	})
   293  	if !loaded {
   294  		session.remoteSessionId = sessionId
   295  		key := SessionKey(uPSK, packetHeader[:8], s.keySaltLength)
   296  		session.remoteCipher, err = s.constructor(key)
   297  		if err != nil {
   298  			return err
   299  		}
   300  	}
   301  
   302  	goto process
   303  
   304  returnErr:
   305  	if !loaded {
   306  		s.udpSessions.Delete(sessionId)
   307  	}
   308  	return err
   309  
   310  process:
   311  	if !session.window.Check(packetId) {
   312  		err = ErrPacketIdNotUnique
   313  		goto returnErr
   314  	}
   315  
   316  	if packetHeader != nil {
   317  		_, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
   318  		if err != nil {
   319  			err = E.Cause(err, "decrypt packet")
   320  			goto returnErr
   321  		}
   322  		buffer.Truncate(buffer.Len() - shadowaead.Overhead)
   323  	}
   324  
   325  	session.window.Add(packetId)
   326  
   327  	var headerType byte
   328  	headerType, err = buffer.ReadByte()
   329  	if err != nil {
   330  		err = E.Cause(err, "decrypt packet")
   331  		goto returnErr
   332  	}
   333  	if headerType != HeaderTypeClient {
   334  		err = E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
   335  		goto returnErr
   336  	}
   337  
   338  	var epoch uint64
   339  	err = binary.Read(buffer, binary.BigEndian, &epoch)
   340  	if err != nil {
   341  		goto returnErr
   342  	}
   343  	diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
   344  	if diff > 30 {
   345  		err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
   346  		goto returnErr
   347  	}
   348  
   349  	var paddingLen uint16
   350  	err = binary.Read(buffer, binary.BigEndian, &paddingLen)
   351  	if err != nil {
   352  		err = E.Cause(err, "read padding length")
   353  		goto returnErr
   354  	}
   355  	buffer.Advance(int(paddingLen))
   356  
   357  	destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
   358  	if err != nil {
   359  		goto returnErr
   360  	}
   361  
   362  	metadata.Protocol = "shadowsocks"
   363  	metadata.Destination = destination
   364  	s.udpNat.NewContextPacket(ctx, sessionId, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
   365  		return auth.ContextWithUser(ctx, user), &serverPacketWriter{s.Service, conn, natConn, session, s.uCipher[user]}
   366  	})
   367  	return nil
   368  }
   369  
   370  func (s *MultiService[U]) newUDPSession(uPSK []byte) *serverUDPSession {
   371  	session := &serverUDPSession{}
   372  	if s.udpCipher != nil {
   373  		session.rng = Blake3KeyedHash(rand.Reader)
   374  		common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId))
   375  	} else {
   376  		common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId))
   377  	}
   378  	session.packetId--
   379  	sessionId := make([]byte, 8)
   380  	binary.BigEndian.PutUint64(sessionId, session.sessionId)
   381  	key := SessionKey(uPSK, sessionId, s.keySaltLength)
   382  	var err error
   383  	session.cipher, err = s.constructor(key)
   384  	common.Must(err)
   385  	return session
   386  }