github.com/sagernet/sing-shadowsocks@v0.2.6/shadowaead_2022/service.go (about)

     1  package shadowaead_2022
     2  
     3  import (
     4  	"context"
     5  	"crypto/aes"
     6  	"crypto/cipher"
     7  	"crypto/rand"
     8  	"encoding/base64"
     9  	"encoding/binary"
    10  	"io"
    11  	"math"
    12  	mRand "math/rand"
    13  	"net"
    14  	"os"
    15  	"sync"
    16  	"sync/atomic"
    17  	"time"
    18  
    19  	"github.com/sagernet/sing-shadowsocks"
    20  	"github.com/sagernet/sing-shadowsocks/shadowaead"
    21  	"github.com/sagernet/sing/common"
    22  	"github.com/sagernet/sing/common/buf"
    23  	"github.com/sagernet/sing/common/cache"
    24  	E "github.com/sagernet/sing/common/exceptions"
    25  	M "github.com/sagernet/sing/common/metadata"
    26  	N "github.com/sagernet/sing/common/network"
    27  	"github.com/sagernet/sing/common/replay"
    28  	"github.com/sagernet/sing/common/udpnat"
    29  
    30  	"golang.org/x/crypto/chacha20poly1305"
    31  )
    32  
    33  var (
    34  	ErrInvalidRequest = E.New("invalid request")
    35  	ErrNoPadding      = E.New("bad request: missing payload or padding")
    36  	ErrBadPadding     = E.New("bad request: damaged padding")
    37  )
    38  
    39  var _ shadowsocks.Service = (*Service)(nil)
    40  
    41  type Service struct {
    42  	name          string
    43  	keySaltLength int
    44  	handler       shadowsocks.Handler
    45  	timeFunc      func() time.Time
    46  
    47  	constructor      func(key []byte) (cipher.AEAD, error)
    48  	blockConstructor func(key []byte) (cipher.Block, error)
    49  	udpCipher        cipher.AEAD
    50  	udpBlockCipher   cipher.Block
    51  	psk              []byte
    52  
    53  	replayFilter replay.Filter
    54  	udpNat       *udpnat.Service[uint64]
    55  	udpSessions  *cache.LruCache[uint64, *serverUDPSession]
    56  }
    57  
    58  func NewServiceWithPassword(method string, password string, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (shadowsocks.Service, error) {
    59  	if password == "" {
    60  		return nil, ErrMissingPSK
    61  	}
    62  	psk, err := base64.StdEncoding.DecodeString(password)
    63  	if err != nil {
    64  		return nil, E.Cause(err, "decode psk")
    65  	}
    66  	return NewService(method, psk, udpTimeout, handler, timeFunc)
    67  }
    68  
    69  func NewService(method string, psk []byte, udpTimeout int64, handler shadowsocks.Handler, timeFunc func() time.Time) (shadowsocks.Service, error) {
    70  	s := &Service{
    71  		name:     method,
    72  		handler:  handler,
    73  		timeFunc: timeFunc,
    74  
    75  		replayFilter: replay.NewSimple(60 * time.Second),
    76  		udpNat:       udpnat.New[uint64](udpTimeout, handler),
    77  		udpSessions: cache.New[uint64, *serverUDPSession](
    78  			cache.WithAge[uint64, *serverUDPSession](udpTimeout),
    79  			cache.WithUpdateAgeOnGet[uint64, *serverUDPSession](),
    80  		),
    81  	}
    82  
    83  	switch method {
    84  	case "2022-blake3-aes-128-gcm":
    85  		s.keySaltLength = 16
    86  		s.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    87  		s.blockConstructor = aes.NewCipher
    88  	case "2022-blake3-aes-256-gcm":
    89  		s.keySaltLength = 32
    90  		s.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    91  		s.blockConstructor = aes.NewCipher
    92  	case "2022-blake3-chacha20-poly1305":
    93  		s.keySaltLength = 32
    94  		s.constructor = chacha20poly1305.New
    95  	default:
    96  		return nil, os.ErrInvalid
    97  	}
    98  
    99  	if len(psk) != s.keySaltLength {
   100  		if len(psk) < s.keySaltLength {
   101  			return nil, shadowsocks.ErrBadKey
   102  		} else if len(psk) > s.keySaltLength {
   103  			psk = Key(psk, s.keySaltLength)
   104  		} else {
   105  			return nil, ErrMissingPSK
   106  		}
   107  	}
   108  
   109  	var err error
   110  	switch method {
   111  	case "2022-blake3-aes-128-gcm", "2022-blake3-aes-256-gcm":
   112  		s.udpBlockCipher, err = aes.NewCipher(psk)
   113  	case "2022-blake3-chacha20-poly1305":
   114  		s.udpCipher, err = chacha20poly1305.NewX(psk)
   115  	}
   116  	if err != nil {
   117  		return nil, err
   118  	}
   119  
   120  	s.psk = psk
   121  	return s, nil
   122  }
   123  
   124  func (s *Service) Name() string {
   125  	return s.name
   126  }
   127  
   128  func (s *Service) Password() string {
   129  	return base64.StdEncoding.EncodeToString(s.psk)
   130  }
   131  
   132  func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
   133  	err := s.newConnection(ctx, conn, metadata)
   134  	if err != nil {
   135  		err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err}
   136  	}
   137  	return err
   138  }
   139  
   140  func (s *Service) time() time.Time {
   141  	if s.timeFunc != nil {
   142  		return s.timeFunc()
   143  	} else {
   144  		return time.Now()
   145  	}
   146  }
   147  
   148  func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
   149  	header := make([]byte, s.keySaltLength+shadowaead.Overhead+RequestHeaderFixedChunkLength)
   150  
   151  	n, err := conn.Read(header)
   152  	if err != nil {
   153  		return E.Cause(err, "read header")
   154  	} else if n < len(header) {
   155  		return shadowaead.ErrBadHeader
   156  	}
   157  
   158  	requestSalt := header[:s.keySaltLength]
   159  
   160  	if !s.replayFilter.Check(requestSalt) {
   161  		return ErrSaltNotUnique
   162  	}
   163  
   164  	requestKey := SessionKey(s.psk, requestSalt, s.keySaltLength)
   165  	readCipher, err := s.constructor(requestKey)
   166  	if err != nil {
   167  		return err
   168  	}
   169  	reader := shadowaead.NewReader(
   170  		conn,
   171  		readCipher,
   172  		MaxPacketSize,
   173  	)
   174  
   175  	err = reader.ReadExternalChunk(header[s.keySaltLength:])
   176  	if err != nil {
   177  		return err
   178  	}
   179  
   180  	headerType, err := reader.ReadByte()
   181  	if err != nil {
   182  		return E.Cause(err, "read header")
   183  	}
   184  
   185  	if headerType != HeaderTypeClient {
   186  		return E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
   187  	}
   188  
   189  	var epoch uint64
   190  	err = binary.Read(reader, binary.BigEndian, &epoch)
   191  	if err != nil {
   192  		return err
   193  	}
   194  
   195  	diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
   196  	if diff > 30 {
   197  		return E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
   198  	}
   199  
   200  	var length uint16
   201  	err = binary.Read(reader, binary.BigEndian, &length)
   202  	if err != nil {
   203  		return err
   204  	}
   205  
   206  	err = reader.ReadWithLength(length)
   207  	if err != nil {
   208  		return err
   209  	}
   210  
   211  	destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
   212  	if err != nil {
   213  		return err
   214  	}
   215  
   216  	var paddingLen uint16
   217  	err = binary.Read(reader, binary.BigEndian, &paddingLen)
   218  	if err != nil {
   219  		return err
   220  	}
   221  
   222  	if uint16(reader.Cached()) < paddingLen {
   223  		return ErrNoPadding
   224  	}
   225  
   226  	if paddingLen > 0 {
   227  		err = reader.Discard(int(paddingLen))
   228  		if err != nil {
   229  			return E.Cause(err, "discard padding")
   230  		}
   231  	} else if reader.Cached() == 0 {
   232  		return ErrNoPadding
   233  	}
   234  
   235  	protocolConn := &serverConn{
   236  		Service:     s,
   237  		Conn:        conn,
   238  		uPSK:        s.psk,
   239  		headerType:  headerType,
   240  		requestSalt: requestSalt,
   241  	}
   242  
   243  	protocolConn.reader = reader
   244  
   245  	metadata.Protocol = "shadowsocks"
   246  	metadata.Destination = destination
   247  	return s.handler.NewConnection(ctx, protocolConn, metadata)
   248  }
   249  
   250  type serverConn struct {
   251  	*Service
   252  	net.Conn
   253  	uPSK        []byte
   254  	access      sync.Mutex
   255  	headerType  byte
   256  	reader      *shadowaead.Reader
   257  	writer      *shadowaead.Writer
   258  	requestSalt []byte
   259  }
   260  
   261  func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
   262  	salt := buf.NewSize(c.keySaltLength)
   263  	salt.WriteRandom(salt.FreeLen())
   264  
   265  	key := SessionKey(c.uPSK, salt.Bytes(), c.keySaltLength)
   266  	writeCipher, err := c.constructor(key)
   267  	if err != nil {
   268  		salt.Release()
   269  		return
   270  	}
   271  	writer := shadowaead.NewWriter(
   272  		c.Conn,
   273  		writeCipher,
   274  		MaxPacketSize,
   275  	)
   276  	header := writer.Buffer()
   277  	header.Write(salt.Bytes())
   278  
   279  	salt.Release()
   280  
   281  	headerType := byte(HeaderTypeServer)
   282  	payloadLen := len(payload)
   283  
   284  	headerFixedChunk := buf.NewSize(1 + 8 + c.keySaltLength + 2)
   285  	common.Must(headerFixedChunk.WriteByte(headerType))
   286  	common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint64(c.time().Unix())))
   287  	common.Must1(headerFixedChunk.Write(c.requestSalt))
   288  	common.Must(binary.Write(headerFixedChunk, binary.BigEndian, uint16(payloadLen)))
   289  
   290  	writer.WriteChunk(header, headerFixedChunk.Bytes())
   291  	headerFixedChunk.Release()
   292  	c.requestSalt = nil
   293  
   294  	if payloadLen > 0 {
   295  		writer.WriteChunk(header, payload[:payloadLen])
   296  	}
   297  
   298  	err = writer.BufferedWriter(header.Len()).Flush()
   299  	if err != nil {
   300  		return
   301  	}
   302  
   303  	switch headerType {
   304  	case HeaderTypeServer:
   305  		c.writer = writer
   306  		// case HeaderTypeServerEncrypted:
   307  		//	encryptedWriter := NewTLSEncryptedStreamWriter(writer)
   308  		//	if payloadLen < len(payload) {
   309  		//		_, err = encryptedWriter.Write(payload[payloadLen:])
   310  		//		if err != nil {
   311  		//			return
   312  		//		}
   313  		//	}
   314  		//	c.writer = encryptedWriter
   315  	}
   316  
   317  	n = len(payload)
   318  	return
   319  }
   320  
   321  func (c *serverConn) Read(b []byte) (n int, err error) {
   322  	return c.reader.Read(b)
   323  }
   324  
   325  func (c *serverConn) Write(p []byte) (n int, err error) {
   326  	if c.writer != nil {
   327  		return c.writer.Write(p)
   328  	}
   329  	c.access.Lock()
   330  	if c.writer != nil {
   331  		c.access.Unlock()
   332  		return c.writer.Write(p)
   333  	}
   334  	defer c.access.Unlock()
   335  	return c.writeResponse(p)
   336  }
   337  
   338  func (c *serverConn) WriteVectorised(buffers []*buf.Buffer) error {
   339  	if c.writer != nil {
   340  		return c.writer.WriteVectorised(buffers)
   341  	}
   342  	c.access.Lock()
   343  	if c.writer != nil {
   344  		c.access.Unlock()
   345  		return c.writer.WriteVectorised(buffers)
   346  	}
   347  	defer c.access.Unlock()
   348  	_, err := c.writeResponse(buffers[0].Bytes())
   349  	if err != nil {
   350  		buf.ReleaseMulti(buffers)
   351  		return err
   352  	}
   353  	buffers[0].Release()
   354  	return c.writer.WriteVectorised(buffers[1:])
   355  }
   356  
   357  func (c *serverConn) Close() error {
   358  	return common.Close(
   359  		c.Conn,
   360  		common.PtrOrNil(c.reader),
   361  		common.PtrOrNil(c.writer),
   362  	)
   363  }
   364  
   365  func (c *serverConn) NeedAdditionalReadDeadline() bool {
   366  	return true
   367  }
   368  
   369  func (c *serverConn) Upstream() any {
   370  	return c.Conn
   371  }
   372  
   373  func (s *Service) WriteIsThreadUnsafe() {
   374  }
   375  
   376  func (s *Service) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
   377  	err := s.newPacket(ctx, conn, buffer, metadata)
   378  	if err != nil {
   379  		err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err}
   380  	}
   381  	return err
   382  }
   383  
   384  func (s *Service) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
   385  	var packetHeader []byte
   386  	if s.udpCipher != nil {
   387  		if buffer.Len() < PacketNonceSize+PacketMinimalHeaderSize {
   388  			return ErrPacketTooShort
   389  		}
   390  		_, err := s.udpCipher.Open(buffer.Index(PacketNonceSize), buffer.To(PacketNonceSize), buffer.From(PacketNonceSize), nil)
   391  		if err != nil {
   392  			return E.Cause(err, "decrypt packet header")
   393  		}
   394  		buffer.Advance(PacketNonceSize)
   395  		buffer.Truncate(buffer.Len() - shadowaead.Overhead)
   396  	} else {
   397  		if buffer.Len() < PacketMinimalHeaderSize {
   398  			return ErrPacketTooShort
   399  		}
   400  		packetHeader = buffer.To(aes.BlockSize)
   401  		s.udpBlockCipher.Decrypt(packetHeader, packetHeader)
   402  	}
   403  
   404  	var sessionId, packetId uint64
   405  	err := binary.Read(buffer, binary.BigEndian, &sessionId)
   406  	if err != nil {
   407  		return err
   408  	}
   409  	err = binary.Read(buffer, binary.BigEndian, &packetId)
   410  	if err != nil {
   411  		return err
   412  	}
   413  
   414  	session, loaded := s.udpSessions.LoadOrStore(sessionId, s.newUDPSession)
   415  	if !loaded {
   416  		session.remoteSessionId = sessionId
   417  		if packetHeader != nil {
   418  			key := SessionKey(s.psk, packetHeader[:8], s.keySaltLength)
   419  			session.remoteCipher, err = s.constructor(key)
   420  			if err != nil {
   421  				return err
   422  			}
   423  		}
   424  	}
   425  	goto process
   426  
   427  returnErr:
   428  	if !loaded {
   429  		s.udpSessions.Delete(sessionId)
   430  	}
   431  	return err
   432  
   433  process:
   434  	if !session.window.Check(packetId) {
   435  		err = ErrPacketIdNotUnique
   436  		goto returnErr
   437  	}
   438  
   439  	if packetHeader != nil {
   440  		_, err = session.remoteCipher.Open(buffer.Index(0), packetHeader[4:16], buffer.Bytes(), nil)
   441  		if err != nil {
   442  			err = E.Cause(err, "decrypt packet")
   443  			goto returnErr
   444  		}
   445  		buffer.Truncate(buffer.Len() - shadowaead.Overhead)
   446  	}
   447  
   448  	session.window.Add(packetId)
   449  
   450  	var headerType byte
   451  	headerType, err = buffer.ReadByte()
   452  	if err != nil {
   453  		err = E.Cause(err, "decrypt packet")
   454  		goto returnErr
   455  	}
   456  	if headerType != HeaderTypeClient {
   457  		err = E.Extend(ErrBadHeaderType, "expected ", HeaderTypeClient, ", got ", headerType)
   458  		goto returnErr
   459  	}
   460  
   461  	var epoch uint64
   462  	err = binary.Read(buffer, binary.BigEndian, &epoch)
   463  	if err != nil {
   464  		goto returnErr
   465  	}
   466  	diff := int(math.Abs(float64(s.time().Unix() - int64(epoch))))
   467  	if diff > 30 {
   468  		err = E.Extend(ErrBadTimestamp, "received ", epoch, ", diff ", diff, "s")
   469  		goto returnErr
   470  	}
   471  
   472  	var paddingLen uint16
   473  	err = binary.Read(buffer, binary.BigEndian, &paddingLen)
   474  	if err != nil {
   475  		err = E.Cause(err, "read padding length")
   476  		goto returnErr
   477  	}
   478  	buffer.Advance(int(paddingLen))
   479  
   480  	destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
   481  	if err != nil {
   482  		goto returnErr
   483  	}
   484  	metadata.Protocol = "shadowsocks"
   485  	metadata.Destination = destination
   486  	s.udpNat.NewPacket(ctx, sessionId, buffer, metadata, func(natConn N.PacketConn) N.PacketWriter {
   487  		return &serverPacketWriter{s, conn, natConn, session, s.udpBlockCipher}
   488  	})
   489  	return nil
   490  }
   491  
   492  func (s *Service) NewError(ctx context.Context, err error) {
   493  	s.handler.NewError(ctx, err)
   494  }
   495  
   496  type serverPacketWriter struct {
   497  	*Service
   498  	source         N.PacketConn
   499  	nat            N.PacketConn
   500  	session        *serverUDPSession
   501  	udpBlockCipher cipher.Block
   502  }
   503  
   504  func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   505  	var hdrLen int
   506  	if w.udpCipher != nil {
   507  		hdrLen = PacketNonceSize
   508  	}
   509  
   510  	var paddingLen int
   511  	if destination.Port == 53 && buffer.Len() < MaxPaddingLength {
   512  		paddingLen = mRand.Intn(MaxPaddingLength-buffer.Len()) + 1
   513  	}
   514  
   515  	hdrLen += 16 // packet header
   516  	hdrLen += 1  // header type
   517  	hdrLen += 8  // timestamp
   518  	hdrLen += 8  // remote session id
   519  	hdrLen += 2  // padding length
   520  	hdrLen += paddingLen
   521  	hdrLen += M.SocksaddrSerializer.AddrPortLen(destination)
   522  	header := buf.With(buffer.ExtendHeader(hdrLen))
   523  
   524  	var dataIndex int
   525  	if w.udpCipher != nil {
   526  		common.Must1(header.ReadFullFrom(w.session.rng, PacketNonceSize))
   527  		dataIndex = PacketNonceSize
   528  	} else {
   529  		dataIndex = aes.BlockSize
   530  	}
   531  
   532  	common.Must(
   533  		binary.Write(header, binary.BigEndian, w.session.sessionId),
   534  		binary.Write(header, binary.BigEndian, w.session.nextPacketId()),
   535  		header.WriteByte(HeaderTypeServer),
   536  		binary.Write(header, binary.BigEndian, uint64(w.time().Unix())),
   537  		binary.Write(header, binary.BigEndian, w.session.remoteSessionId),
   538  		binary.Write(header, binary.BigEndian, uint16(paddingLen)), // padding length
   539  	)
   540  
   541  	if paddingLen > 0 {
   542  		header.Extend(paddingLen)
   543  	}
   544  
   545  	err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
   546  	if err != nil {
   547  		buffer.Release()
   548  		return err
   549  	}
   550  
   551  	if w.udpCipher != nil {
   552  		w.udpCipher.Seal(buffer.Index(dataIndex), buffer.To(dataIndex), buffer.From(dataIndex), nil)
   553  		buffer.Extend(shadowaead.Overhead)
   554  	} else {
   555  		packetHeader := buffer.To(aes.BlockSize)
   556  		w.session.cipher.Seal(buffer.Index(dataIndex), packetHeader[4:16], buffer.From(dataIndex), nil)
   557  		buffer.Extend(shadowaead.Overhead)
   558  		w.udpBlockCipher.Encrypt(packetHeader, packetHeader)
   559  	}
   560  	return w.source.WritePacket(buffer, M.SocksaddrFromNet(w.nat.LocalAddr()))
   561  }
   562  
   563  func (w *serverPacketWriter) FrontHeadroom() int {
   564  	var hdrLen int
   565  	if w.udpCipher != nil {
   566  		hdrLen = PacketNonceSize
   567  	}
   568  	hdrLen += 16 // packet header
   569  	hdrLen += 1  // header type
   570  	hdrLen += 8  // timestamp
   571  	hdrLen += 8  // remote session id
   572  	hdrLen += 2  // padding length
   573  	hdrLen += MaxPaddingLength
   574  	hdrLen += M.MaxSocksaddrLength
   575  	return hdrLen
   576  }
   577  
   578  func (w *serverPacketWriter) RearHeadroom() int {
   579  	return shadowaead.Overhead
   580  }
   581  
   582  func (w *serverPacketWriter) Upstream() any {
   583  	return w.source
   584  }
   585  
   586  type serverUDPSession struct {
   587  	sessionId       uint64
   588  	remoteSessionId uint64
   589  	packetId        uint64
   590  	cipher          cipher.AEAD
   591  	remoteCipher    cipher.AEAD
   592  	window          SlidingWindow
   593  	rng             io.Reader
   594  }
   595  
   596  func (s *serverUDPSession) nextPacketId() uint64 {
   597  	return atomic.AddUint64(&s.packetId, 1)
   598  }
   599  
   600  func (s *Service) newUDPSession() *serverUDPSession {
   601  	session := &serverUDPSession{}
   602  	if s.udpCipher != nil {
   603  		session.rng = Blake3KeyedHash(rand.Reader)
   604  		common.Must(binary.Read(session.rng, binary.BigEndian, &session.sessionId))
   605  	} else {
   606  		common.Must(binary.Read(rand.Reader, binary.BigEndian, &session.sessionId))
   607  	}
   608  	session.packetId--
   609  	if s.udpCipher == nil {
   610  		sessionId := make([]byte, 8)
   611  		binary.BigEndian.PutUint64(sessionId, session.sessionId)
   612  		key := SessionKey(s.psk, sessionId, s.keySaltLength)
   613  		var err error
   614  		session.cipher, err = s.constructor(key)
   615  		common.Must(err)
   616  	}
   617  	return session
   618  }