github.com/MerlinKodo/sing-shadowsocks@v0.2.6/shadowaead_2022/service.go (about)

     1  package shadowaead_2022
     2  
     3  import (
     4  	"context"
     5  	"crypto/aes"
     6  	"crypto/cipher"
     7  	"crypto/rand"
     8  	"encoding/base64"
     9  	"encoding/binary"
    10  	"io"
    11  	"math"
    12  	mRand "math/rand"
    13  	"net"
    14  	"os"
    15  	"sync"
    16  	"sync/atomic"
    17  	"time"
    18  
    19  	shadowsocks "github.com/MerlinKodo/sing-shadowsocks"
    20  	"github.com/MerlinKodo/sing-shadowsocks/shadowaead"
    21  	"github.com/sagernet/sing/common"
    22  	"github.com/sagernet/sing/common/buf"
    23  	"github.com/sagernet/sing/common/cache"
    24  	E "github.com/sagernet/sing/common/exceptions"
    25  	M "github.com/sagernet/sing/common/metadata"
    26  	N "github.com/sagernet/sing/common/network"
    27  	"github.com/sagernet/sing/common/replay"
    28  	"github.com/sagernet/sing/common/udpnat"
    29  
    30  	"golang.org/x/crypto/chacha20poly1305"
    31  )
    32  
    33  var (
    34  	ErrNoPadding  = E.New("bad request: missing payload or padding")
    35  	ErrBadPadding = E.New("bad request: damaged padding")
    36  )
    37  
    38  var _ shadowsocks.Service = (*Service)(nil)
    39  
    40  type Service struct {
    41  	name          string
    42  	keySaltLength int
    43  	handler       shadowsocks.Handler
    44  	timeFunc      func() time.Time
    45  
    46  	constructor      func(key []byte) (cipher.AEAD, error)
    47  	blockConstructor func(key []byte) (cipher.Block, error)
    48  	udpCipher        cipher.AEAD
    49  	udpBlockCipher   cipher.Block
    50  	psk              []byte
    51  
    52  	replayFilter replay.Filter
    53  	udpNat       *udpnat.Service[uint64]
    54  	udpSessions  *cache.LruCache[uint64, *serverUDPSession]
    55  }
    56  
    57  func NewServiceWithPassword(method string, password string, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (shadowsocks.Service, error) {
    58  	if password == "" {
    59  		return nil, ErrMissingPSK
    60  	}
    61  	psk, err := base64.StdEncoding.DecodeString(password)
    62  	if err != nil {
    63  		return nil, E.Cause(err, "decode psk")
    64  	}
    65  	return NewService(method, psk, udpTimeout, handler, timeFunc)
    66  }
    67  
    68  func NewService(method string, psk []byte, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (shadowsocks.Service, error) {
    69  	s := &Service{
    70  		name:     method,
    71  		handler:  handler,
    72  		timeFunc: timeFunc,
    73  
    74  		replayFilter: replay.NewSimple(60 * time.Second),
    75  		udpNat:       udpnat.New[uint64](udpTimeout, handler),
    76  		udpSessions: cache.New[uint64, *serverUDPSession](
    77  			cache.WithAge[uint64, *serverUDPSession](udpTimeout),
    78  			cache.WithUpdateAgeOnGet[uint64, *serverUDPSession](),
    79  		),
    80  	}
    81  
    82  	switch method {
    83  	case "2022-blake3-aes-128-gcm":
    84  		s.keySaltLength = 16
    85  		s.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    86  		s.blockConstructor = aes.NewCipher
    87  	case "2022-blake3-aes-256-gcm":
    88  		s.keySaltLength = 32
    89  		s.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    90  		s.blockConstructor = aes.NewCipher
    91  	case "2022-blake3-chacha20-poly1305":
    92  		s.keySaltLength = 32
    93  		s.constructor = chacha20poly1305.New
    94  	default:
    95  		return nil, os.ErrInvalid
    96  	}
    97  
    98  	if len(psk) != s.keySaltLength {
    99  		if len(psk) < s.keySaltLength {
   100  			return nil, shadowsocks.ErrBadKey
   101  		} else if len(psk) > s.keySaltLength {
   102  			psk = Key(psk, s.keySaltLength)
   103  		} else {
   104  			return nil, ErrMissingPSK
   105  		}
   106  	}
   107  
   108  	var err error
   109  	switch method {
   110  	case "2022-blake3-aes-128-gcm", "2022-blake3-aes-256-gcm":
   111  		s.udpBlockCipher, err = aes.NewCipher(psk)
   112  	case "2022-blake3-chacha20-poly1305":
   113  		s.udpCipher, err = chacha20poly1305.NewX(psk)
   114  	}
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  
   119  	s.psk = psk
   120  	return s, nil
   121  }
   122  
   123  func (s *Service) Name() string {
   124  	return s.name
   125  }
   126  
   127  func (s *Service) Password() string {
   128  	return base64.StdEncoding.EncodeToString(s.psk)
   129  }
   130  
   131  func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
   132  	err := s.newConnection(ctx, conn, metadata)
   133  	if err != nil {
   134  		err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err}
   135  	}
   136  	return err
   137  }
   138  
   139  func (s *Service) time() time.Time {
   140  	if s.timeFunc != nil {
   141  		return s.timeFunc()
   142  	} else {
   143  		return time.Now()
   144  	}
   145  }
   146  
   147  func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
   148  	header := make([]byte, s.keySaltLength+shadowaead.Overhead+RequestHeaderFixedChunkLength)
   149  
   150  	n, err := conn.Read(header)
   151  	if err != nil {
   152  		return E.Cause(err, "read header")
   153  	} else if n < len(header) {
   154  		return shadowaead.ErrBadHeader
   155  	}
   156  
   157  	requestSalt := header[:s.keySaltLength]
   158  
   159  	if !s.replayFilter.Check(requestSalt) {
   160  		return ErrSaltNotUnique
   161  	}
   162  
   163  	requestKey := SessionKey(s.psk, requestSalt, s.keySaltLength)
   164  	readCipher, err := s.constructor(requestKey)
   165  	if err != nil {
   166  		return err
   167  	}
   168  	reader := shadowaead.NewReader(
   169  		conn,
   170  		readCipher,
   171  		MaxPacketSize,
   172  	)
   173  
   174  	err = reader.ReadExternalChunk(header[s.keySaltLength:])
   175  	if err != nil {
   176  		return err
   177  	}
   178  
   179  	headerType, err := reader.ReadByte()
   180  	if err != nil {
   181  		return E.Cause(err, "read header")
   182  	}
   183  
   184  	if headerType != HeaderTypeClient {
   185  		return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
   186  	}
   187  
   188  	var epoch uint64
   189  	err = binary.Read(reader, binary.BigEndian, &epoch)
   190  	if err != nil {
   191  		return err
   192  	}
   193  
   194  	diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
   195  	if diff > 30 {
   196  		return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
   197  	}
   198  
   199  	var length uint16
   200  	err = binary.Read(reader, binary.BigEndian, &length)
   201  	if err != nil {
   202  		return err
   203  	}
   204  
   205  	err = reader.ReadWithLength(length)
   206  	if err != nil {
   207  		return err
   208  	}
   209  
   210  	destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
   211  	if err != nil {
   212  		return err
   213  	}
   214  
   215  	var paddingLen uint16
   216  	err = binary.Read(reader, binary.BigEndian, &paddingLen)
   217  	if err != nil {
   218  		return err
   219  	}
   220  
   221  	if uint16(reader.Cached()) < paddingLen {
   222  		return ErrNoPadding
   223  	}
   224  
   225  	if paddingLen > 0 {
   226  		err = reader.Discard(int(paddingLen))
   227  		if err != nil {
   228  			return E.Cause(err, "discard padding")
   229  		}
   230  	} else if reader.Cached() == 0 {
   231  		return ErrNoPadding
   232  	}
   233  
   234  	protocolConn := &serverConn{
   235  		Service:     s,
   236  		Conn:        conn,
   237  		uPSK:        s.psk,
   238  		headerType:  headerType,
   239  		requestSalt: requestSalt,
   240  	}
   241  
   242  	protocolConn.reader = reader
   243  
   244  	metadata.Protocol = "shadowsocks"
   245  	metadata.Destination = destination
   246  	return s.handler.NewConnection(ctx, protocolConn, metadata)
   247  }
   248  
   249  type serverConn struct {
   250  	*Service
   251  	net.Conn
   252  	uPSK        []byte
   253  	access      sync.Mutex
   254  	headerType  byte
   255  	reader      *shadowaead.Reader
   256  	writer      *shadowaead.Writer
   257  	requestSalt []byte
   258  }
   259  
   260  func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
   261  	salt := buf.NewSize(c.keySaltLength)
   262  	salt.WriteRandom(salt.FreeLen())
   263  
   264  	key := SessionKey(c.uPSK, salt.Bytes(), c.keySaltLength)
   265  	writeCipher, err := c.constructor(key)
   266  	if err != nil {
   267  		salt.Release()
   268  		return
   269  	}
   270  	writer := shadowaead.NewWriter(
   271  		c.Conn,
   272  		writeCipher,
   273  		MaxPacketSize,
   274  	)
   275  	header := writer.Buffer()
   276  	header.Write(salt.Bytes())
   277  
   278  	salt.Release()
   279  
   280  	headerType := byte(HeaderTypeServer)
   281  	payloadLen := len(payload)
   282  
   283  	headerFixedChunk := buf.NewSize(1 + 8 + c.keySaltLength + 2)
   284  	common.Must(headerFixedChunk.WriteByte(headerType))
   285  	common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint64(c.time().Unix())))
   286  	common.Must1(headerFixedChunk.Write(c.requestSalt))
   287  	common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint16(payloadLen)))
   288  
   289  	writer.WriteChunk(header, headerFixedChunk.Slice())
   290  	headerFixedChunk.Release()
   291  	c.requestSalt = nil
   292  
   293  	if payloadLen > 0 {
   294  		writer.WriteChunk(header, payload[:payloadLen])
   295  	}
   296  
   297  	err = writer.BufferedWriter(header.Len()).Flush()
   298  	if err != nil {
   299  		return
   300  	}
   301  
   302  	switch headerType {
   303  	case HeaderTypeServer:
   304  		c.writer = writer
   305  		// case HeaderTypeServerEncrypted:
   306  		//	encryptedWriter := NewTLSEncryptedStreamWriter(writer)
   307  		//	if payloadLen < len(payload) {
   308  		//		_, err = encryptedWriter.Write(payload[payloadLen:])
   309  		//		if err != nil {
   310  		//			return
   311  		//		}
   312  		//	}
   313  		//	c.writer = encryptedWriter
   314  	}
   315  
   316  	n = len(payload)
   317  	return
   318  }
   319  
   320  func (c *serverConn) Read(b []byte) (n int, err error) {
   321  	return c.reader.Read(b)
   322  }
   323  
   324  func (c *serverConn) Write(p []byte) (n int, err error) {
   325  	if c.writer != nil {
   326  		return c.writer.Write(p)
   327  	}
   328  	c.access.Lock()
   329  	if c.writer != nil {
   330  		c.access.Unlock()
   331  		return c.writer.Write(p)
   332  	}
   333  	defer c.access.Unlock()
   334  	return c.writeResponse(p)
   335  }
   336  
   337  func (c *serverConn) WriteVectorised(buffers []*buf.Buffer) error {
   338  	if c.writer != nil {
   339  		return c.writer.WriteVectorised(buffers)
   340  	}
   341  	c.access.Lock()
   342  	if c.writer != nil {
   343  		c.access.Unlock()
   344  		return c.writer.WriteVectorised(buffers)
   345  	}
   346  	defer c.access.Unlock()
   347  	_, err := c.writeResponse(buffers[0].Bytes())
   348  	if err != nil {
   349  		buf.ReleaseMulti(buffers)
   350  		return err
   351  	}
   352  	buffers[0].Release()
   353  	return c.writer.WriteVectorised(buffers[1:])
   354  }
   355  
   356  func (c *serverConn) Close() error {
   357  	return common.Close(
   358  		c.Conn,
   359  		common.PtrOrNil(c.reader),
   360  		common.PtrOrNil(c.writer),
   361  	)
   362  }
   363  
   364  func (c *serverConn) NeedAdditionalReadDeadline() bool {
   365  	return true
   366  }
   367  
   368  func (c *serverConn) Upstream() any {
   369  	return c.Conn
   370  }
   371  
   372  func (s *Service) WriteIsThreadUnsafe() {
   373  }
   374  
   375  func (s *Service) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
   376  	err := s.newPacket(ctx, conn, buffer, metadata)
   377  	if err != nil {
   378  		err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err}
   379  	}
   380  	return err
   381  }
   382  
   383  func (s *Service) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
   384  	var packetHeader []byte
   385  	if s.udpCipher != nil {
   386  		if buffer.Len() < PacketNonceSize+PacketMinimalHeaderSize {
   387  			return ErrPacketTooShort
   388  		}
   389  		_, err := s.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
   390  		if err != nil {
   391  			return E.Cause(err, "decrypt packet header")
   392  		}
   393  		buffer.Advance(PacketNonceSize)
   394  		buffer.Truncate(buffer.Len() - shadowaead.Overhead)
   395  	} else {
   396  		if buffer.Len() < PacketMinimalHeaderSize {
   397  			return ErrPacketTooShort
   398  		}
   399  		packetHeader = buffer.To(aes.BlockSize)
   400  		s.udpBlockCipher.Decrypt(packetHeader, packetHeader)
   401  	}
   402  
   403  	var sessionId, packetId uint64
   404  	err := binary.Read(buffer, binary.BigEndian, &sessionId)
   405  	if err != nil {
   406  		return err
   407  	}
   408  	err = binary.Read(buffer, binary.BigEndian, &packetId)
   409  	if err != nil {
   410  		return err
   411  	}
   412  
   413  	session, loaded := s.udpSessions.LoadOrStore(sessionId, s.newUDPSession)
   414  	if !loaded {
   415  		session.remoteSessionId = sessionId
   416  		if packetHeader != nil {
   417  			key := SessionKey(s.psk, packetHeader[:8], s.keySaltLength)
   418  			session.remoteCipher, err = s.constructor(key)
   419  			if err != nil {
   420  				return err
   421  			}
   422  		}
   423  	}
   424  	goto process
   425  
   426  returnErr:
   427  	if !loaded {
   428  		s.udpSessions.Delete(sessionId)
   429  	}
   430  	return err
   431  
   432  process:
   433  	if !session.window.Check(packetId) {
   434  		err = ErrPacketIdNotUnique
   435  		goto returnErr
   436  	}
   437  
   438  	if packetHeader != nil {
   439  		_, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
   440  		if err != nil {
   441  			err = E.Cause(err, "decrypt packet")
   442  			goto returnErr
   443  		}
   444  		buffer.Truncate(buffer.Len() - shadowaead.Overhead)
   445  	}
   446  
   447  	session.window.Add(packetId)
   448  
   449  	var headerType byte
   450  	headerType, err = buffer.ReadByte()
   451  	if err != nil {
   452  		err = E.Cause(err, "decrypt packet")
   453  		goto returnErr
   454  	}
   455  	if headerType != HeaderTypeClient {
   456  		err = E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
   457  		goto returnErr
   458  	}
   459  
   460  	var epoch uint64
   461  	err = binary.Read(buffer, binary.BigEndian, &epoch)
   462  	if err != nil {
   463  		goto returnErr
   464  	}
   465  	diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
   466  	if diff > 30 {
   467  		err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
   468  		goto returnErr
   469  	}
   470  
   471  	var paddingLen uint16
   472  	err = binary.Read(buffer, binary.BigEndian, &paddingLen)
   473  	if err != nil {
   474  		err = E.Cause(err, "read padding length")
   475  		goto returnErr
   476  	}
   477  	buffer.Advance(int(paddingLen))
   478  
   479  	destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
   480  	if err != nil {
   481  		goto returnErr
   482  	}
   483  	metadata.Protocol = "shadowsocks"
   484  	metadata.Destination = destination
   485  	s.udpNat.NewPacket(ctx, sessionId, buffer, metadata, func(natConn N.PacketConn) N.PacketWriter {
   486  		return &serverPacketWriter{s, conn, natConn, session, s.udpBlockCipher}
   487  	})
   488  	return nil
   489  }
   490  
   491  func (s *Service) NewError(ctx context.Context, err error) {
   492  	s.handler.NewError(ctx, err)
   493  }
   494  
   495  type serverPacketWriter struct {
   496  	*Service
   497  	source         N.PacketConn
   498  	nat            N.PacketConn
   499  	session        *serverUDPSession
   500  	udpBlockCipher cipher.Block
   501  }
   502  
   503  func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   504  	var hdrLen int
   505  	if w.udpCipher != nil {
   506  		hdrLen = PacketNonceSize
   507  	}
   508  
   509  	var paddingLen int
   510  	if destination.Port == 53 && buffer.Len() < MaxPaddingLength {
   511  		paddingLen = mRand.Intn(MaxPaddingLength-buffer.Len()) + 1
   512  	}
   513  
   514  	hdrLen += 16 // packet header
   515  	hdrLen += 1  // header type
   516  	hdrLen += 8  // timestamp
   517  	hdrLen += 8  // remote session id
   518  	hdrLen += 2  // padding length
   519  	hdrLen += paddingLen
   520  	hdrLen += M.SocksaddrSerializer.AddrPortLen(destination)
   521  	header := buf.With(buffer.ExtendHeader(hdrLen))
   522  
   523  	var dataIndex int
   524  	if w.udpCipher != nil {
   525  		common.Must1(header.ReadFullFrom(w.session.rng, PacketNonceSize))
   526  		dataIndex = PacketNonceSize
   527  	} else {
   528  		dataIndex = aes.BlockSize
   529  	}
   530  
   531  	common.Must(
   532  		binary.Write(header, binary.BigEndian, w.session.sessionId),
   533  		binary.Write(header, binary.BigEndian, w.session.nextPacketId()),
   534  		header.WriteByte(HeaderTypeServer),
   535  		binary.Write(header, binary.BigEndian, uint64(w.time().Unix())),
   536  		binary.Write(header, binary.BigEndian, w.session.remoteSessionId),
   537  		binary.Write(header, binary.BigEndian, uint16(paddingLen)), // padding length
   538  	)
   539  
   540  	if paddingLen > 0 {
   541  		header.Extend(paddingLen)
   542  	}
   543  
   544  	err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
   545  	if err != nil {
   546  		buffer.Release()
   547  		return err
   548  	}
   549  
   550  	if w.udpCipher != nil {
   551  		w.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
   552  		buffer.Extend(shadowaead.Overhead)
   553  	} else {
   554  		packetHeader := buffer.To(aes.BlockSize)
   555  		w.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
   556  		buffer.Extend(shadowaead.Overhead)
   557  		w.udpBlockCipher.Encrypt(packetHeader, packetHeader)
   558  	}
   559  	return w.source.WritePacket(buffer, M.SocksaddrFromNet(w.nat.LocalAddr()))
   560  }
   561  
   562  func (w *serverPacketWriter) FrontHeadroom() int {
   563  	var hdrLen int
   564  	if w.udpCipher != nil {
   565  		hdrLen = PacketNonceSize
   566  	}
   567  	hdrLen += 16 // packet header
   568  	hdrLen += 1  // header type
   569  	hdrLen += 8  // timestamp
   570  	hdrLen += 8  // remote session id
   571  	hdrLen += 2  // padding length
   572  	hdrLen += MaxPaddingLength
   573  	hdrLen += M.MaxSocksaddrLength
   574  	return hdrLen
   575  }
   576  
   577  func (w *serverPacketWriter) RearHeadroom() int {
   578  	return shadowaead.Overhead
   579  }
   580  
   581  func (w *serverPacketWriter) Upstream() any {
   582  	return w.source
   583  }
   584  
   585  type serverUDPSession struct {
   586  	sessionId       uint64
   587  	remoteSessionId uint64
   588  	packetId        uint64
   589  	cipher          cipher.AEAD
   590  	remoteCipher    cipher.AEAD
   591  	window          SlidingWindow
   592  	rng             io.Reader
   593  }
   594  
   595  func (s *serverUDPSession) nextPacketId() uint64 {
   596  	return atomic.AddUint64(&s.packetId, 1)
   597  }
   598  
   599  func (s *Service) newUDPSession() *serverUDPSession {
   600  	session := &serverUDPSession{}
   601  	if s.udpCipher != nil {
   602  		session.rng = Blake3KeyedHash(rand.Reader)
   603  		common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId))
   604  	} else {
   605  		common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId))
   606  	}
   607  	session.packetId--
   608  	if s.udpCipher == nil {
   609  		sessionId := make([]byte, 8)
   610  		binary.BigEndian.PutUint64(sessionId, session.sessionId)
   611  		key := SessionKey(s.psk, sessionId, s.keySaltLength)
   612  		var err error
   613  		session.cipher, err = s.constructor(key)
   614  		common.Must(err)
   615  	}
   616  	return session
   617  }