github.com/MerlinKodo/sing-shadowsocks@v0.2.6/shadowaead_2022/protocol.go (about)

     1  package shadowaead_2022
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/aes"
     6  	"crypto/cipher"
     7  	"crypto/rand"
     8  	"crypto/sha256"
     9  	"encoding/base64"
    10  	"encoding/binary"
    11  	"io"
    12  	"math"
    13  	mRand "math/rand"
    14  	"net"
    15  	"os"
    16  	"strings"
    17  	"sync/atomic"
    18  	"time"
    19  
    20  	shadowsocks "github.com/MerlinKodo/sing-shadowsocks"
    21  	"github.com/MerlinKodo/sing-shadowsocks/shadowaead"
    22  	"github.com/sagernet/sing/common"
    23  	"github.com/sagernet/sing/common/buf"
    24  	"github.com/sagernet/sing/common/bufio"
    25  	E "github.com/sagernet/sing/common/exceptions"
    26  	M "github.com/sagernet/sing/common/metadata"
    27  	N "github.com/sagernet/sing/common/network"
    28  	"github.com/sagernet/sing/common/random"
    29  	"github.com/sagernet/sing/common/rw"
    30  
    31  	"golang.org/x/crypto/chacha20poly1305"
    32  	"lukechampine.com/blake3"
    33  )
    34  
    35  const (
    36  	HeaderTypeClient              = 0
    37  	HeaderTypeServer              = 1
    38  	MaxPaddingLength              = 900
    39  	PacketNonceSize               = 24
    40  	MaxPacketSize                 = 65535
    41  	RequestHeaderFixedChunkLength = 1 + 8 + 2
    42  	PacketMinimalHeaderSize       = 30
    43  )
    44  
    45  var (
    46  	ErrMissingPSK            = E.New("missing psk")
    47  	ErrBadHeaderType         = E.New("bad header type")
    48  	ErrBadTimestamp          = E.New("bad timestamp")
    49  	ErrBadRequestSalt        = E.New("bad request salt")
    50  	ErrSaltNotUnique         = E.New("salt not unique")
    51  	ErrBadClientSessionId    = E.New("bad client session id")
    52  	ErrPacketIdNotUnique     = E.New("packet id not unique")
    53  	ErrTooManyServerSessions = E.New("server session changed more than once during the last minute")
    54  	ErrPacketTooShort        = E.New("packet too short")
    55  )
    56  
    57  var List = []string{
    58  	"2022-blake3-aes-128-gcm",
    59  	"2022-blake3-aes-256-gcm",
    60  	"2022-blake3-chacha20-poly1305",
    61  }
    62  
    63  func init() {
    64  	random.InitializeSeed()
    65  }
    66  
    67  func NewWithPassword(method string, password string, timeFunc func() time.Time) (shadowsocks.Method, error) {
    68  	var pskList [][]byte
    69  	if password == "" {
    70  		return nil, ErrMissingPSK
    71  	}
    72  	keyStrList := strings.Split(password, ":")
    73  	pskList = make([][]byte, len(keyStrList))
    74  	for i, keyStr := range keyStrList {
    75  		kb, err := base64.StdEncoding.DecodeString(keyStr)
    76  		if err != nil {
    77  			return nil, E.Cause(err, "decode key")
    78  		}
    79  		pskList[i] = kb
    80  	}
    81  	return New(method, pskList, timeFunc)
    82  }
    83  
    84  func New(method string, pskList [][]byte, timeFunc func() time.Time) (shadowsocks.Method, error) {
    85  	m := &Method{
    86  		name:     method,
    87  		timeFunc: timeFunc,
    88  	}
    89  
    90  	switch method {
    91  	case "2022-blake3-aes-128-gcm":
    92  		m.keySaltLength = 16
    93  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    94  		m.blockConstructor = aes.NewCipher
    95  	case "2022-blake3-aes-256-gcm":
    96  		m.keySaltLength = 32
    97  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    98  		m.blockConstructor = aes.NewCipher
    99  	case "2022-blake3-chacha20-poly1305":
   100  		if len(pskList) > 1 {
   101  			return nil, os.ErrInvalid
   102  		}
   103  		m.keySaltLength = 32
   104  		m.constructor = chacha20poly1305.New
   105  	}
   106  
   107  	if len(pskList) == 0 {
   108  		return nil, ErrMissingPSK
   109  	}
   110  
   111  	for i, psk := range pskList {
   112  		if len(psk) < m.keySaltLength {
   113  			return nil, shadowsocks.ErrBadKey
   114  		} else if len(psk) > m.keySaltLength {
   115  			pskList[i] = Key(psk, m.keySaltLength)
   116  		}
   117  	}
   118  
   119  	if len(pskList) > 1 {
   120  		pskHash := make([]byte, (len(pskList)-1)*aes.BlockSize)
   121  		for i, psk := range pskList {
   122  			if i == 0 {
   123  				continue
   124  			}
   125  			hash := blake3.Sum512(psk)
   126  			copy(pskHash[aes.BlockSize*(i-1):aes.BlockSize*i], hash[:aes.BlockSize])
   127  		}
   128  		m.pskHash = pskHash
   129  	}
   130  
   131  	var err error
   132  	switch method {
   133  	case "2022-blake3-aes-128-gcm", "2022-blake3-aes-256-gcm":
   134  		m.udpBlockEncryptCipher, err = aes.NewCipher(pskList[0])
   135  		if err != nil {
   136  			return nil, err
   137  		}
   138  		m.udpBlockDecryptCipher, err = aes.NewCipher(pskList[len(pskList)-1])
   139  		if err != nil {
   140  			return nil, err
   141  		}
   142  	case "2022-blake3-chacha20-poly1305":
   143  		m.udpCipher, err = chacha20poly1305.NewX(pskList[0])
   144  		if err != nil {
   145  			return nil, err
   146  		}
   147  	}
   148  
   149  	m.pskList = pskList
   150  	return m, nil
   151  }
   152  
   153  func Key(key []byte, keyLength int) []byte {
   154  	psk := sha256.Sum256(key)
   155  	return psk[:keyLength]
   156  }
   157  
   158  func SessionKey(psk []byte, salt []byte, keyLength int) []byte {
   159  	sessionKey := make([]byte, len(psk)+len(salt))
   160  	copy(sessionKey, psk)
   161  	copy(sessionKey[len(psk):], salt)
   162  	outKey := make([]byte, keyLength)
   163  	blake3.DeriveKey(outKey, "shadowsocks 2022 session subkey", sessionKey)
   164  	return outKey
   165  }
   166  
   167  func aeadCipher(block func(key []byte) (cipher.Block, error), aead func(block cipher.Block) (cipher.AEAD, error)) func(key []byte) (cipher.AEAD, error) {
   168  	return func(key []byte) (cipher.AEAD, error) {
   169  		b, err := block(key)
   170  		if err != nil {
   171  			return nil, err
   172  		}
   173  		return aead(b)
   174  	}
   175  }
   176  
   177  type Method struct {
   178  	name          string
   179  	keySaltLength int
   180  	timeFunc      func() time.Time
   181  
   182  	constructor           func(key []byte) (cipher.AEAD, error)
   183  	blockConstructor      func(key []byte) (cipher.Block, error)
   184  	udpCipher             cipher.AEAD
   185  	udpBlockEncryptCipher cipher.Block
   186  	udpBlockDecryptCipher cipher.Block
   187  	pskList               [][]byte
   188  	pskHash               []byte
   189  }
   190  
   191  func (m *Method) Name() string {
   192  	return m.name
   193  }
   194  
   195  func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
   196  	shadowsocksConn := &clientConn{
   197  		Method:      m,
   198  		Conn:        conn,
   199  		destination: destination,
   200  	}
   201  	return shadowsocksConn, shadowsocksConn.writeRequest(nil)
   202  }
   203  
   204  func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
   205  	return &clientConn{
   206  		Method:      m,
   207  		Conn:        conn,
   208  		destination: destination,
   209  	}
   210  }
   211  
   212  func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
   213  	return &clientPacketConn{m, conn, m.newUDPSession()}
   214  }
   215  
   216  type clientConn struct {
   217  	*Method
   218  	net.Conn
   219  	destination M.Socksaddr
   220  	requestSalt []byte
   221  	reader      *shadowaead.Reader
   222  	writer      *shadowaead.Writer
   223  }
   224  
   225  func (m *Method) time() time.Time {
   226  	if m.timeFunc != nil {
   227  		return m.timeFunc()
   228  	} else {
   229  		return time.Now()
   230  	}
   231  }
   232  
   233  func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) error {
   234  	pskLen := len(m.pskList)
   235  	if pskLen < 2 {
   236  		return nil
   237  	}
   238  	for i, psk := range m.pskList {
   239  		keyMaterial := make([]byte, m.keySaltLength*2)
   240  		copy(keyMaterial, psk)
   241  		copy(keyMaterial[m.keySaltLength:], salt)
   242  		identitySubkey := buf.NewSize(m.keySaltLength)
   243  		identitySubkey.Extend(identitySubkey.FreeLen())
   244  		blake3.DeriveKey(identitySubkey.Bytes(), "shadowsocks 2022 identity subkey", keyMaterial)
   245  
   246  		pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
   247  
   248  		header := request.Extend(16)
   249  		b, err := m.blockConstructor(identitySubkey.Bytes())
   250  		if err != nil {
   251  			return err
   252  		}
   253  		b.Encrypt(header, pskHash)
   254  		identitySubkey.Release()
   255  		if i == pskLen-2 {
   256  			break
   257  		}
   258  	}
   259  	return nil
   260  }
   261  
   262  func (c *clientConn) writeRequest(payload []byte) error {
   263  	salt := make([]byte, c.keySaltLength)
   264  	common.Must1(io.ReadFull(rand.Reader, salt))
   265  
   266  	key := SessionKey(c.pskList[len(c.pskList)-1], salt, c.keySaltLength)
   267  	writeCipher, err := c.constructor(key)
   268  	if err != nil {
   269  		return err
   270  	}
   271  	writer := shadowaead.NewWriter(
   272  		c.Conn,
   273  		writeCipher,
   274  		MaxPacketSize,
   275  	)
   276  
   277  	header := writer.Buffer()
   278  	header.Write(salt)
   279  
   280  	err = c.writeExtendedIdentityHeaders(header, salt)
   281  	if err != nil {
   282  		return err
   283  	}
   284  
   285  	var _fixedLengthBuffer [RequestHeaderFixedChunkLength]byte
   286  	fixedLengthBuffer := buf.With(_fixedLengthBuffer[:])
   287  	common.Must(fixedLengthBuffer.WriteByte(HeaderTypeClient))
   288  	common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(c.time().Unix())))
   289  	var paddingLen int
   290  	if len(payload) < MaxPaddingLength {
   291  		paddingLen = mRand.Intn(MaxPaddingLength) + 1
   292  	}
   293  	variableLengthHeaderLen := M.SocksaddrSerializer.AddrPortLen(c.destination) + 2 + paddingLen
   294  	payloadLen := len(payload)
   295  	variableLengthHeaderLen += payloadLen
   296  	common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint16(variableLengthHeaderLen)))
   297  	writer.WriteChunk(header, fixedLengthBuffer.Slice())
   298  
   299  	variableLengthBuffer := buf.NewSize(variableLengthHeaderLen)
   300  	err = M.SocksaddrSerializer.WriteAddrPort(variableLengthBuffer, c.destination)
   301  	if err != nil {
   302  		return err
   303  	}
   304  	common.Must(binary.Write(variableLengthBuffer, binary.BigEndian, uint16(paddingLen)))
   305  	if paddingLen > 0 {
   306  		variableLengthBuffer.Extend(paddingLen)
   307  	}
   308  	if payloadLen > 0 {
   309  		common.Must1(variableLengthBuffer.Write(payload[:payloadLen]))
   310  	}
   311  	writer.WriteChunk(header, variableLengthBuffer.Slice())
   312  	variableLengthBuffer.Release()
   313  
   314  	err = writer.BufferedWriter(header.Len()).Flush()
   315  	if err != nil {
   316  		return E.Cause(err, "client handshake")
   317  	}
   318  
   319  	c.requestSalt = salt
   320  	c.writer = writer
   321  	return nil
   322  }
   323  
   324  func (c *clientConn) readResponse() error {
   325  	if c.reader != nil {
   326  		return nil
   327  	}
   328  
   329  	salt := buf.NewSize(c.keySaltLength)
   330  
   331  	_, err := salt.ReadFullFrom(c.Conn, salt.FreeLen())
   332  	if err != nil {
   333  		salt.Release()
   334  		return err
   335  	}
   336  
   337  	key := SessionKey(c.pskList[len(c.pskList)-1], salt.Bytes(), c.keySaltLength)
   338  	salt.Release()
   339  
   340  	readCipher, err := c.constructor(key)
   341  	if err != nil {
   342  		return err
   343  	}
   344  	reader := shadowaead.NewReader(
   345  		c.Conn,
   346  		readCipher,
   347  		MaxPacketSize,
   348  	)
   349  
   350  	err = reader.ReadWithLength(uint16(1 + 8 + c.keySaltLength + 2))
   351  	if err != nil {
   352  		return E.Cause(err, "read response fixed length chunk")
   353  	}
   354  
   355  	headerType, err := rw.ReadByte(reader)
   356  	if err != nil {
   357  		return err
   358  	}
   359  	if headerType != HeaderTypeServer /* && headerType != HeaderTypeServerEncrypted*/ {
   360  		return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType)
   361  	}
   362  
   363  	var epoch uint64
   364  	err = binary.Read(reader, binary.BigEndian, &epoch)
   365  	if err != nil {
   366  		return err
   367  	}
   368  
   369  	diff := int(math.Abs(float64(c.time().Unix() - int64(epoch))))
   370  	if diff > 30 {
   371  		return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
   372  	}
   373  
   374  	requestSalt := buf.NewSize(c.keySaltLength)
   375  	_, err = requestSalt.ReadFullFrom(reader, requestSalt.FreeLen())
   376  	if err != nil {
   377  		return err
   378  	}
   379  
   380  	if bytes.Compare(requestSalt.Bytes(), c.requestSalt) > 0 {
   381  		return ErrBadRequestSalt
   382  	}
   383  	requestSalt.Release()
   384  	c.requestSalt = nil
   385  
   386  	var length uint16
   387  	err = binary.Read(reader, binary.BigEndian, &length)
   388  	if err != nil {
   389  		return err
   390  	}
   391  
   392  	err = reader.ReadWithLength(length)
   393  	if err != nil {
   394  		return err
   395  	}
   396  	if headerType == HeaderTypeServer {
   397  		c.reader = reader
   398  	}
   399  	return nil
   400  }
   401  
   402  func (c *clientConn) Read(p []byte) (n int, err error) {
   403  	if err = c.readResponse(); err != nil {
   404  		return
   405  	}
   406  	return c.reader.Read(p)
   407  }
   408  
   409  func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
   410  	if err = c.readResponse(); err != nil {
   411  		return
   412  	}
   413  	return bufio.Copy(w, c.reader)
   414  }
   415  
   416  func (c *clientConn) Write(p []byte) (n int, err error) {
   417  	if c.writer == nil {
   418  		err = c.writeRequest(p)
   419  		if err == nil {
   420  			n = len(p)
   421  		}
   422  		return
   423  	}
   424  	return c.writer.Write(p)
   425  }
   426  
   427  var _ N.VectorisedWriter = (*clientConn)(nil)
   428  
   429  func (c *clientConn) WriteVectorised(buffers []*buf.Buffer) error {
   430  	if c.writer != nil {
   431  		return c.writer.WriteVectorised(buffers)
   432  	}
   433  	err := c.writeRequest(buffers[0].Bytes())
   434  	if err != nil {
   435  		buf.ReleaseMulti(buffers)
   436  		return err
   437  	}
   438  	buffers[0].Release()
   439  	return c.writer.WriteVectorised(buffers[1:])
   440  }
   441  
   442  func (c *clientConn) NeedHandshake() bool {
   443  	return c.writer == nil
   444  }
   445  
   446  func (c *clientConn) NeedAdditionalReadDeadline() bool {
   447  	return true
   448  }
   449  
   450  func (c *clientConn) Upstream() any {
   451  	return c.Conn
   452  }
   453  
   454  func (c *clientConn) Close() error {
   455  	return common.Close(
   456  		c.Conn,
   457  		common.PtrOrNil(c.reader),
   458  		common.PtrOrNil(c.writer),
   459  	)
   460  }
   461  
   462  type clientPacketConn struct {
   463  	*Method
   464  	net.Conn
   465  	session *udpSession
   466  }
   467  
   468  func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   469  	defer buffer.Release()
   470  	var hdrLen int
   471  	if c.udpCipher != nil {
   472  		hdrLen = PacketNonceSize
   473  	}
   474  
   475  	var paddingLen int
   476  	if destination.Port == 53 && buffer.Len() < MaxPaddingLength {
   477  		paddingLen = mRand.Intn(MaxPaddingLength-buffer.Len()) + 1
   478  	}
   479  
   480  	hdrLen += 16 // packet header
   481  	pskLen := len(c.pskList)
   482  	if c.udpCipher == nil && pskLen > 1 {
   483  		hdrLen += (pskLen - 1) * aes.BlockSize
   484  	}
   485  	hdrLen += 1 // header type
   486  	hdrLen += 8 // timestamp
   487  	hdrLen += 2 // padding length
   488  	hdrLen += paddingLen
   489  	hdrLen += M.SocksaddrSerializer.AddrPortLen(destination)
   490  	header := buf.With(buffer.ExtendHeader(hdrLen))
   491  
   492  	var dataIndex int
   493  	if c.udpCipher != nil {
   494  		common.Must1(header.ReadFullFrom(c.session.rng, PacketNonceSize))
   495  		if pskLen > 1 {
   496  			panic("unsupported chacha extended header")
   497  		}
   498  		dataIndex = PacketNonceSize
   499  	} else {
   500  		dataIndex = aes.BlockSize
   501  	}
   502  
   503  	common.Must(
   504  		binary.Write(header, binary.BigEndian, c.session.sessionId),
   505  		binary.Write(header, binary.BigEndian, c.session.nextPacketId()),
   506  	)
   507  
   508  	if c.udpCipher == nil && pskLen > 1 {
   509  		for i, psk := range c.pskList {
   510  			dataIndex += aes.BlockSize
   511  			pskHash := c.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
   512  
   513  			identityHeader := header.Extend(aes.BlockSize)
   514  			xorWords(identityHeader, pskHash, header.To(aes.BlockSize))
   515  			b, err := c.blockConstructor(psk)
   516  			if err != nil {
   517  				return err
   518  			}
   519  			b.Encrypt(identityHeader, identityHeader)
   520  
   521  			if i == pskLen-2 {
   522  				break
   523  			}
   524  		}
   525  	}
   526  	common.Must(
   527  		header.WriteByte(HeaderTypeClient),
   528  		binary.Write(header, binary.BigEndian, uint64(c.time().Unix())),
   529  		binary.Write(header, binary.BigEndian, uint16(paddingLen)), // padding length
   530  	)
   531  
   532  	if paddingLen > 0 {
   533  		header.Extend(paddingLen)
   534  	}
   535  
   536  	err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
   537  	if err != nil {
   538  		return err
   539  	}
   540  	if c.udpCipher != nil {
   541  		c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
   542  		buffer.Extend(shadowaead.Overhead)
   543  	} else {
   544  		packetHeader := buffer.To(aes.BlockSize)
   545  		c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
   546  		buffer.Extend(shadowaead.Overhead)
   547  		c.udpBlockEncryptCipher.Encrypt(packetHeader, packetHeader)
   548  	}
   549  	return common.Error(c.Write(buffer.Bytes()))
   550  }
   551  
   552  func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
   553  	n, err := c.Read(buffer.FreeBytes())
   554  	if err != nil {
   555  		return M.Socksaddr{}, err
   556  	}
   557  	buffer.Truncate(n)
   558  
   559  	var packetHeader []byte
   560  	if c.udpCipher != nil {
   561  		if buffer.Len() < PacketNonceSize+PacketMinimalHeaderSize {
   562  			return M.Socksaddr{}, ErrPacketTooShort
   563  		}
   564  		_, err = c.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
   565  		if err != nil {
   566  			return M.Socksaddr{}, E.Cause(err, "decrypt packet")
   567  		}
   568  		buffer.Advance(PacketNonceSize)
   569  		buffer.Truncate(buffer.Len() - shadowaead.Overhead)
   570  	} else {
   571  		if buffer.Len() < PacketMinimalHeaderSize {
   572  			return M.Socksaddr{}, ErrPacketTooShort
   573  		}
   574  		packetHeader = buffer.To(aes.BlockSize)
   575  		c.udpBlockDecryptCipher.Decrypt(packetHeader, packetHeader)
   576  	}
   577  
   578  	var sessionId, packetId uint64
   579  	err = binary.Read(buffer, binary.BigEndian, &sessionId)
   580  	if err != nil {
   581  		return M.Socksaddr{}, err
   582  	}
   583  	err = binary.Read(buffer, binary.BigEndian, &packetId)
   584  	if err != nil {
   585  		return M.Socksaddr{}, err
   586  	}
   587  
   588  	if sessionId == c.session.remoteSessionId {
   589  		if !c.session.window.Check(packetId) {
   590  			return M.Socksaddr{}, ErrPacketIdNotUnique
   591  		}
   592  	} else if sessionId == c.session.lastRemoteSessionId {
   593  		if !c.session.lastWindow.Check(packetId) {
   594  			return M.Socksaddr{}, ErrPacketIdNotUnique
   595  		}
   596  	}
   597  
   598  	var remoteCipher cipher.AEAD
   599  	if packetHeader != nil {
   600  		if sessionId == c.session.remoteSessionId {
   601  			remoteCipher = c.session.remoteCipher
   602  		} else if sessionId == c.session.lastRemoteSessionId {
   603  			remoteCipher = c.session.lastRemoteCipher
   604  		} else {
   605  			key := SessionKey(c.pskList[len(c.pskList)-1], packetHeader[:8], c.keySaltLength)
   606  			remoteCipher, err = c.constructor(key)
   607  			if err != nil {
   608  				return M.Socksaddr{}, err
   609  			}
   610  		}
   611  		_, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
   612  		if err != nil {
   613  			return M.Socksaddr{}, E.Cause(err, "decrypt packet")
   614  		}
   615  		buffer.Truncate(buffer.Len() - shadowaead.Overhead)
   616  	}
   617  
   618  	var headerType byte
   619  	headerType, err = buffer.ReadByte()
   620  	if err != nil {
   621  		return M.Socksaddr{}, err
   622  	}
   623  	if headerType != HeaderTypeServer {
   624  		return M.Socksaddr{}, E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType)
   625  	}
   626  
   627  	var epoch uint64
   628  	err = binary.Read(buffer, binary.BigEndian, &epoch)
   629  	if err != nil {
   630  		return M.Socksaddr{}, err
   631  	}
   632  
   633  	diff := int(math.Abs(float64(c.time().Unix() - int64(epoch))))
   634  	if diff > 30 {
   635  		return M.Socksaddr{}, E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
   636  	}
   637  
   638  	if sessionId == c.session.remoteSessionId {
   639  		c.session.window.Add(packetId)
   640  	} else if sessionId == c.session.lastRemoteSessionId {
   641  		c.session.lastWindow.Add(packetId)
   642  		c.session.lastRemoteSeen = c.time().Unix()
   643  	} else {
   644  		if c.session.remoteSessionId != 0 {
   645  			if c.time().Unix()-c.session.lastRemoteSeen < 60 {
   646  				return M.Socksaddr{}, ErrTooManyServerSessions
   647  			} else {
   648  				c.session.lastRemoteSessionId = c.session.remoteSessionId
   649  				c.session.lastWindow = c.session.window
   650  				c.session.lastRemoteSeen = c.time().Unix()
   651  				c.session.lastRemoteCipher = c.session.remoteCipher
   652  				c.session.window = SlidingWindow{}
   653  			}
   654  		}
   655  		c.session.remoteSessionId = sessionId
   656  		c.session.remoteCipher = remoteCipher
   657  		c.session.window.Add(packetId)
   658  	}
   659  
   660  	var clientSessionId uint64
   661  	err = binary.Read(buffer, binary.BigEndian, &clientSessionId)
   662  	if err != nil {
   663  		return M.Socksaddr{}, err
   664  	}
   665  
   666  	if clientSessionId != c.session.sessionId {
   667  		return M.Socksaddr{}, ErrBadClientSessionId
   668  	}
   669  
   670  	var paddingLen uint16
   671  	err = binary.Read(buffer, binary.BigEndian, &paddingLen)
   672  	if err != nil {
   673  		return M.Socksaddr{}, E.Cause(err, "read padding length")
   674  	}
   675  	buffer.Advance(int(paddingLen))
   676  
   677  	destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
   678  	if err != nil {
   679  		return M.Socksaddr{}, err
   680  	}
   681  	return destination, nil
   682  }
   683  
   684  func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   685  	buffer := buf.With(p)
   686  	destination, err := c.ReadPacket(buffer)
   687  	if err != nil {
   688  		return
   689  	}
   690  	if destination.IsFqdn() {
   691  		addr = destination
   692  	} else {
   693  		addr = destination.UDPAddr()
   694  	}
   695  	n = copy(p, buffer.Bytes())
   696  	return
   697  }
   698  
   699  func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   700  	destination := M.SocksaddrFromNet(addr)
   701  	var overHead int
   702  	if c.udpCipher != nil {
   703  		overHead = PacketNonceSize + shadowaead.Overhead
   704  	} else {
   705  		overHead = shadowaead.Overhead
   706  	}
   707  	overHead += 16 // packet header
   708  	pskLen := len(c.pskList)
   709  	if c.udpCipher == nil && pskLen > 1 {
   710  		overHead += (pskLen - 1) * aes.BlockSize
   711  	}
   712  	var paddingLen int
   713  	if destination.Port == 53 && len(p) < MaxPaddingLength {
   714  		paddingLen = mRand.Intn(MaxPaddingLength-len(p)) + 1
   715  	}
   716  	overHead += 1 // header type
   717  	overHead += 8 // timestamp
   718  	overHead += 2 // padding length
   719  	overHead += paddingLen
   720  	overHead += M.SocksaddrSerializer.AddrPortLen(destination)
   721  
   722  	buffer := buf.NewSize(overHead + len(p))
   723  	defer buffer.Release()
   724  
   725  	var dataIndex int
   726  	if c.udpCipher != nil {
   727  		common.Must1(buffer.ReadFullFrom(c.session.rng, PacketNonceSize))
   728  		if pskLen > 1 {
   729  			panic("unsupported chacha extended header")
   730  		}
   731  		dataIndex = PacketNonceSize
   732  	} else {
   733  		dataIndex = aes.BlockSize
   734  	}
   735  
   736  	common.Must(
   737  		binary.Write(buffer, binary.BigEndian, c.session.sessionId),
   738  		binary.Write(buffer, binary.BigEndian, c.session.nextPacketId()),
   739  	)
   740  
   741  	if c.udpCipher == nil && pskLen > 1 {
   742  		for i, psk := range c.pskList {
   743  			dataIndex += aes.BlockSize
   744  			pskHash := c.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
   745  
   746  			identityHeader := buffer.Extend(aes.BlockSize)
   747  			xorWords(identityHeader, pskHash, buffer.To(aes.BlockSize))
   748  			b, err := c.blockConstructor(psk)
   749  			if err != nil {
   750  				return 0, err
   751  			}
   752  			b.Encrypt(identityHeader, identityHeader)
   753  
   754  			if i == pskLen-2 {
   755  				break
   756  			}
   757  		}
   758  	}
   759  	common.Must(
   760  		buffer.WriteByte(HeaderTypeClient),
   761  		binary.Write(buffer, binary.BigEndian, uint64(c.time().Unix())),
   762  		binary.Write(buffer, binary.BigEndian, uint16(paddingLen)), // padding length
   763  	)
   764  
   765  	if paddingLen > 0 {
   766  		buffer.Extend(paddingLen)
   767  	}
   768  
   769  	err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination)
   770  	if err != nil {
   771  		return
   772  	}
   773  	common.Must1(buffer.Write(p))
   774  	if c.udpCipher != nil {
   775  		c.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
   776  		buffer.Extend(shadowaead.Overhead)
   777  	} else {
   778  		packetHeader := buffer.To(aes.BlockSize)
   779  		c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
   780  		buffer.Extend(shadowaead.Overhead)
   781  		c.udpBlockEncryptCipher.Encrypt(packetHeader, packetHeader)
   782  	}
   783  	err = common.Error(c.Write(buffer.Bytes()))
   784  	if err != nil {
   785  		return
   786  	}
   787  	return len(p), nil
   788  }
   789  
   790  func (c *clientPacketConn) FrontHeadroom() int {
   791  	var overHead int
   792  	if c.udpCipher != nil {
   793  		overHead = PacketNonceSize + shadowaead.Overhead
   794  	} else {
   795  		overHead = shadowaead.Overhead
   796  	}
   797  	overHead += 16 // packet header
   798  	pskLen := len(c.pskList)
   799  	if c.udpCipher == nil && pskLen > 1 {
   800  		overHead += (pskLen - 1) * aes.BlockSize
   801  	}
   802  	overHead += 1 // header type
   803  	overHead += 8 // timestamp
   804  	overHead += 2 // padding length
   805  	overHead += MaxPaddingLength
   806  	overHead += M.MaxSocksaddrLength
   807  	return overHead
   808  }
   809  
   810  func (c *clientPacketConn) RearHeadroom() int {
   811  	return shadowaead.Overhead
   812  }
   813  
   814  type udpSession struct {
   815  	sessionId           uint64
   816  	packetId            uint64
   817  	remoteSessionId     uint64
   818  	lastRemoteSessionId uint64
   819  	lastRemoteSeen      int64
   820  	cipher              cipher.AEAD
   821  	remoteCipher        cipher.AEAD
   822  	lastRemoteCipher    cipher.AEAD
   823  	window              SlidingWindow
   824  	lastWindow          SlidingWindow
   825  	rng                 io.Reader
   826  }
   827  
   828  func (s *udpSession) nextPacketId() uint64 {
   829  	return atomic.AddUint64(&s.packetId, 1)
   830  }
   831  
   832  func (m *Method) newUDPSession() *udpSession {
   833  	session := &udpSession{}
   834  	if m.udpCipher != nil {
   835  		session.rng = Blake3KeyedHash(rand.Reader)
   836  		common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId))
   837  	} else {
   838  		common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId))
   839  	}
   840  	session.packetId--
   841  	if m.udpCipher == nil {
   842  		sessionId := make([]byte, 8)
   843  		binary.BigEndian.PutUint64(sessionId, session.sessionId)
   844  		key := SessionKey(m.pskList[len(m.pskList)-1], sessionId, m.keySaltLength)
   845  		var err error
   846  		session.cipher, err = m.constructor(key)
   847  		if err != nil {
   848  			return nil
   849  		}
   850  	}
   851  	return session
   852  }
   853  
   854  func (c *clientPacketConn) Upstream() any {
   855  	return c.Conn
   856  }
   857  
   858  func (c *clientPacketConn) Close() error {
   859  	return common.Close(c.Conn)
   860  }
   861  
   862  func Blake3KeyedHash(reader io.Reader) io.Reader {
   863  	key := make([]byte, 32)
   864  	common.Must1(io.ReadFull(reader, key))
   865  	h := blake3.New(1024, key)
   866  	return h.XOF()
   867  }