github.com/metacubex/sing-shadowsocks2@v0.2.0/shadowaead_2022/method.go (about)

     1  package shadowaead_2022
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/aes"
     7  	"crypto/cipher"
     8  	"crypto/rand"
     9  	"encoding/base64"
    10  	"encoding/binary"
    11  	"math"
    12  	mRand "math/rand"
    13  	"net"
    14  	"os"
    15  	"strings"
    16  	"time"
    17  
    18  	C "github.com/metacubex/sing-shadowsocks2/cipher"
    19  	"github.com/metacubex/sing-shadowsocks2/internal/shadowio"
    20  	"github.com/sagernet/sing/common"
    21  	"github.com/sagernet/sing/common/buf"
    22  	"github.com/sagernet/sing/common/bufio"
    23  	E "github.com/sagernet/sing/common/exceptions"
    24  	M "github.com/sagernet/sing/common/metadata"
    25  	N "github.com/sagernet/sing/common/network"
    26  	"github.com/sagernet/sing/common/ntp"
    27  
    28  	"golang.org/x/crypto/chacha20poly1305"
    29  	"lukechampine.com/blake3"
    30  )
    31  
    32  var MethodList = []string{
    33  	"2022-blake3-aes-128-gcm",
    34  	"2022-blake3-aes-256-gcm",
    35  	"2022-blake3-chacha20-poly1305",
    36  }
    37  
    38  func init() {
    39  	C.RegisterMethod(MethodList, NewMethod)
    40  }
    41  
    42  type Method struct {
    43  	keySaltLength         int
    44  	timeFunc              func() time.Time
    45  	constructor           func(key []byte) (cipher.AEAD, error)
    46  	blockConstructor      func(key []byte) (cipher.Block, error)
    47  	udpCipher             cipher.AEAD
    48  	udpBlockEncryptCipher cipher.Block
    49  	udpBlockDecryptCipher cipher.Block
    50  	pskList               [][]byte
    51  	pskHash               []byte
    52  }
    53  
    54  func NewMethod(ctx context.Context, methodName string, options C.MethodOptions) (C.Method, error) {
    55  	m := &Method{
    56  		timeFunc: ntp.TimeFuncFromContext(ctx),
    57  		pskList:  options.KeyList,
    58  	}
    59  	if options.Password != "" {
    60  		var pskList [][]byte
    61  		keyStrList := strings.Split(options.Password, ":")
    62  		pskList = make([][]byte, len(keyStrList))
    63  		for i, keyStr := range keyStrList {
    64  			kb, err := base64.StdEncoding.DecodeString(keyStr)
    65  			if err != nil {
    66  				return nil, E.Cause(err, "decode key")
    67  			}
    68  			pskList[i] = kb
    69  		}
    70  		m.pskList = pskList
    71  	}
    72  	switch methodName {
    73  	case "2022-blake3-aes-128-gcm":
    74  		m.keySaltLength = 16
    75  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    76  		m.blockConstructor = aes.NewCipher
    77  	case "2022-blake3-aes-256-gcm":
    78  		m.keySaltLength = 32
    79  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    80  		m.blockConstructor = aes.NewCipher
    81  	case "2022-blake3-chacha20-poly1305":
    82  		if len(m.pskList) > 1 {
    83  			return nil, ErrNoEIH
    84  		}
    85  		m.keySaltLength = 32
    86  		m.constructor = chacha20poly1305.New
    87  	default:
    88  		return nil, os.ErrInvalid
    89  	}
    90  	if len(m.pskList) == 0 {
    91  		return nil, C.ErrMissingPassword
    92  	}
    93  	for _, key := range m.pskList {
    94  		if len(key) != m.keySaltLength {
    95  			return nil, E.New("bad key length, required ", m.keySaltLength, ", got ", len(key))
    96  		}
    97  	}
    98  	if len(m.pskList) > 1 {
    99  		pskHash := make([]byte, (len(m.pskList)-1)*aes.BlockSize)
   100  		for i, key := range m.pskList {
   101  			if i == 0 {
   102  				continue
   103  			}
   104  			keyHash := blake3.Sum512(key)
   105  			copy(pskHash[aes.BlockSize*(i-1):aes.BlockSize*i], keyHash[:aes.BlockSize])
   106  		}
   107  		m.pskHash = pskHash
   108  	}
   109  	var err error
   110  	switch methodName {
   111  	case "2022-blake3-aes-128-gcm", "2022-blake3-aes-256-gcm":
   112  		m.udpBlockEncryptCipher, err = aes.NewCipher(m.pskList[0])
   113  		if err != nil {
   114  			return nil, err
   115  		}
   116  		m.udpBlockDecryptCipher, err = aes.NewCipher(m.pskList[len(m.pskList)-1])
   117  		if err != nil {
   118  			return nil, err
   119  		}
   120  	case "2022-blake3-chacha20-poly1305":
   121  		m.udpCipher, err = chacha20poly1305.NewX(m.pskList[0])
   122  		if err != nil {
   123  			return nil, err
   124  		}
   125  	}
   126  	return m, nil
   127  }
   128  
   129  func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
   130  	shadowsocksConn := &clientConn{
   131  		Conn:        conn,
   132  		method:      m,
   133  		destination: destination,
   134  	}
   135  	return shadowsocksConn, shadowsocksConn.writeRequest(nil)
   136  }
   137  
   138  func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
   139  	return &clientConn{
   140  		Conn:        conn,
   141  		method:      m,
   142  		destination: destination,
   143  	}
   144  }
   145  
   146  func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
   147  	return &clientPacketConn{
   148  		AbstractConn: conn,
   149  		reader:       bufio.NewExtendedReader(conn),
   150  		writer:       bufio.NewExtendedWriter(conn),
   151  		method:       m,
   152  		session:      m.newUDPSession(),
   153  	}
   154  }
   155  
   156  func (m *Method) time() time.Time {
   157  	if m.timeFunc != nil {
   158  		return m.timeFunc()
   159  	} else {
   160  		return time.Now()
   161  	}
   162  }
   163  
   164  type clientConn struct {
   165  	net.Conn
   166  	method          *Method
   167  	destination     M.Socksaddr
   168  	requestSalt     []byte
   169  	reader          *shadowio.Reader
   170  	readWaitOptions N.ReadWaitOptions
   171  	writer          *shadowio.Writer
   172  	shadowio.WriterInterface
   173  }
   174  
   175  func (c *clientConn) writeRequest(payload []byte) error {
   176  	requestSalt := make([]byte, c.method.keySaltLength)
   177  	requestBuffer := buf.New()
   178  	defer requestBuffer.Release()
   179  	requestBuffer.WriteRandom(c.method.keySaltLength)
   180  	copy(requestSalt, requestBuffer.Bytes())
   181  	key := SessionKey(c.method.pskList[len(c.method.pskList)-1], requestSalt, c.method.keySaltLength)
   182  	writeCipher, err := c.method.constructor(key)
   183  	if err != nil {
   184  		return err
   185  	}
   186  	writer := shadowio.NewWriter(
   187  		c.Conn,
   188  		writeCipher,
   189  		nil,
   190  		buf.BufferSize-shadowio.PacketLengthBufferSize-shadowio.Overhead*2,
   191  	)
   192  	err = c.method.writeExtendedIdentityHeaders(requestBuffer, requestBuffer.To(c.method.keySaltLength))
   193  	if err != nil {
   194  		return err
   195  	}
   196  	fixedLengthBuffer := buf.With(requestBuffer.Extend(RequestHeaderFixedChunkLength + shadowio.Overhead))
   197  	common.Must(fixedLengthBuffer.WriteByte(HeaderTypeClient))
   198  	common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint64(c.method.time().Unix())))
   199  	variableLengthHeaderLen := M.SocksaddrSerializer.AddrPortLen(c.destination) + 2
   200  	var paddingLen int
   201  	if len(payload) < MaxPaddingLength {
   202  		paddingLen = mRand.Intn(MaxPaddingLength) + 1
   203  	}
   204  	variableLengthHeaderLen += paddingLen
   205  	maxPayloadLen := requestBuffer.FreeLen() - (variableLengthHeaderLen + shadowio.Overhead)
   206  	payloadLen := len(payload)
   207  	if payloadLen > maxPayloadLen {
   208  		payloadLen = maxPayloadLen
   209  	}
   210  	variableLengthHeaderLen += payloadLen
   211  	common.Must(binary.Write(fixedLengthBuffer, binary.BigEndian, uint16(variableLengthHeaderLen)))
   212  	writer.Encrypt(fixedLengthBuffer.Index(0), fixedLengthBuffer.Bytes())
   213  	fixedLengthBuffer.Extend(shadowio.Overhead)
   214  
   215  	variableLengthBuffer := buf.With(requestBuffer.Extend(variableLengthHeaderLen + shadowio.Overhead))
   216  	err = M.SocksaddrSerializer.WriteAddrPort(variableLengthBuffer, c.destination)
   217  	if err != nil {
   218  		return err
   219  	}
   220  	common.Must(binary.Write(variableLengthBuffer, binary.BigEndian, uint16(paddingLen)))
   221  	if paddingLen > 0 {
   222  		variableLengthBuffer.Extend(paddingLen)
   223  	}
   224  	if payloadLen > 0 {
   225  		common.Must1(variableLengthBuffer.Write(payload[:payloadLen]))
   226  	}
   227  	writer.Encrypt(variableLengthBuffer.Index(0), variableLengthBuffer.Bytes())
   228  	variableLengthBuffer.Extend(shadowio.Overhead)
   229  	_, err = c.Conn.Write(requestBuffer.Bytes())
   230  	if err != nil {
   231  		return err
   232  	}
   233  	if len(payload) > payloadLen {
   234  		_, err = writer.Write(payload[payloadLen:])
   235  		if err != nil {
   236  			return err
   237  		}
   238  	}
   239  	c.requestSalt = requestSalt
   240  	c.writer = writer
   241  	return nil
   242  }
   243  
   244  func (m *Method) writeExtendedIdentityHeaders(request *buf.Buffer, salt []byte) error {
   245  	pskLen := len(m.pskList)
   246  	if pskLen < 2 {
   247  		return nil
   248  	}
   249  	for i, psk := range m.pskList {
   250  		keyMaterial := make([]byte, m.keySaltLength*2)
   251  		copy(keyMaterial, psk)
   252  		copy(keyMaterial[m.keySaltLength:], salt)
   253  		identitySubkey := make([]byte, m.keySaltLength)
   254  		blake3.DeriveKey(identitySubkey, "shadowsocks 2022 identity subkey", keyMaterial)
   255  		pskHash := m.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
   256  		header := request.Extend(16)
   257  		b, err := m.blockConstructor(identitySubkey)
   258  		if err != nil {
   259  			return err
   260  		}
   261  		b.Encrypt(header, pskHash)
   262  		if i == pskLen-2 {
   263  			break
   264  		}
   265  	}
   266  	return nil
   267  }
   268  
   269  func (c *clientConn) readResponse() error {
   270  	salt := buf.NewSize(c.method.keySaltLength)
   271  	defer salt.Release()
   272  	_, err := salt.ReadFullFrom(c.Conn, c.method.keySaltLength)
   273  	if err != nil {
   274  		return err
   275  	}
   276  	key := SessionKey(c.method.pskList[len(c.method.pskList)-1], salt.Bytes(), c.method.keySaltLength)
   277  	readCipher, err := c.method.constructor(key)
   278  	if err != nil {
   279  		return err
   280  	}
   281  	reader := shadowio.NewReader(c.Conn, readCipher)
   282  	fixedResponseBuffer, err := reader.ReadFixedBuffer(1 + 8 + c.method.keySaltLength + 2)
   283  	if err != nil {
   284  		return err
   285  	}
   286  	headerType := common.Must1(fixedResponseBuffer.ReadByte())
   287  	if headerType != HeaderTypeServer {
   288  		return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType)
   289  	}
   290  	var epoch uint64
   291  	common.Must(binary.Read(fixedResponseBuffer, binary.BigEndian, &epoch))
   292  	diff := int(math.Abs(float64(c.method.time().Unix() - int64(epoch))))
   293  	if diff > 30 {
   294  		return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
   295  	}
   296  	responseSalt := common.Must1(fixedResponseBuffer.ReadBytes(c.method.keySaltLength))
   297  	if !bytes.Equal(responseSalt, c.requestSalt) {
   298  		return ErrBadRequestSalt
   299  	}
   300  	var length uint16
   301  	common.Must(binary.Read(reader, binary.BigEndian, &length))
   302  	_, err = reader.ReadFixedBuffer(int(length))
   303  	if err != nil {
   304  		return err
   305  	}
   306  	reader.InitializeReadWaiter(c.readWaitOptions)
   307  	c.reader = reader
   308  	return nil
   309  }
   310  
   311  func (c *clientConn) Read(p []byte) (n int, err error) {
   312  	if c.reader == nil {
   313  		if err = c.readResponse(); err != nil {
   314  			return
   315  		}
   316  	}
   317  	return c.reader.Read(p)
   318  }
   319  
   320  func (c *clientConn) ReadBuffer(buffer *buf.Buffer) error {
   321  	if c.reader == nil {
   322  		err := c.readResponse()
   323  		if err != nil {
   324  			return err
   325  		}
   326  	}
   327  	return c.reader.ReadBuffer(buffer)
   328  }
   329  
   330  func (c *clientConn) Write(p []byte) (n int, err error) {
   331  	if c.writer == nil {
   332  		err = c.writeRequest(p)
   333  		if err == nil {
   334  			n = len(p)
   335  		}
   336  		return
   337  	}
   338  	return c.writer.Write(p)
   339  }
   340  
   341  func (c *clientConn) WriteBuffer(buffer *buf.Buffer) error {
   342  	if c.writer == nil {
   343  		defer buffer.Release()
   344  		return c.writeRequest(buffer.Bytes())
   345  	}
   346  	return c.writer.WriteBuffer(buffer)
   347  }
   348  
   349  func (c *clientConn) NeedHandshake() bool {
   350  	return c.writer == nil
   351  }
   352  
   353  func (c *clientConn) Upstream() any {
   354  	return c.Conn
   355  }
   356  
   357  func (c *clientConn) Close() error {
   358  	return common.Close(
   359  		c.Conn,
   360  		common.PtrOrNil(c.reader),
   361  		common.PtrOrNil(c.writer),
   362  	)
   363  }
   364  
   365  type clientPacketConn struct {
   366  	N.AbstractConn
   367  	reader  N.ExtendedReader
   368  	writer  N.ExtendedWriter
   369  	method  *Method
   370  	session *udpSession
   371  }
   372  
   373  func (m *Method) newUDPSession() *udpSession {
   374  	session := &udpSession{}
   375  	if m.udpCipher != nil {
   376  		session.rng = Blake3KeyedHash(rand.Reader)
   377  		common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId))
   378  	} else {
   379  		common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId))
   380  	}
   381  	session.packetId--
   382  	if m.udpCipher == nil {
   383  		sessionId := make([]byte, 8)
   384  		binary.BigEndian.PutUint64(sessionId, session.sessionId)
   385  		key := SessionKey(m.pskList[len(m.pskList)-1], sessionId, m.keySaltLength)
   386  		var err error
   387  		session.cipher, err = m.constructor(key)
   388  		if err != nil {
   389  			return nil
   390  		}
   391  	}
   392  	return session
   393  }
   394  
   395  func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   396  	var hdrLen int
   397  	if c.method.udpCipher != nil {
   398  		hdrLen = PacketNonceSize
   399  	}
   400  
   401  	var paddingLen int
   402  	if destination.Port == 53 && buffer.Len() < MaxPaddingLength {
   403  		paddingLen = mRand.Intn(MaxPaddingLength-buffer.Len()) + 1
   404  	}
   405  
   406  	hdrLen += 16 // packet header
   407  	pskLen := len(c.method.pskList)
   408  	if c.method.udpCipher == nil && pskLen > 1 {
   409  		hdrLen += (pskLen - 1) * aes.BlockSize
   410  	}
   411  	hdrLen += 1 // header type
   412  	hdrLen += 8 // timestamp
   413  	hdrLen += 2 // padding length
   414  	hdrLen += paddingLen
   415  	hdrLen += M.SocksaddrSerializer.AddrPortLen(destination)
   416  	header := buf.With(buffer.ExtendHeader(hdrLen))
   417  
   418  	var dataIndex int
   419  	if c.method.udpCipher != nil {
   420  		common.Must1(header.ReadFullFrom(c.session.rng, PacketNonceSize))
   421  		if pskLen > 1 {
   422  			panic("unsupported chacha extended header")
   423  		}
   424  		dataIndex = PacketNonceSize
   425  	} else {
   426  		dataIndex = aes.BlockSize
   427  	}
   428  
   429  	common.Must(
   430  		binary.Write(header, binary.BigEndian, c.session.sessionId),
   431  		binary.Write(header, binary.BigEndian, c.session.nextPacketId()),
   432  	)
   433  
   434  	if c.method.udpCipher == nil && pskLen > 1 {
   435  		for i, psk := range c.method.pskList {
   436  			dataIndex += aes.BlockSize
   437  			pskHash := c.method.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
   438  
   439  			identityHeader := header.Extend(aes.BlockSize)
   440  			xorWords(identityHeader, pskHash, header.To(aes.BlockSize))
   441  			b, err := c.method.blockConstructor(psk)
   442  			if err != nil {
   443  				return err
   444  			}
   445  			b.Encrypt(identityHeader, identityHeader)
   446  
   447  			if i == pskLen-2 {
   448  				break
   449  			}
   450  		}
   451  	}
   452  	common.Must(
   453  		header.WriteByte(HeaderTypeClient),
   454  		binary.Write(header, binary.BigEndian, uint64(c.method.time().Unix())),
   455  		binary.Write(header, binary.BigEndian, uint16(paddingLen)), // padding length
   456  	)
   457  
   458  	if paddingLen > 0 {
   459  		header.Extend(paddingLen)
   460  	}
   461  
   462  	err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
   463  	if err != nil {
   464  		return err
   465  	}
   466  	if c.method.udpCipher != nil {
   467  		c.method.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
   468  		buffer.Extend(shadowio.Overhead)
   469  	} else {
   470  		packetHeader := buffer.To(aes.BlockSize)
   471  		c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
   472  		buffer.Extend(shadowio.Overhead)
   473  		c.method.udpBlockEncryptCipher.Encrypt(packetHeader, packetHeader)
   474  	}
   475  	return c.writer.WriteBuffer(buffer)
   476  }
   477  
   478  func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   479  	err = c.reader.ReadBuffer(buffer)
   480  	if err != nil {
   481  		return
   482  	}
   483  	return c.readPacket(buffer)
   484  }
   485  
   486  func (c *clientPacketConn) readPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   487  	var packetHeader []byte
   488  	if c.method.udpCipher != nil {
   489  		if buffer.Len() < PacketNonceSize+PacketMinimalHeaderSize {
   490  			return M.Socksaddr{}, C.ErrPacketTooShort
   491  		}
   492  		_, err = c.method.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
   493  		if err != nil {
   494  			return M.Socksaddr{}, E.Cause(err, "decrypt packet")
   495  		}
   496  		buffer.Advance(PacketNonceSize)
   497  		buffer.Truncate(buffer.Len() - shadowio.Overhead)
   498  	} else {
   499  		if buffer.Len() < PacketMinimalHeaderSize {
   500  			return M.Socksaddr{}, C.ErrPacketTooShort
   501  		}
   502  		packetHeader = buffer.To(aes.BlockSize)
   503  		c.method.udpBlockDecryptCipher.Decrypt(packetHeader, packetHeader)
   504  	}
   505  
   506  	var sessionId, packetId uint64
   507  	err = binary.Read(buffer, binary.BigEndian, &sessionId)
   508  	if err != nil {
   509  		return M.Socksaddr{}, err
   510  	}
   511  	err = binary.Read(buffer, binary.BigEndian, &packetId)
   512  	if err != nil {
   513  		return M.Socksaddr{}, err
   514  	}
   515  
   516  	if sessionId == c.session.remoteSessionId {
   517  		if !c.session.window.Check(packetId) {
   518  			return M.Socksaddr{}, ErrPacketIdNotUnique
   519  		}
   520  	} else if sessionId == c.session.lastRemoteSessionId {
   521  		if !c.session.lastWindow.Check(packetId) {
   522  			return M.Socksaddr{}, ErrPacketIdNotUnique
   523  		}
   524  	}
   525  
   526  	var remoteCipher cipher.AEAD
   527  	if packetHeader != nil {
   528  		if sessionId == c.session.remoteSessionId {
   529  			remoteCipher = c.session.remoteCipher
   530  		} else if sessionId == c.session.lastRemoteSessionId {
   531  			remoteCipher = c.session.lastRemoteCipher
   532  		} else {
   533  			key := SessionKey(c.method.pskList[len(c.method.pskList)-1], packetHeader[:8], c.method.keySaltLength)
   534  			remoteCipher, err = c.method.constructor(key)
   535  			if err != nil {
   536  				return M.Socksaddr{}, err
   537  			}
   538  		}
   539  		_, err = remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
   540  		if err != nil {
   541  			return M.Socksaddr{}, E.Cause(err, "decrypt packet")
   542  		}
   543  		buffer.Truncate(buffer.Len() - shadowio.Overhead)
   544  	}
   545  
   546  	var headerType byte
   547  	headerType, err = buffer.ReadByte()
   548  	if err != nil {
   549  		return M.Socksaddr{}, err
   550  	}
   551  	if headerType != HeaderTypeServer {
   552  		return M.Socksaddr{}, E.Extend(ErrBadHeaderType, "expected ", HeaderTypeServer, ", got ", headerType)
   553  	}
   554  
   555  	var epoch uint64
   556  	err = binary.Read(buffer, binary.BigEndian, &epoch)
   557  	if err != nil {
   558  		return M.Socksaddr{}, err
   559  	}
   560  
   561  	diff := int(math.Abs(float64(c.method.time().Unix() - int64(epoch))))
   562  	if diff > 30 {
   563  		return M.Socksaddr{}, E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
   564  	}
   565  
   566  	if sessionId == c.session.remoteSessionId {
   567  		c.session.window.Add(packetId)
   568  	} else if sessionId == c.session.lastRemoteSessionId {
   569  		c.session.lastWindow.Add(packetId)
   570  		c.session.lastRemoteSeen = c.method.time().Unix()
   571  	} else {
   572  		if c.session.remoteSessionId != 0 {
   573  			if c.method.time().Unix()-c.session.lastRemoteSeen < 60 {
   574  				return M.Socksaddr{}, ErrTooManyServerSessions
   575  			} else {
   576  				c.session.lastRemoteSessionId = c.session.remoteSessionId
   577  				c.session.lastWindow = c.session.window
   578  				c.session.lastRemoteSeen = c.method.time().Unix()
   579  				c.session.lastRemoteCipher = c.session.remoteCipher
   580  				c.session.window = SlidingWindow{}
   581  			}
   582  		}
   583  		c.session.remoteSessionId = sessionId
   584  		c.session.remoteCipher = remoteCipher
   585  		c.session.window.Add(packetId)
   586  	}
   587  
   588  	var clientSessionId uint64
   589  	err = binary.Read(buffer, binary.BigEndian, &clientSessionId)
   590  	if err != nil {
   591  		return M.Socksaddr{}, err
   592  	}
   593  
   594  	if clientSessionId != c.session.sessionId {
   595  		return M.Socksaddr{}, ErrBadClientSessionId
   596  	}
   597  
   598  	var paddingLen uint16
   599  	err = binary.Read(buffer, binary.BigEndian, &paddingLen)
   600  	if err != nil {
   601  		return M.Socksaddr{}, E.Cause(err, "read padding length")
   602  	}
   603  	buffer.Advance(int(paddingLen))
   604  
   605  	destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer)
   606  	if err != nil {
   607  		return
   608  	}
   609  	return destination.Unwrap(), nil
   610  }
   611  
   612  func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   613  	n, err = c.reader.Read(p)
   614  	if err != nil {
   615  		return
   616  	}
   617  	buffer := buf.As(p[:n])
   618  	destination, err := c.readPacket(buffer)
   619  	if destination.IsFqdn() {
   620  		addr = destination
   621  	} else {
   622  		addr = destination.UDPAddr()
   623  	}
   624  	n = copy(p, buffer.Bytes())
   625  	return
   626  }
   627  
   628  func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   629  	destination := M.SocksaddrFromNet(addr)
   630  	var overHead int
   631  	if c.method.udpCipher != nil {
   632  		overHead = PacketNonceSize + shadowio.Overhead
   633  	} else {
   634  		overHead = shadowio.Overhead
   635  	}
   636  	overHead += 16 // packet header
   637  	pskLen := len(c.method.pskList)
   638  	if c.method.udpCipher == nil && pskLen > 1 {
   639  		overHead += (pskLen - 1) * aes.BlockSize
   640  	}
   641  	var paddingLen int
   642  	if destination.Port == 53 && len(p) < MaxPaddingLength {
   643  		paddingLen = mRand.Intn(MaxPaddingLength-len(p)) + 1
   644  	}
   645  	overHead += 1 // header type
   646  	overHead += 8 // timestamp
   647  	overHead += 2 // padding length
   648  	overHead += paddingLen
   649  	overHead += M.SocksaddrSerializer.AddrPortLen(destination)
   650  
   651  	buffer := buf.NewSize(overHead + len(p))
   652  	defer buffer.Release()
   653  
   654  	var dataIndex int
   655  	if c.method.udpCipher != nil {
   656  		common.Must1(buffer.ReadFullFrom(c.session.rng, PacketNonceSize))
   657  		if pskLen > 1 {
   658  			panic("unsupported chacha extended header")
   659  		}
   660  		dataIndex = PacketNonceSize
   661  	} else {
   662  		dataIndex = aes.BlockSize
   663  	}
   664  
   665  	common.Must(
   666  		binary.Write(buffer, binary.BigEndian, c.session.sessionId),
   667  		binary.Write(buffer, binary.BigEndian, c.session.nextPacketId()),
   668  	)
   669  
   670  	if c.method.udpCipher == nil && pskLen > 1 {
   671  		for i, psk := range c.method.pskList {
   672  			dataIndex += aes.BlockSize
   673  			pskHash := c.method.pskHash[aes.BlockSize*i : aes.BlockSize*(i+1)]
   674  
   675  			identityHeader := buffer.Extend(aes.BlockSize)
   676  			xorWords(identityHeader, pskHash, buffer.To(aes.BlockSize))
   677  			b, err := c.method.blockConstructor(psk)
   678  			if err != nil {
   679  				return 0, err
   680  			}
   681  			b.Encrypt(identityHeader, identityHeader)
   682  
   683  			if i == pskLen-2 {
   684  				break
   685  			}
   686  		}
   687  	}
   688  	common.Must(
   689  		buffer.WriteByte(HeaderTypeClient),
   690  		binary.Write(buffer, binary.BigEndian, uint64(c.method.time().Unix())),
   691  		binary.Write(buffer, binary.BigEndian, uint16(paddingLen)), // padding length
   692  	)
   693  
   694  	if paddingLen > 0 {
   695  		buffer.Extend(paddingLen)
   696  	}
   697  
   698  	err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination)
   699  	if err != nil {
   700  		return
   701  	}
   702  	common.Must1(buffer.Write(p))
   703  	if c.method.udpCipher != nil {
   704  		c.method.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
   705  		buffer.Extend(shadowio.Overhead)
   706  	} else {
   707  		packetHeader := buffer.To(aes.BlockSize)
   708  		c.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
   709  		buffer.Extend(shadowio.Overhead)
   710  		c.method.udpBlockEncryptCipher.Encrypt(packetHeader, packetHeader)
   711  	}
   712  	err = common.Error(c.writer.Write(buffer.Bytes()))
   713  	if err != nil {
   714  		return
   715  	}
   716  	return len(p), nil
   717  }
   718  
   719  func (c *clientPacketConn) FrontHeadroom() int {
   720  	var overHead int
   721  	if c.method.udpCipher != nil {
   722  		overHead = PacketNonceSize + shadowio.Overhead
   723  	} else {
   724  		overHead = shadowio.Overhead
   725  	}
   726  	overHead += 16 // packet header
   727  	pskLen := len(c.method.pskList)
   728  	if c.method.udpCipher == nil && pskLen > 1 {
   729  		overHead += (pskLen - 1) * aes.BlockSize
   730  	}
   731  	overHead += 1 // header type
   732  	overHead += 8 // timestamp
   733  	overHead += 2 // padding length
   734  	overHead += MaxPaddingLength
   735  	overHead += M.MaxSocksaddrLength
   736  	return overHead
   737  }
   738  
   739  func (c *clientPacketConn) RearHeadroom() int {
   740  	return shadowio.Overhead
   741  }
   742  
   743  func (c *clientPacketConn) Upstream() any {
   744  	return c.AbstractConn
   745  }
   746  
   747  func (c *clientPacketConn) Close() error {
   748  	return c.AbstractConn.Close()
   749  }