github.com/MicalKarl/sing-shadowsocks@v0.0.5/shadowaead_2022/protocol.go (about)

     1  package shadowaead_2022
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/aes"
     6  	"crypto/cipher"
     7  	"crypto/rand"
     8  	"crypto/sha256"
     9  	"encoding/base64"
    10  	"encoding/binary"
    11  	"io"
    12  	"math"
    13  	mRand "math/rand"
    14  	"net"
    15  	"os"
    16  	"strings"
    17  	"sync/atomic"
    18  	"time"
    19  
    20  	shadowsocks "github.com/MicalKarl/sing-shadowsocks"
    21  	"github.com/MicalKarl/sing-shadowsocks/shadowaead"
    22  	"github.com/MicalKarl/sing-shadowsocks/ssv"
    23  	"github.com/sagernet/sing/common"
    24  	"github.com/sagernet/sing/common/buf"
    25  	"github.com/sagernet/sing/common/bufio"
    26  	E "github.com/sagernet/sing/common/exceptions"
    27  	M "github.com/sagernet/sing/common/metadata"
    28  	N "github.com/sagernet/sing/common/network"
    29  	"github.com/sagernet/sing/common/random"
    30  	"github.com/sagernet/sing/common/rw"
    31  
    32  	"golang.org/x/crypto/chacha20poly1305"
    33  	"lukechampine.com/blake3"
    34  )
    35  
    36  const (
    37  	HeaderTypeClient              = 0
    38  	HeaderTypeServer              = 1
    39  	MaxPaddingLength              = 900
    40  	PacketNonceSize               = 24
    41  	MaxPacketSize                 = 65535
    42  	RequestHeaderFixedChunkLength = 1 + 8 + 2
    43  	PacketMinimalHeaderSize       = 30
    44  )
    45  
    46  var (
    47  	ErrMissingPSK            = E.New("missing psk")
    48  	ErrBadHeaderType         = E.New("bad header type")
    49  	ErrBadTimestamp          = E.New("bad timestamp")
    50  	ErrBadRequestSalt        = E.New("bad request salt")
    51  	ErrSaltNotUnique         = E.New("salt not unique")
    52  	ErrBadClientSessionId    = E.New("bad client session id")
    53  	ErrPacketIdNotUnique     = E.New("packet id not unique")
    54  	ErrTooManyServerSessions = E.New("server session changed more than once during the last minute")
    55  	ErrPacketTooShort        = E.New("packet too short")
    56  )
    57  
    58  var List = []string{
    59  	"2022-blake3-aes-128-gcm",
    60  	"2022-blake3-aes-256-gcm",
    61  	"2022-blake3-chacha20-poly1305",
    62  }
    63  
    64  func init() {
    65  	random.InitializeSeed()
    66  }
    67  
    68  func NewWithPassword(method string, password string, timeFunc func() time.Time) (shadowsocks.Method, error) {
    69  	var pskList [][]byte
    70  	if password == "" {
    71  		return nil, ErrMissingPSK
    72  	}
    73  	keyStrList := strings.Split(password, ":")
    74  	pskList = make([][]byte, len(keyStrList))
    75  	for i, keyStr := range keyStrList {
    76  		kb, err := base64.StdEncoding.DecodeString(keyStr)
    77  		if err != nil {
    78  			return nil, E.Cause(err, "decode key")
    79  		}
    80  		pskList[i] = kb
    81  	}
    82  	return New(method, pskList, timeFunc)
    83  }
    84  
    85  func New(method string, pskList [][]byte, timeFunc func() time.Time) (shadowsocks.Method, error) {
    86  	m := &Method{
    87  		name:     method,
    88  		timeFunc: timeFunc,
    89  	}
    90  
    91  	switch method {
    92  	case "2022-blake3-aes-128-gcm":
    93  		m.keySaltLength = 16
    94  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    95  		m.blockConstructor = aes.NewCipher
    96  	case "2022-blake3-aes-256-gcm":
    97  		m.keySaltLength = 32
    98  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    99  		m.blockConstructor = aes.NewCipher
   100  	case "2022-blake3-chacha20-poly1305":
   101  		if len(pskList) > 1 {
   102  			return nil, os.ErrInvalid
   103  		}
   104  		m.keySaltLength = 32
   105  		m.constructor = chacha20poly1305.New
   106  	}
   107  
   108  	if len(pskList) == 0 {
   109  		return nil, ErrMissingPSK
   110  	}
   111  
   112  	for i, psk := range pskList {
   113  		if len(psk) < m.keySaltLength {
   114  			return nil, shadowsocks.ErrBadKey
   115  		} else if len(psk) > m.keySaltLength {
   116  			pskList[i] = Key(psk, m.keySaltLength)
   117  		}
   118  	}
   119  
   120  	if len(pskList) > 1 {
   121  		pskHash := make([]byte, (len(pskList)-1)*aes.BlockSize)
   122  		for i, psk := range pskList {
   123  			if i == 0 {
   124  				continue
   125  			}
   126  			hash := blake3.Sum512(psk)
   127  			copy(pskHash[aes.BlockSize*(i-1):aes.BlockSize*i], hash[:aes.BlockSize])
   128  		}
   129  		m.pskHash = pskHash
   130  	}
   131  
   132  	var err error
   133  	switch method {
   134  	case "2022-blake3-aes-128-gcm", "2022-blake3-aes-256-gcm":
   135  		m.udpBlockEncryptCipher, err = aes.NewCipher(pskList[0])
   136  		if err != nil {
   137  			return nil, err
   138  		}
   139  		m.udpBlockDecryptCipher, err = aes.NewCipher(pskList[len(pskList)-1])
   140  		if err != nil {
   141  			return nil, err
   142  		}
   143  	case "2022-blake3-chacha20-poly1305":
   144  		m.udpCipher, err = chacha20poly1305.NewX(pskList[0])
   145  		if err != nil {
   146  			return nil, err
   147  		}
   148  	}
   149  
   150  	m.pskList = pskList
   151  	return m, nil
   152  }
   153  
   154  func Key(key []byte, keyLength int) []byte {
   155  	psk := sha256.Sum256(key)
   156  	return psk[:keyLength]
   157  }
   158  
   159  func SessionKey(psk []byte, salt []byte, keyLength int) []byte {
   160  	sessionKey := make([]byte, len(psk)+len(salt))
   161  	copy(sessionKey, psk)
   162  	copy(sessionKey[len(psk):], salt)
   163  	outKey := make([]byte, keyLength)
   164  	blake3.DeriveKey(outKey, "shadowsocks 2022 session subkey", sessionKey)
   165  	return outKey
   166  }
   167  
   168  func aeadCipher(block func(key []byte) (cipher.Block, error), aead func(block cipher.Block) (cipher.AEAD, error)) func(key []byte) (cipher.AEAD, error) {
   169  	return func(key []byte) (cipher.AEAD, error) {
   170  		b, err := block(key)
   171  		if err != nil {
   172  			return nil, err
   173  		}
   174  		return aead(b)
   175  	}
   176  }
   177  
   178  type Method struct {
   179  	name          string
   180  	keySaltLength int
   181  	timeFunc      func() time.Time
   182  
   183  	constructor           func(key []byte) (cipher.AEAD, error)
   184  	blockConstructor      func(key []byte) (cipher.Block, error)
   185  	udpCipher             cipher.AEAD
   186  	udpBlockEncryptCipher cipher.Block
   187  	udpBlockDecryptCipher cipher.Block
   188  	pskList               [][]byte
   189  	pskHash               []byte
   190  }
   191  
   192  func (m *Method) Name() string {
   193  	return m.name
   194  }
   195  
   196  func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
   197  	shadowsocksConn := &clientConn{
   198  		Method:      m,
   199  		Conn:        conn,
   200  		destination: destination,
   201  	}
   202  	return shadowsocksConn, shadowsocksConn.writeRequest(nil)
   203  }
   204  
   205  func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
   206  	return &clientConn{
   207  		Method:      m,
   208  		Conn:        conn,
   209  		destination: destination,
   210  	}
   211  }
   212  
   213  func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
   214  	return &clientPacketConn{m, conn, m.newUDPSession()}
   215  }
   216  
   217  type clientConn struct {
   218  	*Method
   219  	net.Conn
   220  	destination M.Socksaddr
   221  	requestSalt []byte
   222  	reader      *shadowaead.Reader
   223  	writer      *shadowaead.Writer
   224  }
   225  
   226  func (m *Method) time() time.Time {
   227  	if m.timeFunc != nil {
   228  		return m.timeFunc()
   229  	} else {
   230  		return time.Now()
   231  	}
   232  }
   233  
   234  func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) error {
   235  	pskLen := len(m.pskList)
   236  	if pskLen < 2 {
   237  		return nil
   238  	}
   239  	for i, psk := range m.pskList {
   240  		keyMaterial := make([]byte, m.keySaltLength*2)
   241  		copy(keyMaterial, psk)
   242  		copy(keyMaterial[m.keySaltLength:], salt)
   243  		identitySubkey := buf.NewSize(m.keySaltLength)
   244  		identitySubkey.Extend(identitySubkey.FreeLen())
   245  		blake3.DeriveKey(identitySubkey.Bytes(), "shadowsocks 2022 identity subkey", keyMaterial)
   246  
   247  		pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
   248  
   249  		header := request.Extend(16)
   250  		b, err := m.blockConstructor(identitySubkey.Bytes())
   251  		if err != nil {
   252  			return err
   253  		}
   254  		b.Encrypt(header, pskHash)
   255  		identitySubkey.Release()
   256  		if i == pskLen-2 {
   257  			break
   258  		}
   259  	}
   260  	return nil
   261  }
   262  
   263  func (c *clientConn) writeRequest(payload []byte) error {
   264  	salt := make([]byte, c.keySaltLength)
   265  	common.Must1(io.ReadFull(rand.Reader, salt))
   266  
   267  	key := SessionKey(c.pskList[len(c.pskList)-1], salt, c.keySaltLength)
   268  	writeCipher, err := c.constructor(key)
   269  	if err != nil {
   270  		return err
   271  	}
   272  	writer := shadowaead.NewWriter(
   273  		c.Conn,
   274  		writeCipher,
   275  		MaxPacketSize,
   276  	)
   277  
   278  	header := writer.Buffer()
   279  	header.Write(salt)
   280  
   281  	err = c.writeExtendedIdentityHeaders(header, salt)
   282  	if err != nil {
   283  		return err
   284  	}
   285  
   286  	var _fixedLengthBuffer [RequestHeaderFixedChunkLength]byte
   287  	fixedLengthBuffer := buf.With(_fixedLengthBuffer[:])
   288  	common.Must(fixedLengthBuffer.WriteByte(HeaderTypeClient))
   289  	common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(c.time().Unix())))
   290  	var paddingLen int
   291  	if len(payload) < MaxPaddingLength {
   292  		paddingLen = mRand.Intn(MaxPaddingLength) + 1
   293  	}
   294  	variableLengthHeaderLen := M.SocksaddrSerializer.AddrPortLen(c.destination) + 2 + paddingLen
   295  	payloadLen := len(payload)
   296  	variableLengthHeaderLen += payloadLen
   297  	common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint16(variableLengthHeaderLen)))
   298  	writer.WriteChunk(header, fixedLengthBuffer.Slice())
   299  
   300  	variableLengthBuffer := buf.NewSize(variableLengthHeaderLen)
   301  	err = ssv.FakeSockSerializer.WriteAddrPort(variableLengthBuffer, c.destination)
   302  	if err != nil {
   303  		return err
   304  	}
   305  	common.Must(binary.Write(variableLengthBuffer, binary.BigEndian, uint16(paddingLen)))
   306  	if paddingLen > 0 {
   307  		variableLengthBuffer.Extend(paddingLen)
   308  	}
   309  	if payloadLen > 0 {
   310  		common.Must1(variableLengthBuffer.Write(payload[:payloadLen]))
   311  	}
   312  	writer.WriteChunk(header, variableLengthBuffer.Slice())
   313  	variableLengthBuffer.Release()
   314  
   315  	err = writer.BufferedWriter(header.Len()).Flush()
   316  	if err != nil {
   317  		return E.Cause(err, "client handshake")
   318  	}
   319  
   320  	c.requestSalt = salt
   321  	c.writer = writer
   322  	return nil
   323  }
   324  
   325  func (c *clientConn) readResponse() error {
   326  	if c.reader != nil {
   327  		return nil
   328  	}
   329  
   330  	salt := buf.NewSize(c.keySaltLength)
   331  
   332  	_, err := salt.ReadFullFrom(c.Conn, salt.FreeLen())
   333  	if err != nil {
   334  		salt.Release()
   335  		return err
   336  	}
   337  
   338  	key := SessionKey(c.pskList[len(c.pskList)-1], salt.Bytes(), c.keySaltLength)
   339  	salt.Release()
   340  
   341  	readCipher, err := c.constructor(key)
   342  	if err != nil {
   343  		return err
   344  	}
   345  	reader := shadowaead.NewReader(
   346  		c.Conn,
   347  		readCipher,
   348  		MaxPacketSize,
   349  	)
   350  
   351  	err = reader.ReadWithLength(uint16(1 + 8 + c.keySaltLength + 2))
   352  	if err != nil {
   353  		return E.Cause(err, "read response fixed length chunk")
   354  	}
   355  
   356  	headerType, err := rw.ReadByte(reader)
   357  	if err != nil {
   358  		return err
   359  	}
   360  	if headerType != HeaderTypeServer /* && headerType != HeaderTypeServerEncrypted*/ {
   361  		return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType)
   362  	}
   363  
   364  	var epoch uint64
   365  	err = binary.Read(reader, binary.BigEndian, &epoch)
   366  	if err != nil {
   367  		return err
   368  	}
   369  
   370  	diff := int(math.Abs(float64(c.time().Unix() - int64(epoch))))
   371  	if diff > 30 {
   372  		return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
   373  	}
   374  
   375  	requestSalt := buf.NewSize(c.keySaltLength)
   376  	_, err = requestSalt.ReadFullFrom(reader, requestSalt.FreeLen())
   377  	if err != nil {
   378  		return err
   379  	}
   380  
   381  	if bytes.Compare(requestSalt.Bytes(), c.requestSalt) > 0 {
   382  		return ErrBadRequestSalt
   383  	}
   384  	requestSalt.Release()
   385  	c.requestSalt = nil
   386  
   387  	var length uint16
   388  	err = binary.Read(reader, binary.BigEndian, &length)
   389  	if err != nil {
   390  		return err
   391  	}
   392  
   393  	err = reader.ReadWithLength(length)
   394  	if err != nil {
   395  		return err
   396  	}
   397  	if headerType == HeaderTypeServer {
   398  		c.reader = reader
   399  	}
   400  	return nil
   401  }
   402  
   403  func (c *clientConn) Read(p []byte) (n int, err error) {
   404  	if err = c.readResponse(); err != nil {
   405  		return
   406  	}
   407  	return c.reader.Read(p)
   408  }
   409  
   410  func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
   411  	if err = c.readResponse(); err != nil {
   412  		return
   413  	}
   414  	return bufio.Copy(w, c.reader)
   415  }
   416  
   417  func (c *clientConn) Write(p []byte) (n int, err error) {
   418  	if c.writer == nil {
   419  		err = c.writeRequest(p)
   420  		if err == nil {
   421  			n = len(p)
   422  		}
   423  		return
   424  	}
   425  	return c.writer.Write(p)
   426  }
   427  
   428  var _ N.VectorisedWriter = (*clientConn)(nil)
   429  
   430  func (c *clientConn) WriteVectorised(buffers []*buf.Buffer) error {
   431  	if c.writer != nil {
   432  		return c.writer.WriteVectorised(buffers)
   433  	}
   434  	err := c.writeRequest(buffers[0].Bytes())
   435  	if err != nil {
   436  		buf.ReleaseMulti(buffers)
   437  		return err
   438  	}
   439  	buffers[0].Release()
   440  	return c.writer.WriteVectorised(buffers[1:])
   441  }
   442  
   443  func (c *clientConn) NeedHandshake() bool {
   444  	return c.writer == nil
   445  }
   446  
   447  func (c *clientConn) NeedAdditionalReadDeadline() bool {
   448  	return true
   449  }
   450  
   451  func (c *clientConn) Upstream() any {
   452  	return c.Conn
   453  }
   454  
   455  func (c *clientConn) Close() error {
   456  	return common.Close(
   457  		c.Conn,
   458  		common.PtrOrNil(c.reader),
   459  		common.PtrOrNil(c.writer),
   460  	)
   461  }
   462  
   463  type clientPacketConn struct {
   464  	*Method
   465  	net.Conn
   466  	session *udpSession
   467  }
   468  
   469  func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   470  	defer buffer.Release()
   471  	var hdrLen int
   472  	if c.udpCipher != nil {
   473  		hdrLen = PacketNonceSize
   474  	}
   475  
   476  	var paddingLen int
   477  	if destination.Port == 53 && buffer.Len() < MaxPaddingLength {
   478  		paddingLen = mRand.Intn(MaxPaddingLength-buffer.Len()) + 1
   479  	}
   480  
   481  	hdrLen += 16 // packet header
   482  	pskLen := len(c.pskList)
   483  	if c.udpCipher == nil && pskLen > 1 {
   484  		hdrLen += (pskLen - 1) * aes.BlockSize
   485  	}
   486  	hdrLen += 1 // header type
   487  	hdrLen += 8 // timestamp
   488  	hdrLen += 2 // padding length
   489  	hdrLen += paddingLen
   490  	hdrLen += M.SocksaddrSerializer.AddrPortLen(destination)
   491  	header := buf.With(buffer.ExtendHeader(hdrLen))
   492  
   493  	var dataIndex int
   494  	if c.udpCipher != nil {
   495  		common.Must1(header.ReadFullFrom(c.session.rng, PacketNonceSize))
   496  		if pskLen > 1 {
   497  			panic("unsupported chacha extended header")
   498  		}
   499  		dataIndex = PacketNonceSize
   500  	} else {
   501  		dataIndex = aes.BlockSize
   502  	}
   503  
   504  	common.Must(
   505  		binary.Write(header, binary.BigEndian, c.session.sessionId),
   506  		binary.Write(header, binary.BigEndian, c.session.nextPacketId()),
   507  	)
   508  
   509  	if c.udpCipher == nil && pskLen > 1 {
   510  		for i, psk := range c.pskList {
   511  			dataIndex += aes.BlockSize
   512  			pskHash := c.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
   513  
   514  			identityHeader := header.Extend(aes.BlockSize)
   515  			xorWords(identityHeader, pskHash, header.To(aes.BlockSize))
   516  			b, err := c.blockConstructor(psk)
   517  			if err != nil {
   518  				return err
   519  			}
   520  			b.Encrypt(identityHeader, identityHeader)
   521  
   522  			if i == pskLen-2 {
   523  				break
   524  			}
   525  		}
   526  	}
   527  	common.Must(
   528  		header.WriteByte(HeaderTypeClient),
   529  		binary.Write(header, binary.BigEndian, uint64(c.time().Unix())),
   530  		binary.Write(header, binary.BigEndian, uint16(paddingLen)), // padding length
   531  	)
   532  
   533  	if paddingLen > 0 {
   534  		header.Extend(paddingLen)
   535  	}
   536  
   537  	err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
   538  	if err != nil {
   539  		return err
   540  	}
   541  	if c.udpCipher != nil {
   542  		c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
   543  		buffer.Extend(shadowaead.Overhead)
   544  	} else {
   545  		packetHeader := buffer.To(aes.BlockSize)
   546  		c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
   547  		buffer.Extend(shadowaead.Overhead)
   548  		c.udpBlockEncryptCipher.Encrypt(packetHeader, packetHeader)
   549  	}
   550  	return common.Error(c.Write(buffer.Bytes()))
   551  }
   552  
   553  func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
   554  	n, err := c.Read(buffer.FreeBytes())
   555  	if err != nil {
   556  		return M.Socksaddr{}, err
   557  	}
   558  	buffer.Truncate(n)
   559  
   560  	var packetHeader []byte
   561  	if c.udpCipher != nil {
   562  		if buffer.Len() < PacketNonceSize+PacketMinimalHeaderSize {
   563  			return M.Socksaddr{}, ErrPacketTooShort
   564  		}
   565  		_, err = c.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
   566  		if err != nil {
   567  			return M.Socksaddr{}, E.Cause(err, "decrypt packet")
   568  		}
   569  		buffer.Advance(PacketNonceSize)
   570  		buffer.Truncate(buffer.Len() - shadowaead.Overhead)
   571  	} else {
   572  		if buffer.Len() < PacketMinimalHeaderSize {
   573  			return M.Socksaddr{}, ErrPacketTooShort
   574  		}
   575  		packetHeader = buffer.To(aes.BlockSize)
   576  		c.udpBlockDecryptCipher.Decrypt(packetHeader, packetHeader)
   577  	}
   578  
   579  	var sessionId, packetId uint64
   580  	err = binary.Read(buffer, binary.BigEndian, &sessionId)
   581  	if err != nil {
   582  		return M.Socksaddr{}, err
   583  	}
   584  	err = binary.Read(buffer, binary.BigEndian, &packetId)
   585  	if err != nil {
   586  		return M.Socksaddr{}, err
   587  	}
   588  
   589  	if sessionId == c.session.remoteSessionId {
   590  		if !c.session.window.Check(packetId) {
   591  			return M.Socksaddr{}, ErrPacketIdNotUnique
   592  		}
   593  	} else if sessionId == c.session.lastRemoteSessionId {
   594  		if !c.session.lastWindow.Check(packetId) {
   595  			return M.Socksaddr{}, ErrPacketIdNotUnique
   596  		}
   597  	}
   598  
   599  	var remoteCipher cipher.AEAD
   600  	if packetHeader != nil {
   601  		if sessionId == c.session.remoteSessionId {
   602  			remoteCipher = c.session.remoteCipher
   603  		} else if sessionId == c.session.lastRemoteSessionId {
   604  			remoteCipher = c.session.lastRemoteCipher
   605  		} else {
   606  			key := SessionKey(c.pskList[len(c.pskList)-1], packetHeader[:8], c.keySaltLength)
   607  			remoteCipher, err = c.constructor(key)
   608  			if err != nil {
   609  				return M.Socksaddr{}, err
   610  			}
   611  		}
   612  		_, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
   613  		if err != nil {
   614  			return M.Socksaddr{}, E.Cause(err, "decrypt packet")
   615  		}
   616  		buffer.Truncate(buffer.Len() - shadowaead.Overhead)
   617  	}
   618  
   619  	var headerType byte
   620  	headerType, err = buffer.ReadByte()
   621  	if err != nil {
   622  		return M.Socksaddr{}, err
   623  	}
   624  	if headerType != HeaderTypeServer {
   625  		return M.Socksaddr{}, E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType)
   626  	}
   627  
   628  	var epoch uint64
   629  	err = binary.Read(buffer, binary.BigEndian, &epoch)
   630  	if err != nil {
   631  		return M.Socksaddr{}, err
   632  	}
   633  
   634  	diff := int(math.Abs(float64(c.time().Unix() - int64(epoch))))
   635  	if diff > 30 {
   636  		return M.Socksaddr{}, E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
   637  	}
   638  
   639  	if sessionId == c.session.remoteSessionId {
   640  		c.session.window.Add(packetId)
   641  	} else if sessionId == c.session.lastRemoteSessionId {
   642  		c.session.lastWindow.Add(packetId)
   643  		c.session.lastRemoteSeen = c.time().Unix()
   644  	} else {
   645  		if c.session.remoteSessionId != 0 {
   646  			if c.time().Unix()-c.session.lastRemoteSeen < 60 {
   647  				return M.Socksaddr{}, ErrTooManyServerSessions
   648  			} else {
   649  				c.session.lastRemoteSessionId = c.session.remoteSessionId
   650  				c.session.lastWindow = c.session.window
   651  				c.session.lastRemoteSeen = c.time().Unix()
   652  				c.session.lastRemoteCipher = c.session.remoteCipher
   653  				c.session.window = SlidingWindow{}
   654  			}
   655  		}
   656  		c.session.remoteSessionId = sessionId
   657  		c.session.remoteCipher = remoteCipher
   658  		c.session.window.Add(packetId)
   659  	}
   660  
   661  	var clientSessionId uint64
   662  	err = binary.Read(buffer, binary.BigEndian, &clientSessionId)
   663  	if err != nil {
   664  		return M.Socksaddr{}, err
   665  	}
   666  
   667  	if clientSessionId != c.session.sessionId {
   668  		return M.Socksaddr{}, ErrBadClientSessionId
   669  	}
   670  
   671  	var paddingLen uint16
   672  	err = binary.Read(buffer, binary.BigEndian, &paddingLen)
   673  	if err != nil {
   674  		return M.Socksaddr{}, E.Cause(err, "read padding length")
   675  	}
   676  	buffer.Advance(int(paddingLen))
   677  
   678  	destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
   679  	if err != nil {
   680  		return M.Socksaddr{}, err
   681  	}
   682  	return destination, nil
   683  }
   684  
   685  func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   686  	buffer := buf.With(p)
   687  	destination, err := c.ReadPacket(buffer)
   688  	if err != nil {
   689  		return
   690  	}
   691  	if destination.IsFqdn() {
   692  		addr = destination
   693  	} else {
   694  		addr = destination.UDPAddr()
   695  	}
   696  	n = copy(p, buffer.Bytes())
   697  	return
   698  }
   699  
   700  func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   701  	destination := M.SocksaddrFromNet(addr)
   702  	var overHead int
   703  	if c.udpCipher != nil {
   704  		overHead = PacketNonceSize + shadowaead.Overhead
   705  	} else {
   706  		overHead = shadowaead.Overhead
   707  	}
   708  	overHead += 16 // packet header
   709  	pskLen := len(c.pskList)
   710  	if c.udpCipher == nil && pskLen > 1 {
   711  		overHead += (pskLen - 1) * aes.BlockSize
   712  	}
   713  	var paddingLen int
   714  	if destination.Port == 53 && len(p) < MaxPaddingLength {
   715  		paddingLen = mRand.Intn(MaxPaddingLength-len(p)) + 1
   716  	}
   717  	overHead += 1 // header type
   718  	overHead += 8 // timestamp
   719  	overHead += 2 // padding length
   720  	overHead += paddingLen
   721  	overHead += M.SocksaddrSerializer.AddrPortLen(destination)
   722  
   723  	buffer := buf.NewSize(overHead + len(p))
   724  	defer buffer.Release()
   725  
   726  	var dataIndex int
   727  	if c.udpCipher != nil {
   728  		common.Must1(buffer.ReadFullFrom(c.session.rng, PacketNonceSize))
   729  		if pskLen > 1 {
   730  			panic("unsupported chacha extended header")
   731  		}
   732  		dataIndex = PacketNonceSize
   733  	} else {
   734  		dataIndex = aes.BlockSize
   735  	}
   736  
   737  	common.Must(
   738  		binary.Write(buffer, binary.BigEndian, c.session.sessionId),
   739  		binary.Write(buffer, binary.BigEndian, c.session.nextPacketId()),
   740  	)
   741  
   742  	if c.udpCipher == nil && pskLen > 1 {
   743  		for i, psk := range c.pskList {
   744  			dataIndex += aes.BlockSize
   745  			pskHash := c.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
   746  
   747  			identityHeader := buffer.Extend(aes.BlockSize)
   748  			xorWords(identityHeader, pskHash, buffer.To(aes.BlockSize))
   749  			b, err := c.blockConstructor(psk)
   750  			if err != nil {
   751  				return 0, err
   752  			}
   753  			b.Encrypt(identityHeader, identityHeader)
   754  
   755  			if i == pskLen-2 {
   756  				break
   757  			}
   758  		}
   759  	}
   760  	common.Must(
   761  		buffer.WriteByte(HeaderTypeClient),
   762  		binary.Write(buffer, binary.BigEndian, uint64(c.time().Unix())),
   763  		binary.Write(buffer, binary.BigEndian, uint16(paddingLen)), // padding length
   764  	)
   765  
   766  	if paddingLen > 0 {
   767  		buffer.Extend(paddingLen)
   768  	}
   769  
   770  	err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination)
   771  	if err != nil {
   772  		return
   773  	}
   774  	common.Must1(buffer.Write(p))
   775  	if c.udpCipher != nil {
   776  		c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
   777  		buffer.Extend(shadowaead.Overhead)
   778  	} else {
   779  		packetHeader := buffer.To(aes.BlockSize)
   780  		c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
   781  		buffer.Extend(shadowaead.Overhead)
   782  		c.udpBlockEncryptCipher.Encrypt(packetHeader, packetHeader)
   783  	}
   784  	err = common.Error(c.Write(buffer.Bytes()))
   785  	if err != nil {
   786  		return
   787  	}
   788  	return len(p), nil
   789  }
   790  
   791  func (c *clientPacketConn) FrontHeadroom() int {
   792  	var overHead int
   793  	if c.udpCipher != nil {
   794  		overHead = PacketNonceSize + shadowaead.Overhead
   795  	} else {
   796  		overHead = shadowaead.Overhead
   797  	}
   798  	overHead += 16 // packet header
   799  	pskLen := len(c.pskList)
   800  	if c.udpCipher == nil && pskLen > 1 {
   801  		overHead += (pskLen - 1) * aes.BlockSize
   802  	}
   803  	overHead += 1 // header type
   804  	overHead += 8 // timestamp
   805  	overHead += 2 // padding length
   806  	overHead += MaxPaddingLength
   807  	overHead += M.MaxSocksaddrLength
   808  	return overHead
   809  }
   810  
   811  func (c *clientPacketConn) RearHeadroom() int {
   812  	return shadowaead.Overhead
   813  }
   814  
   815  type udpSession struct {
   816  	sessionId           uint64
   817  	packetId            uint64
   818  	remoteSessionId     uint64
   819  	lastRemoteSessionId uint64
   820  	lastRemoteSeen      int64
   821  	cipher              cipher.AEAD
   822  	remoteCipher        cipher.AEAD
   823  	lastRemoteCipher    cipher.AEAD
   824  	window              SlidingWindow
   825  	lastWindow          SlidingWindow
   826  	rng                 io.Reader
   827  }
   828  
   829  func (s *udpSession) nextPacketId() uint64 {
   830  	return atomic.AddUint64(&s.packetId, 1)
   831  }
   832  
   833  func (m *Method) newUDPSession() *udpSession {
   834  	session := &udpSession{}
   835  	if m.udpCipher != nil {
   836  		session.rng = Blake3KeyedHash(rand.Reader)
   837  		common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId))
   838  	} else {
   839  		common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId))
   840  	}
   841  	session.packetId--
   842  	if m.udpCipher == nil {
   843  		sessionId := make([]byte, 8)
   844  		binary.BigEndian.PutUint64(sessionId, session.sessionId)
   845  		key := SessionKey(m.pskList[len(m.pskList)-1], sessionId, m.keySaltLength)
   846  		var err error
   847  		session.cipher, err = m.constructor(key)
   848  		if err != nil {
   849  			return nil
   850  		}
   851  	}
   852  	return session
   853  }
   854  
   855  func (c *clientPacketConn) Upstream() any {
   856  	return c.Conn
   857  }
   858  
   859  func (c *clientPacketConn) Close() error {
   860  	return common.Close(c.Conn)
   861  }
   862  
   863  func Blake3KeyedHash(reader io.Reader) io.Reader {
   864  	key := make([]byte, 32)
   865  	common.Must1(io.ReadFull(reader, key))
   866  	h := blake3.New(1024, key)
   867  	return h.XOF()
   868  }