github.com/MicalKarl/sing-shadowsocks@v0.0.5/shadowaead_2022/service_multi.go (about)

     1  package shadowaead_2022
     2  
     3  import (
     4  	"context"
     5  	"crypto/aes"
     6  	"crypto/cipher"
     7  	"crypto/rand"
     8  	"encoding/base64"
     9  	"encoding/binary"
    10  	"math"
    11  	"net"
    12  	"os"
    13  	"time"
    14  
    15  	shadowsocks "github.com/MicalKarl/sing-shadowsocks"
    16  	"github.com/MicalKarl/sing-shadowsocks/shadowaead"
    17  	"github.com/MicalKarl/sing-shadowsocks/ssv"
    18  	"github.com/sagernet/sing/common"
    19  	"github.com/sagernet/sing/common/auth"
    20  	"github.com/sagernet/sing/common/buf"
    21  	E "github.com/sagernet/sing/common/exceptions"
    22  	M "github.com/sagernet/sing/common/metadata"
    23  	N "github.com/sagernet/sing/common/network"
    24  	"github.com/sagernet/sing/common/rw"
    25  
    26  	"lukechampine.com/blake3"
    27  )
    28  
    29  var _ shadowsocks.MultiService[int] = (*MultiService[int])(nil)
    30  
    31  type MultiService[U comparable] struct {
    32  	*Service
    33  
    34  	uPSK     map[U][]byte
    35  	uPSKHash map[[aes.BlockSize]byte]U
    36  	uCipher  map[U]cipher.Block
    37  }
    38  
    39  func NewMultiServiceWithPassword[U comparable](method string, password string, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (*MultiService[U], error) {
    40  	if password == "" {
    41  		return nil, ErrMissingPSK
    42  	}
    43  	iPSK, err := base64.StdEncoding.DecodeString(password)
    44  	if err != nil {
    45  		return nil, E.Cause(err, "decode psk")
    46  	}
    47  	return NewMultiService[U](method, iPSK, udpTimeout, handler, timeFunc)
    48  }
    49  
    50  func NewMultiService[U comparable](method string, iPSK []byte, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (*MultiService[U], error) {
    51  	switch method {
    52  	case "2022-blake3-aes-128-gcm":
    53  	case "2022-blake3-aes-256-gcm":
    54  	default:
    55  		return nil, os.ErrInvalid
    56  	}
    57  
    58  	ss, err := NewService(method, iPSK, udpTimeout, handler, timeFunc)
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  
    63  	s := &MultiService[U]{
    64  		Service: ss.(*Service),
    65  
    66  		uPSK:     make(map[U][]byte),
    67  		uPSKHash: make(map[[aes.BlockSize]byte]U),
    68  	}
    69  	return s, nil
    70  }
    71  
    72  func (s *MultiService[U]) UpdateUsers(userList []U, keyList [][]byte) error {
    73  	uPSK := make(map[U][]byte)
    74  	uPSKHash := make(map[[aes.BlockSize]byte]U)
    75  	uCipher := make(map[U]cipher.Block)
    76  	for i, user := range userList {
    77  		key := keyList[i]
    78  		if len(key) < s.keySaltLength {
    79  			return shadowsocks.ErrBadKey
    80  		} else if len(key) > s.keySaltLength {
    81  			key = Key(key, s.keySaltLength)
    82  		}
    83  
    84  		var hash [aes.BlockSize]byte
    85  		hash512 := blake3.Sum512(key)
    86  		copy(hash[:], hash512[:])
    87  
    88  		uPSKHash[hash] = user
    89  		uPSK[user] = key
    90  		var err error
    91  		uCipher[user], err = s.blockConstructor(key)
    92  		if err != nil {
    93  			return err
    94  		}
    95  	}
    96  
    97  	s.uPSK = uPSK
    98  	s.uPSKHash = uPSKHash
    99  	s.uCipher = uCipher
   100  	return nil
   101  }
   102  
   103  func (s *MultiService[U]) UpdateUsersWithPasswords(userList []U, passwordList []string) error {
   104  	keyList := make([][]byte, 0, len(passwordList))
   105  	for _, password := range passwordList {
   106  		if password == "" {
   107  			return shadowsocks.ErrMissingPassword
   108  		}
   109  		uPSK, err := base64.StdEncoding.DecodeString(password)
   110  		if err != nil {
   111  			return E.Cause(err, "decode psk")
   112  		}
   113  		keyList = append(keyList, uPSK)
   114  	}
   115  	return s.UpdateUsers(userList, keyList)
   116  }
   117  
   118  func (s *MultiService[U]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
   119  	err := s.newConnection(ctx, conn, metadata)
   120  	if err != nil {
   121  		err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err}
   122  	}
   123  	return err
   124  }
   125  
   126  func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
   127  	requestHeader := make([]byte, s.keySaltLength+aes.BlockSize+shadowaead.Overhead+RequestHeaderFixedChunkLength)
   128  	n, err := conn.Read(requestHeader)
   129  	if err != nil {
   130  		return err
   131  	} else if n < len(requestHeader) {
   132  		return shadowaead.ErrBadHeader
   133  	}
   134  	requestSalt := requestHeader[:s.keySaltLength]
   135  	if !s.replayFilter.Check(requestSalt) {
   136  		return ErrSaltNotUnique
   137  	}
   138  
   139  	var _eiHeader [aes.BlockSize]byte
   140  	eiHeader := _eiHeader[:]
   141  	copy(eiHeader, requestHeader[s.keySaltLength:s.keySaltLength+aes.BlockSize])
   142  
   143  	keyMaterial := make([]byte, s.keySaltLength*2)
   144  	copy(keyMaterial, s.psk)
   145  	copy(keyMaterial[s.keySaltLength:], requestSalt)
   146  	identitySubkey := buf.NewSize(s.keySaltLength)
   147  	identitySubkey.Extend(identitySubkey.FreeLen())
   148  	blake3.DeriveKey(identitySubkey.Bytes(), "shadowsocks 2022 identity subkey", keyMaterial)
   149  	b, err := s.blockConstructor(identitySubkey.Bytes())
   150  	identitySubkey.Release()
   151  	if err != nil {
   152  		return err
   153  	}
   154  	b.Decrypt(eiHeader, eiHeader)
   155  
   156  	var user U
   157  	var uPSK []byte
   158  	if u, loaded := s.uPSKHash[_eiHeader]; loaded {
   159  		user = u
   160  		uPSK = s.uPSK[u]
   161  	} else {
   162  		return E.New("invalid request")
   163  	}
   164  
   165  	requestKey := SessionKey(uPSK, requestSalt, s.keySaltLength)
   166  	readCipher, err := s.constructor(requestKey)
   167  	if err != nil {
   168  		return err
   169  	}
   170  	reader := shadowaead.NewReader(
   171  		conn,
   172  		readCipher,
   173  		MaxPacketSize,
   174  	)
   175  
   176  	err = reader.ReadExternalChunk(requestHeader[s.keySaltLength+aes.BlockSize:])
   177  	if err != nil {
   178  		return err
   179  	}
   180  
   181  	headerType, err := rw.ReadByte(reader)
   182  	if err != nil {
   183  		return E.Cause(err, "read header")
   184  	}
   185  
   186  	if headerType != HeaderTypeClient {
   187  		return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
   188  	}
   189  
   190  	var epoch uint64
   191  	err = binary.Read(reader, binary.BigEndian, &epoch)
   192  	if err != nil {
   193  		return E.Cause(err, "read timestamp")
   194  	}
   195  	diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
   196  	if diff > 30 {
   197  		return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
   198  	}
   199  	var length uint16
   200  	err = binary.Read(reader, binary.BigEndian, &length)
   201  	if err != nil {
   202  		return E.Cause(err, "read length")
   203  	}
   204  
   205  	err = reader.ReadWithLength(length)
   206  	if err != nil {
   207  		return err
   208  	}
   209  
   210  	destination, err := ssv.FakeSockSerializer.ReadAddrPort(reader)
   211  	if err != nil {
   212  		return E.Cause(err, "read destination")
   213  	}
   214  
   215  	var paddingLen uint16
   216  	err = binary.Read(reader, binary.BigEndian, &paddingLen)
   217  	if err != nil {
   218  		return E.Cause(err, "read padding length")
   219  	}
   220  
   221  	if reader.Cached() < int(paddingLen) {
   222  		return ErrBadPadding
   223  	} else if paddingLen > 0 {
   224  		err = reader.Discard(int(paddingLen))
   225  		if err != nil {
   226  			return E.Cause(err, "discard padding")
   227  		}
   228  	} else if reader.Cached() == 0 {
   229  		return ErrNoPadding
   230  	}
   231  
   232  	protocolConn := &serverConn{
   233  		Service:     s.Service,
   234  		Conn:        conn,
   235  		uPSK:        uPSK,
   236  		headerType:  headerType,
   237  		requestSalt: requestSalt,
   238  	}
   239  
   240  	protocolConn.reader = reader
   241  	metadata.Protocol = "shadowsocks"
   242  	metadata.Destination = destination
   243  	return s.handler.NewConnection(auth.ContextWithUser(ctx, user), protocolConn, metadata)
   244  }
   245  
   246  func (s *MultiService[U]) WriteIsThreadUnsafe() {
   247  }
   248  
   249  func (s *MultiService[U]) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
   250  	err := s.newPacket(ctx, conn, buffer, metadata)
   251  	if err != nil {
   252  		err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err}
   253  	}
   254  	return err
   255  }
   256  
   257  func (s *MultiService[U]) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
   258  	if buffer.Len() < PacketMinimalHeaderSize {
   259  		return ErrPacketTooShort
   260  	}
   261  
   262  	packetHeader := buffer.To(aes.BlockSize)
   263  	s.udpBlockCipher.Decrypt(packetHeader, packetHeader)
   264  
   265  	var _eiHeader [aes.BlockSize]byte
   266  	eiHeader := _eiHeader[:]
   267  	s.udpBlockCipher.Decrypt(eiHeader, buffer.Range(aes.BlockSize, 2*aes.BlockSize))
   268  	xorWords(eiHeader, eiHeader, packetHeader)
   269  
   270  	var user U
   271  	var uPSK []byte
   272  	if u, loaded := s.uPSKHash[_eiHeader]; loaded {
   273  		user = u
   274  		uPSK = s.uPSK[u]
   275  	} else {
   276  		return E.New("invalid request")
   277  	}
   278  
   279  	var sessionId, packetId uint64
   280  	err := binary.Read(buffer, binary.BigEndian, &sessionId)
   281  	if err != nil {
   282  		return err
   283  	}
   284  	err = binary.Read(buffer, binary.BigEndian, &packetId)
   285  	if err != nil {
   286  		return err
   287  	}
   288  
   289  	buffer.Advance(aes.BlockSize)
   290  
   291  	session, loaded := s.udpSessions.LoadOrStore(sessionId, func() *serverUDPSession {
   292  		return s.newUDPSession(uPSK)
   293  	})
   294  	if !loaded {
   295  		session.remoteSessionId = sessionId
   296  		key := SessionKey(uPSK, packetHeader[:8], s.keySaltLength)
   297  		session.remoteCipher, err = s.constructor(key)
   298  		if err != nil {
   299  			return err
   300  		}
   301  	}
   302  
   303  	goto process
   304  
   305  returnErr:
   306  	if !loaded {
   307  		s.udpSessions.Delete(sessionId)
   308  	}
   309  	return err
   310  
   311  process:
   312  	if !session.window.Check(packetId) {
   313  		err = ErrPacketIdNotUnique
   314  		goto returnErr
   315  	}
   316  
   317  	if packetHeader != nil {
   318  		_, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
   319  		if err != nil {
   320  			err = E.Cause(err, "decrypt packet")
   321  			goto returnErr
   322  		}
   323  		buffer.Truncate(buffer.Len() - shadowaead.Overhead)
   324  	}
   325  
   326  	session.window.Add(packetId)
   327  
   328  	var headerType byte
   329  	headerType, err = buffer.ReadByte()
   330  	if err != nil {
   331  		err = E.Cause(err, "decrypt packet")
   332  		goto returnErr
   333  	}
   334  	if headerType != HeaderTypeClient {
   335  		err = E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
   336  		goto returnErr
   337  	}
   338  
   339  	var epoch uint64
   340  	err = binary.Read(buffer, binary.BigEndian, &epoch)
   341  	if err != nil {
   342  		goto returnErr
   343  	}
   344  	diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
   345  	if diff > 30 {
   346  		err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
   347  		goto returnErr
   348  	}
   349  
   350  	var paddingLen uint16
   351  	err = binary.Read(buffer, binary.BigEndian, &paddingLen)
   352  	if err != nil {
   353  		err = E.Cause(err, "read padding length")
   354  		goto returnErr
   355  	}
   356  	buffer.Advance(int(paddingLen))
   357  
   358  	destination, err := ssv.FakeSockSerializer.ReadAddrPort(buffer)
   359  	if err != nil {
   360  		goto returnErr
   361  	}
   362  
   363  	metadata.Protocol = "shadowsocks"
   364  	metadata.Destination = destination
   365  	s.udpNat.NewContextPacket(ctx, sessionId, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
   366  		return auth.ContextWithUser(ctx, user), &serverPacketWriter{s.Service, conn, natConn, session, s.uCipher[user]}
   367  	})
   368  	return nil
   369  }
   370  
   371  func (s *MultiService[U]) newUDPSession(uPSK []byte) *serverUDPSession {
   372  	session := &serverUDPSession{}
   373  	if s.udpCipher != nil {
   374  		session.rng = Blake3KeyedHash(rand.Reader)
   375  		common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId))
   376  	} else {
   377  		common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId))
   378  	}
   379  	session.packetId--
   380  	sessionId := make([]byte, 8)
   381  	binary.BigEndian.PutUint64(sessionId, session.sessionId)
   382  	key := SessionKey(uPSK, sessionId, s.keySaltLength)
   383  	var err error
   384  	session.cipher, err = s.constructor(key)
   385  	common.Must(err)
   386  	return session
   387  }