github.com/MerlinKodo/sing-shadowsocks2@v0.1.6/shadowaead_2022/method.go (about)

     1  package shadowaead_2022
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/aes"
     7  	"crypto/cipher"
     8  	"crypto/rand"
     9  	"encoding/base64"
    10  	"encoding/binary"
    11  	"math"
    12  	mRand "math/rand"
    13  	"net"
    14  	"os"
    15  	"strings"
    16  	"time"
    17  
    18  	C "github.com/MerlinKodo/sing-shadowsocks2/cipher"
    19  	"github.com/MerlinKodo/sing-shadowsocks2/internal/shadowio"
    20  	"github.com/sagernet/sing/common"
    21  	"github.com/sagernet/sing/common/buf"
    22  	"github.com/sagernet/sing/common/bufio"
    23  	E "github.com/sagernet/sing/common/exceptions"
    24  	M "github.com/sagernet/sing/common/metadata"
    25  	N "github.com/sagernet/sing/common/network"
    26  	"github.com/sagernet/sing/common/ntp"
    27  
    28  	"golang.org/x/crypto/chacha20poly1305"
    29  	"lukechampine.com/blake3"
    30  )
    31  
    32  var MethodList = []string{
    33  	"2022-blake3-aes-128-gcm",
    34  	"2022-blake3-aes-256-gcm",
    35  	"2022-blake3-chacha20-poly1305",
    36  }
    37  
    38  func init() {
    39  	C.RegisterMethod(MethodList, NewMethod)
    40  }
    41  
    42  type Method struct {
    43  	keySaltLength         int
    44  	timeFunc              func() time.Time
    45  	constructor           func(key []byte) (cipher.AEAD, error)
    46  	blockConstructor      func(key []byte) (cipher.Block, error)
    47  	udpCipher             cipher.AEAD
    48  	udpBlockEncryptCipher cipher.Block
    49  	udpBlockDecryptCipher cipher.Block
    50  	pskList               [][]byte
    51  	pskHash               []byte
    52  }
    53  
    54  func NewMethod(ctx context.Context, methodName string, options C.MethodOptions) (C.Method, error) {
    55  	m := &Method{
    56  		timeFunc: ntp.TimeFuncFromContext(ctx),
    57  		pskList:  options.KeyList,
    58  	}
    59  	if options.Password != "" {
    60  		var pskList [][]byte
    61  		keyStrList := strings.Split(options.Password, ":")
    62  		pskList = make([][]byte, len(keyStrList))
    63  		for i, keyStr := range keyStrList {
    64  			kb, err := base64.StdEncoding.DecodeString(keyStr)
    65  			if err != nil {
    66  				return nil, E.Cause(err, "decode key")
    67  			}
    68  			pskList[i] = kb
    69  		}
    70  		m.pskList = pskList
    71  	}
    72  	switch methodName {
    73  	case "2022-blake3-aes-128-gcm":
    74  		m.keySaltLength = 16
    75  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    76  		m.blockConstructor = aes.NewCipher
    77  	case "2022-blake3-aes-256-gcm":
    78  		m.keySaltLength = 32
    79  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    80  		m.blockConstructor = aes.NewCipher
    81  	case "2022-blake3-chacha20-poly1305":
    82  		if len(options.KeyList) > 1 {
    83  			return nil, ErrNoEIH
    84  		}
    85  		m.keySaltLength = 32
    86  		m.constructor = chacha20poly1305.New
    87  	default:
    88  		return nil, os.ErrInvalid
    89  	}
    90  	if len(m.pskList) == 0 {
    91  		return nil, C.ErrMissingPassword
    92  	}
    93  	for _, key := range m.pskList {
    94  		if len(key) != m.keySaltLength {
    95  			return nil, E.New("bad key length, required ", m.keySaltLength, ", got ", len(key))
    96  		}
    97  	}
    98  	if len(m.pskList) > 1 {
    99  		pskHash := make([]byte, (len(m.pskList)-1)*aes.BlockSize)
   100  		for i, key := range m.pskList {
   101  			if i == 0 {
   102  				continue
   103  			}
   104  			keyHash := blake3.Sum512(key)
   105  			copy(pskHash[aes.BlockSize*(i-1):aes.BlockSize*i], keyHash[:aes.BlockSize])
   106  		}
   107  		m.pskHash = pskHash
   108  	}
   109  	var err error
   110  	switch methodName {
   111  	case "2022-blake3-aes-128-gcm", "2022-blake3-aes-256-gcm":
   112  		m.udpBlockEncryptCipher, err = aes.NewCipher(m.pskList[0])
   113  		if err != nil {
   114  			return nil, err
   115  		}
   116  		m.udpBlockDecryptCipher, err = aes.NewCipher(m.pskList[len(m.pskList)-1])
   117  		if err != nil {
   118  			return nil, err
   119  		}
   120  	case "2022-blake3-chacha20-poly1305":
   121  		m.udpCipher, err = chacha20poly1305.NewX(m.pskList[0])
   122  		if err != nil {
   123  			return nil, err
   124  		}
   125  	}
   126  	return m, nil
   127  }
   128  
   129  func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
   130  	shadowsocksConn := &clientConn{
   131  		Conn:        conn,
   132  		method:      m,
   133  		destination: destination,
   134  	}
   135  	return shadowsocksConn, shadowsocksConn.writeRequest(nil)
   136  }
   137  
   138  func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
   139  	return &clientConn{
   140  		Conn:        conn,
   141  		method:      m,
   142  		destination: destination,
   143  	}
   144  }
   145  
   146  func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
   147  	pc := &clientPacketConn{
   148  		AbstractConn: conn,
   149  		reader:       bufio.NewExtendedReader(conn),
   150  		writer:       bufio.NewExtendedWriter(conn),
   151  		method:       m,
   152  		session:      m.newUDPSession(),
   153  	}
   154  	if waitRead, isWaitRead := N.CastReader[shadowio.WaitReadReader](conn); isWaitRead {
   155  		return &clientWaitPacketConn{
   156  			clientPacketConn: pc,
   157  			waitRead:         waitRead,
   158  		}
   159  	}
   160  	return pc
   161  }
   162  
   163  func (m *Method) time() time.Time {
   164  	if m.timeFunc != nil {
   165  		return m.timeFunc()
   166  	} else {
   167  		return time.Now()
   168  	}
   169  }
   170  
   171  type clientConn struct {
   172  	net.Conn
   173  	method      *Method
   174  	destination M.Socksaddr
   175  	requestSalt []byte
   176  	reader      *shadowio.Reader
   177  	writer      *shadowio.Writer
   178  	shadowio.WriterInterface
   179  }
   180  
   181  func (c *clientConn) writeRequest(payload []byte) error {
   182  	requestSalt := make([]byte, c.method.keySaltLength)
   183  	requestBuffer := buf.New()
   184  	defer requestBuffer.Release()
   185  	requestBuffer.WriteRandom(c.method.keySaltLength)
   186  	copy(requestSalt, requestBuffer.Bytes())
   187  	key := SessionKey(c.method.pskList[len(c.method.pskList)-1], requestSalt, c.method.keySaltLength)
   188  	writeCipher, err := c.method.constructor(key)
   189  	if err != nil {
   190  		return err
   191  	}
   192  	writer := shadowio.NewWriter(
   193  		c.Conn,
   194  		writeCipher,
   195  		nil,
   196  		buf.BufferSize-shadowio.PacketLengthBufferSize-shadowio.Overhead*2,
   197  	)
   198  	err = c.method.writeExtendedIdentityHeaders(requestBuffer, requestBuffer.To(c.method.keySaltLength))
   199  	if err != nil {
   200  		return err
   201  	}
   202  	fixedLengthBuffer := buf.With(requestBuffer.Extend(RequestHeaderFixedChunkLength + shadowio.Overhead))
   203  	common.Must(fixedLengthBuffer.WriteByte(HeaderTypeClient))
   204  	common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(c.method.time().Unix())))
   205  	variableLengthHeaderLen := M.SocksaddrSerializer.AddrPortLen(c.destination) + 2
   206  	var paddingLen int
   207  	if len(payload) < MaxPaddingLength {
   208  		paddingLen = mRand.Intn(MaxPaddingLength) + 1
   209  	}
   210  	variableLengthHeaderLen += paddingLen
   211  	maxPayloadLen := requestBuffer.FreeLen() - (variableLengthHeaderLen + shadowio.Overhead)
   212  	payloadLen := len(payload)
   213  	if payloadLen > maxPayloadLen {
   214  		payloadLen = maxPayloadLen
   215  	}
   216  	variableLengthHeaderLen += payloadLen
   217  	common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint16(variableLengthHeaderLen)))
   218  	writer.Encrypt(fixedLengthBuffer.Index(0), fixedLengthBuffer.Bytes())
   219  	fixedLengthBuffer.Extend(shadowio.Overhead)
   220  
   221  	variableLengthBuffer := buf.With(requestBuffer.Extend(variableLengthHeaderLen + shadowio.Overhead))
   222  	err = M.SocksaddrSerializer.WriteAddrPort(variableLengthBuffer, c.destination)
   223  	if err != nil {
   224  		return err
   225  	}
   226  	common.Must(binary.Write(variableLengthBuffer, binary.BigEndian, uint16(paddingLen)))
   227  	if paddingLen > 0 {
   228  		variableLengthBuffer.Extend(paddingLen)
   229  	}
   230  	if payloadLen > 0 {
   231  		common.Must1(variableLengthBuffer.Write(payload[:payloadLen]))
   232  	}
   233  	writer.Encrypt(variableLengthBuffer.Index(0), variableLengthBuffer.Bytes())
   234  	variableLengthBuffer.Extend(shadowio.Overhead)
   235  	_, err = c.Conn.Write(requestBuffer.Bytes())
   236  	if err != nil {
   237  		return err
   238  	}
   239  	if len(payload) > payloadLen {
   240  		_, err = writer.Write(payload[payloadLen:])
   241  		if err != nil {
   242  			return err
   243  		}
   244  	}
   245  	c.requestSalt = requestSalt
   246  	c.writer = writer
   247  	return nil
   248  }
   249  
   250  func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) error {
   251  	pskLen := len(m.pskList)
   252  	if pskLen < 2 {
   253  		return nil
   254  	}
   255  	for i, psk := range m.pskList {
   256  		keyMaterial := make([]byte, m.keySaltLength*2)
   257  		copy(keyMaterial, psk)
   258  		copy(keyMaterial[m.keySaltLength:], salt)
   259  		identitySubkey := make([]byte, m.keySaltLength)
   260  		blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
   261  		pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
   262  		header := request.Extend(16)
   263  		b, err := m.blockConstructor(identitySubkey)
   264  		if err != nil {
   265  			return err
   266  		}
   267  		b.Encrypt(header, pskHash)
   268  		if i == pskLen-2 {
   269  			break
   270  		}
   271  	}
   272  	return nil
   273  }
   274  
   275  func (c *clientConn) readResponse() error {
   276  	salt := buf.NewSize(c.method.keySaltLength)
   277  	defer salt.Release()
   278  	_, err := salt.ReadFullFrom(c.Conn, c.method.keySaltLength)
   279  	if err != nil {
   280  		return err
   281  	}
   282  	key := SessionKey(c.method.pskList[len(c.method.pskList)-1], salt.Bytes(), c.method.keySaltLength)
   283  	readCipher, err := c.method.constructor(key)
   284  	if err != nil {
   285  		return err
   286  	}
   287  	reader := shadowio.NewReader(c.Conn, readCipher)
   288  	fixedResponseBuffer, err := reader.ReadFixedBuffer(1 + 8 + c.method.keySaltLength + 2)
   289  	if err != nil {
   290  		return err
   291  	}
   292  	headerType := common.Must1(fixedResponseBuffer.ReadByte())
   293  	if headerType != HeaderTypeServer {
   294  		return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType)
   295  	}
   296  	var epoch uint64
   297  	common.Must(binary.Read(fixedResponseBuffer, binary.BigEndian, &epoch))
   298  	diff := int(math.Abs(float64(c.method.time().Unix() - int64(epoch))))
   299  	if diff > 30 {
   300  		return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
   301  	}
   302  	responseSalt := common.Must1(fixedResponseBuffer.ReadBytes(c.method.keySaltLength))
   303  	if !bytes.Equal(responseSalt, c.requestSalt) {
   304  		return ErrBadRequestSalt
   305  	}
   306  	var length uint16
   307  	common.Must(binary.Read(reader, binary.BigEndian, &length))
   308  	_, err = reader.ReadFixedBuffer(int(length))
   309  	if err != nil {
   310  		return err
   311  	}
   312  	c.reader = reader
   313  	return nil
   314  }
   315  
   316  func (c *clientConn) Read(p []byte) (n int, err error) {
   317  	if c.reader == nil {
   318  		if err = c.readResponse(); err != nil {
   319  			return
   320  		}
   321  	}
   322  	return c.reader.Read(p)
   323  }
   324  
   325  func (c *clientConn) ReadBuffer(buffer *buf.Buffer) error {
   326  	if c.reader == nil {
   327  		err := c.readResponse()
   328  		if err != nil {
   329  			return err
   330  		}
   331  	}
   332  	return c.reader.ReadBuffer(buffer)
   333  }
   334  
   335  func (c *clientConn) ReadBufferThreadSafe() (buffer *buf.Buffer, err error) {
   336  	if c.reader == nil {
   337  		err = c.readResponse()
   338  		if err != nil {
   339  			return
   340  		}
   341  
   342  	}
   343  	return c.reader.ReadBufferThreadSafe()
   344  }
   345  
   346  func (c *clientConn) Write(p []byte) (n int, err error) {
   347  	if c.writer == nil {
   348  		err = c.writeRequest(p)
   349  		if err == nil {
   350  			n = len(p)
   351  		}
   352  		return
   353  	}
   354  	return c.writer.Write(p)
   355  }
   356  
   357  func (c *clientConn) WriteBuffer(buffer *buf.Buffer) error {
   358  	if c.writer == nil {
   359  		defer buffer.Release()
   360  		return c.writeRequest(buffer.Bytes())
   361  	}
   362  	return c.writer.WriteBuffer(buffer)
   363  }
   364  
   365  func (c *clientConn) NeedHandshake() bool {
   366  	return c.writer == nil
   367  }
   368  
   369  func (c *clientConn) Upstream() any {
   370  	return c.Conn
   371  }
   372  
   373  func (c *clientConn) Close() error {
   374  	return common.Close(
   375  		c.Conn,
   376  		common.PtrOrNil(c.reader),
   377  		common.PtrOrNil(c.writer),
   378  	)
   379  }
   380  
   381  type clientPacketConn struct {
   382  	N.AbstractConn
   383  	reader  N.ExtendedReader
   384  	writer  N.ExtendedWriter
   385  	method  *Method
   386  	session *udpSession
   387  }
   388  
   389  func (m *Method) newUDPSession() *udpSession {
   390  	session := &udpSession{}
   391  	if m.udpCipher != nil {
   392  		session.rng = Blake3KeyedHash(rand.Reader)
   393  		common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId))
   394  	} else {
   395  		common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId))
   396  	}
   397  	session.packetId--
   398  	if m.udpCipher == nil {
   399  		sessionId := make([]byte, 8)
   400  		binary.BigEndian.PutUint64(sessionId, session.sessionId)
   401  		key := SessionKey(m.pskList[len(m.pskList)-1], sessionId, m.keySaltLength)
   402  		var err error
   403  		session.cipher, err = m.constructor(key)
   404  		if err != nil {
   405  			return nil
   406  		}
   407  	}
   408  	return session
   409  }
   410  
   411  func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   412  	var hdrLen int
   413  	if c.method.udpCipher != nil {
   414  		hdrLen = PacketNonceSize
   415  	}
   416  
   417  	var paddingLen int
   418  	if destination.Port == 53 && buffer.Len() < MaxPaddingLength {
   419  		paddingLen = mRand.Intn(MaxPaddingLength-buffer.Len()) + 1
   420  	}
   421  
   422  	hdrLen += 16 // packet header
   423  	pskLen := len(c.method.pskList)
   424  	if c.method.udpCipher == nil && pskLen > 1 {
   425  		hdrLen += (pskLen - 1) * aes.BlockSize
   426  	}
   427  	hdrLen += 1 // header type
   428  	hdrLen += 8 // timestamp
   429  	hdrLen += 2 // padding length
   430  	hdrLen += paddingLen
   431  	hdrLen += M.SocksaddrSerializer.AddrPortLen(destination)
   432  	header := buf.With(buffer.ExtendHeader(hdrLen))
   433  
   434  	var dataIndex int
   435  	if c.method.udpCipher != nil {
   436  		common.Must1(header.ReadFullFrom(c.session.rng, PacketNonceSize))
   437  		if pskLen > 1 {
   438  			panic("unsupported chacha extended header")
   439  		}
   440  		dataIndex = PacketNonceSize
   441  	} else {
   442  		dataIndex = aes.BlockSize
   443  	}
   444  
   445  	common.Must(
   446  		binary.Write(header, binary.BigEndian, c.session.sessionId),
   447  		binary.Write(header, binary.BigEndian, c.session.nextPacketId()),
   448  	)
   449  
   450  	if c.method.udpCipher == nil && pskLen > 1 {
   451  		for i, psk := range c.method.pskList {
   452  			dataIndex += aes.BlockSize
   453  			pskHash := c.method.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
   454  
   455  			identityHeader := header.Extend(aes.BlockSize)
   456  			xorWords(identityHeader, pskHash, header.To(aes.BlockSize))
   457  			b, err := c.method.blockConstructor(psk)
   458  			if err != nil {
   459  				return err
   460  			}
   461  			b.Encrypt(identityHeader, identityHeader)
   462  
   463  			if i == pskLen-2 {
   464  				break
   465  			}
   466  		}
   467  	}
   468  	common.Must(
   469  		header.WriteByte(HeaderTypeClient),
   470  		binary.Write(header, binary.BigEndian, uint64(c.method.time().Unix())),
   471  		binary.Write(header, binary.BigEndian, uint16(paddingLen)), // padding length
   472  	)
   473  
   474  	if paddingLen > 0 {
   475  		header.Extend(paddingLen)
   476  	}
   477  
   478  	err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
   479  	if err != nil {
   480  		return err
   481  	}
   482  	if c.method.udpCipher != nil {
   483  		c.method.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
   484  		buffer.Extend(shadowio.Overhead)
   485  	} else {
   486  		packetHeader := buffer.To(aes.BlockSize)
   487  		c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
   488  		buffer.Extend(shadowio.Overhead)
   489  		c.method.udpBlockEncryptCipher.Encrypt(packetHeader, packetHeader)
   490  	}
   491  	return c.writer.WriteBuffer(buffer)
   492  }
   493  
   494  func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   495  	err = c.reader.ReadBuffer(buffer)
   496  	if err != nil {
   497  		return
   498  	}
   499  	return c.readPacket(buffer)
   500  }
   501  
   502  func (c *clientPacketConn) readPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   503  	var packetHeader []byte
   504  	if c.method.udpCipher != nil {
   505  		if buffer.Len() < PacketNonceSize+PacketMinimalHeaderSize {
   506  			return M.Socksaddr{}, C.ErrPacketTooShort
   507  		}
   508  		_, err = c.method.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
   509  		if err != nil {
   510  			return M.Socksaddr{}, E.Cause(err, "decrypt packet")
   511  		}
   512  		buffer.Advance(PacketNonceSize)
   513  		buffer.Truncate(buffer.Len() - shadowio.Overhead)
   514  	} else {
   515  		if buffer.Len() < PacketMinimalHeaderSize {
   516  			return M.Socksaddr{}, C.ErrPacketTooShort
   517  		}
   518  		packetHeader = buffer.To(aes.BlockSize)
   519  		c.method.udpBlockDecryptCipher.Decrypt(packetHeader, packetHeader)
   520  	}
   521  
   522  	var sessionId, packetId uint64
   523  	err = binary.Read(buffer, binary.BigEndian, &sessionId)
   524  	if err != nil {
   525  		return M.Socksaddr{}, err
   526  	}
   527  	err = binary.Read(buffer, binary.BigEndian, &packetId)
   528  	if err != nil {
   529  		return M.Socksaddr{}, err
   530  	}
   531  
   532  	if sessionId == c.session.remoteSessionId {
   533  		if !c.session.window.Check(packetId) {
   534  			return M.Socksaddr{}, ErrPacketIdNotUnique
   535  		}
   536  	} else if sessionId == c.session.lastRemoteSessionId {
   537  		if !c.session.lastWindow.Check(packetId) {
   538  			return M.Socksaddr{}, ErrPacketIdNotUnique
   539  		}
   540  	}
   541  
   542  	var remoteCipher cipher.AEAD
   543  	if packetHeader != nil {
   544  		if sessionId == c.session.remoteSessionId {
   545  			remoteCipher = c.session.remoteCipher
   546  		} else if sessionId == c.session.lastRemoteSessionId {
   547  			remoteCipher = c.session.lastRemoteCipher
   548  		} else {
   549  			key := SessionKey(c.method.pskList[len(c.method.pskList)-1], packetHeader[:8], c.method.keySaltLength)
   550  			remoteCipher, err = c.method.constructor(key)
   551  			if err != nil {
   552  				return M.Socksaddr{}, err
   553  			}
   554  		}
   555  		_, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
   556  		if err != nil {
   557  			return M.Socksaddr{}, E.Cause(err, "decrypt packet")
   558  		}
   559  		buffer.Truncate(buffer.Len() - shadowio.Overhead)
   560  	}
   561  
   562  	var headerType byte
   563  	headerType, err = buffer.ReadByte()
   564  	if err != nil {
   565  		return M.Socksaddr{}, err
   566  	}
   567  	if headerType != HeaderTypeServer {
   568  		return M.Socksaddr{}, E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType)
   569  	}
   570  
   571  	var epoch uint64
   572  	err = binary.Read(buffer, binary.BigEndian, &epoch)
   573  	if err != nil {
   574  		return M.Socksaddr{}, err
   575  	}
   576  
   577  	diff := int(math.Abs(float64(c.method.time().Unix() - int64(epoch))))
   578  	if diff > 30 {
   579  		return M.Socksaddr{}, E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
   580  	}
   581  
   582  	if sessionId == c.session.remoteSessionId {
   583  		c.session.window.Add(packetId)
   584  	} else if sessionId == c.session.lastRemoteSessionId {
   585  		c.session.lastWindow.Add(packetId)
   586  		c.session.lastRemoteSeen = c.method.time().Unix()
   587  	} else {
   588  		if c.session.remoteSessionId != 0 {
   589  			if c.method.time().Unix()-c.session.lastRemoteSeen < 60 {
   590  				return M.Socksaddr{}, ErrTooManyServerSessions
   591  			} else {
   592  				c.session.lastRemoteSessionId = c.session.remoteSessionId
   593  				c.session.lastWindow = c.session.window
   594  				c.session.lastRemoteSeen = c.method.time().Unix()
   595  				c.session.lastRemoteCipher = c.session.remoteCipher
   596  				c.session.window = SlidingWindow{}
   597  			}
   598  		}
   599  		c.session.remoteSessionId = sessionId
   600  		c.session.remoteCipher = remoteCipher
   601  		c.session.window.Add(packetId)
   602  	}
   603  
   604  	var clientSessionId uint64
   605  	err = binary.Read(buffer, binary.BigEndian, &clientSessionId)
   606  	if err != nil {
   607  		return M.Socksaddr{}, err
   608  	}
   609  
   610  	if clientSessionId != c.session.sessionId {
   611  		return M.Socksaddr{}, ErrBadClientSessionId
   612  	}
   613  
   614  	var paddingLen uint16
   615  	err = binary.Read(buffer, binary.BigEndian, &paddingLen)
   616  	if err != nil {
   617  		return M.Socksaddr{}, E.Cause(err, "read padding length")
   618  	}
   619  	buffer.Advance(int(paddingLen))
   620  
   621  	destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer)
   622  	if err != nil {
   623  		return
   624  	}
   625  	return destination.Unwrap(), nil
   626  }
   627  
   628  func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   629  	n, err = c.reader.Read(p)
   630  	if err != nil {
   631  		return
   632  	}
   633  	buffer := buf.As(p[:n])
   634  	destination, err := c.readPacket(buffer)
   635  	if destination.IsFqdn() {
   636  		addr = destination
   637  	} else {
   638  		addr = destination.UDPAddr()
   639  	}
   640  	n = copy(p, buffer.Bytes())
   641  	return
   642  }
   643  
   644  func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   645  	destination := M.SocksaddrFromNet(addr)
   646  	var overHead int
   647  	if c.method.udpCipher != nil {
   648  		overHead = PacketNonceSize + shadowio.Overhead
   649  	} else {
   650  		overHead = shadowio.Overhead
   651  	}
   652  	overHead += 16 // packet header
   653  	pskLen := len(c.method.pskList)
   654  	if c.method.udpCipher == nil && pskLen > 1 {
   655  		overHead += (pskLen - 1) * aes.BlockSize
   656  	}
   657  	var paddingLen int
   658  	if destination.Port == 53 && len(p) < MaxPaddingLength {
   659  		paddingLen = mRand.Intn(MaxPaddingLength-len(p)) + 1
   660  	}
   661  	overHead += 1 // header type
   662  	overHead += 8 // timestamp
   663  	overHead += 2 // padding length
   664  	overHead += paddingLen
   665  	overHead += M.SocksaddrSerializer.AddrPortLen(destination)
   666  
   667  	buffer := buf.NewSize(overHead + len(p))
   668  	defer buffer.Release()
   669  
   670  	var dataIndex int
   671  	if c.method.udpCipher != nil {
   672  		common.Must1(buffer.ReadFullFrom(c.session.rng, PacketNonceSize))
   673  		if pskLen > 1 {
   674  			panic("unsupported chacha extended header")
   675  		}
   676  		dataIndex = PacketNonceSize
   677  	} else {
   678  		dataIndex = aes.BlockSize
   679  	}
   680  
   681  	common.Must(
   682  		binary.Write(buffer, binary.BigEndian, c.session.sessionId),
   683  		binary.Write(buffer, binary.BigEndian, c.session.nextPacketId()),
   684  	)
   685  
   686  	if c.method.udpCipher == nil && pskLen > 1 {
   687  		for i, psk := range c.method.pskList {
   688  			dataIndex += aes.BlockSize
   689  			pskHash := c.method.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
   690  
   691  			identityHeader := buffer.Extend(aes.BlockSize)
   692  			xorWords(identityHeader, pskHash, buffer.To(aes.BlockSize))
   693  			b, err := c.method.blockConstructor(psk)
   694  			if err != nil {
   695  				return 0, err
   696  			}
   697  			b.Encrypt(identityHeader, identityHeader)
   698  
   699  			if i == pskLen-2 {
   700  				break
   701  			}
   702  		}
   703  	}
   704  	common.Must(
   705  		buffer.WriteByte(HeaderTypeClient),
   706  		binary.Write(buffer, binary.BigEndian, uint64(c.method.time().Unix())),
   707  		binary.Write(buffer, binary.BigEndian, uint16(paddingLen)), // padding length
   708  	)
   709  
   710  	if paddingLen > 0 {
   711  		buffer.Extend(paddingLen)
   712  	}
   713  
   714  	err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination)
   715  	if err != nil {
   716  		return
   717  	}
   718  	common.Must1(buffer.Write(p))
   719  	if c.method.udpCipher != nil {
   720  		c.method.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
   721  		buffer.Extend(shadowio.Overhead)
   722  	} else {
   723  		packetHeader := buffer.To(aes.BlockSize)
   724  		c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
   725  		buffer.Extend(shadowio.Overhead)
   726  		c.method.udpBlockEncryptCipher.Encrypt(packetHeader, packetHeader)
   727  	}
   728  	err = common.Error(c.writer.Write(buffer.Bytes()))
   729  	if err != nil {
   730  		return
   731  	}
   732  	return len(p), nil
   733  }
   734  
   735  func (c *clientPacketConn) FrontHeadroom() int {
   736  	var overHead int
   737  	if c.method.udpCipher != nil {
   738  		overHead = PacketNonceSize + shadowio.Overhead
   739  	} else {
   740  		overHead = shadowio.Overhead
   741  	}
   742  	overHead += 16 // packet header
   743  	pskLen := len(c.method.pskList)
   744  	if c.method.udpCipher == nil && pskLen > 1 {
   745  		overHead += (pskLen - 1) * aes.BlockSize
   746  	}
   747  	overHead += 1 // header type
   748  	overHead += 8 // timestamp
   749  	overHead += 2 // padding length
   750  	overHead += MaxPaddingLength
   751  	overHead += M.MaxSocksaddrLength
   752  	return overHead
   753  }
   754  
   755  func (c *clientPacketConn) RearHeadroom() int {
   756  	return shadowio.Overhead
   757  }
   758  
   759  func (c *clientPacketConn) Upstream() any {
   760  	return c.AbstractConn
   761  }
   762  
   763  func (c *clientPacketConn) Close() error {
   764  	return c.AbstractConn.Close()
   765  }
   766  
   767  var _ shadowio.WaitReadFrom = (*clientWaitPacketConn)(nil)
   768  
   769  type clientWaitPacketConn struct {
   770  	*clientPacketConn
   771  	waitRead shadowio.WaitRead
   772  }
   773  
   774  func (c *clientWaitPacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) {
   775  	data, put, err = c.waitRead.WaitRead()
   776  	if err != nil {
   777  		return
   778  	}
   779  	if len(data) <= 0 {
   780  		err = C.ErrPacketTooShort
   781  		return
   782  	}
   783  	buffer := buf.As(data)
   784  	var destination M.Socksaddr
   785  	destination, err = c.readPacket(buffer)
   786  	if err != nil {
   787  		if put != nil {
   788  			put()
   789  		}
   790  		put = nil
   791  		data = nil
   792  		return
   793  	}
   794  	if destination.IsFqdn() {
   795  		addr = destination
   796  	} else {
   797  		addr = destination.UDPAddr()
   798  	}
   799  	data = buffer.Bytes()
   800  	return
   801  }