github.com/sagernet/sing-shadowsocks2@v0.2.0/shadowaead/method.go (about)

     1  package shadowaead
     2  
     3  import (
     4  	"context"
     5  	"crypto/aes"
     6  	"crypto/cipher"
     7  	"net"
     8  
     9  	C "github.com/sagernet/sing-shadowsocks2/cipher"
    10  	"github.com/sagernet/sing-shadowsocks2/internal/legacykey"
    11  	"github.com/sagernet/sing-shadowsocks2/internal/shadowio"
    12  	"github.com/sagernet/sing/common"
    13  	"github.com/sagernet/sing/common/buf"
    14  	"github.com/sagernet/sing/common/bufio"
    15  	E "github.com/sagernet/sing/common/exceptions"
    16  	M "github.com/sagernet/sing/common/metadata"
    17  	N "github.com/sagernet/sing/common/network"
    18  	"github.com/sagernet/sing/common/rw"
    19  
    20  	"golang.org/x/crypto/chacha20poly1305"
    21  )
    22  
    23  var MethodList = []string{
    24  	"aes-128-gcm",
    25  	"aes-192-gcm",
    26  	"aes-256-gcm",
    27  	"chacha20-ietf-poly1305",
    28  	"xchacha20-ietf-poly1305",
    29  }
    30  
    31  func init() {
    32  	C.RegisterMethod(MethodList, func(ctx context.Context, methodName string, options C.MethodOptions) (C.Method, error) {
    33  		return NewMethod(ctx, methodName, options)
    34  	})
    35  }
    36  
    37  type Method struct {
    38  	keySaltLength int
    39  	constructor   func(key []byte) (cipher.AEAD, error)
    40  	key           []byte
    41  }
    42  
    43  func NewMethod(ctx context.Context, methodName string, options C.MethodOptions) (*Method, error) {
    44  	m := &Method{}
    45  	switch methodName {
    46  	case "aes-128-gcm":
    47  		m.keySaltLength = 16
    48  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    49  	case "aes-192-gcm":
    50  		m.keySaltLength = 24
    51  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    52  	case "aes-256-gcm":
    53  		m.keySaltLength = 32
    54  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    55  	case "chacha20-ietf-poly1305":
    56  		m.keySaltLength = 32
    57  		m.constructor = chacha20poly1305.New
    58  	case "xchacha20-ietf-poly1305":
    59  		m.keySaltLength = 32
    60  		m.constructor = chacha20poly1305.NewX
    61  	}
    62  	if len(options.Key) == m.keySaltLength {
    63  		m.key = options.Key
    64  	} else if len(options.Key) > 0 {
    65  		return nil, E.New("bad key length, required ", m.keySaltLength, ", got ", len(options.Key))
    66  	} else if options.Password == "" {
    67  		return nil, C.ErrMissingPassword
    68  	} else {
    69  		m.key = legacykey.Key([]byte(options.Password), m.keySaltLength)
    70  	}
    71  	return m, nil
    72  }
    73  
    74  func aeadCipher(block func(key []byte) (cipher.Block, error), aead func(block cipher.Block) (cipher.AEAD, error)) func(key []byte) (cipher.AEAD, error) {
    75  	return func(key []byte) (cipher.AEAD, error) {
    76  		b, err := block(key)
    77  		if err != nil {
    78  			return nil, err
    79  		}
    80  		return aead(b)
    81  	}
    82  }
    83  
    84  func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
    85  	ssConn := &clientConn{
    86  		Conn:        conn,
    87  		method:      m,
    88  		destination: destination,
    89  	}
    90  	return ssConn, ssConn.writeRequest(nil)
    91  }
    92  
    93  func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
    94  	return &clientConn{
    95  		Conn:        conn,
    96  		method:      m,
    97  		destination: destination,
    98  	}
    99  }
   100  
   101  func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
   102  	return &clientPacketConn{
   103  		AbstractConn: conn,
   104  		reader:       bufio.NewExtendedReader(conn),
   105  		writer:       bufio.NewExtendedWriter(conn),
   106  		method:       m,
   107  	}
   108  }
   109  
   110  var _ N.ExtendedConn = (*clientConn)(nil)
   111  
   112  type clientConn struct {
   113  	net.Conn
   114  	method          *Method
   115  	destination     M.Socksaddr
   116  	reader          *shadowio.Reader
   117  	readWaitOptions N.ReadWaitOptions
   118  	writer          *shadowio.Writer
   119  	shadowio.WriterInterface
   120  }
   121  
   122  func (c *clientConn) writeRequest(payload []byte) error {
   123  	requestBuffer := buf.New()
   124  	requestBuffer.WriteRandom(c.method.keySaltLength)
   125  	key := make([]byte, c.method.keySaltLength)
   126  	legacykey.Kdf(c.method.key, requestBuffer.Bytes(), key)
   127  	writeCipher, err := c.method.constructor(key)
   128  	if err != nil {
   129  		return err
   130  	}
   131  	bufferedRequestWriter := bufio.NewBufferedWriter(c.Conn, requestBuffer)
   132  	requestContentWriter := shadowio.NewWriter(bufferedRequestWriter, writeCipher, nil, MaxPacketSize)
   133  	bufferedRequestContentWriter := bufio.NewBufferedWriter(requestContentWriter, buf.New())
   134  	err = M.SocksaddrSerializer.WriteAddrPort(bufferedRequestContentWriter, c.destination)
   135  	if err != nil {
   136  		return err
   137  	}
   138  	_, err = bufferedRequestContentWriter.Write(payload)
   139  	if err != nil {
   140  		return err
   141  	}
   142  	err = bufferedRequestContentWriter.Fallthrough()
   143  	if err != nil {
   144  		return err
   145  	}
   146  	err = bufferedRequestWriter.Fallthrough()
   147  	if err != nil {
   148  		return err
   149  	}
   150  	c.writer = shadowio.NewWriter(c.Conn, writeCipher, requestContentWriter.TakeNonce(), MaxPacketSize)
   151  	return nil
   152  }
   153  
   154  func (c *clientConn) readResponse() error {
   155  	buffer := buf.NewSize(c.method.keySaltLength)
   156  	defer buffer.Release()
   157  	_, err := buffer.ReadFullFrom(c.Conn, c.method.keySaltLength)
   158  	if err != nil {
   159  		return err
   160  	}
   161  	legacykey.Kdf(c.method.key, buffer.Bytes(), buffer.Bytes())
   162  	readCipher, err := c.method.constructor(buffer.Bytes())
   163  	if err != nil {
   164  		return err
   165  	}
   166  	reader := shadowio.NewReader(c.Conn, readCipher)
   167  	reader.InitializeReadWaiter(c.readWaitOptions)
   168  	c.reader = reader
   169  	return nil
   170  }
   171  
   172  func (c *clientConn) Read(p []byte) (n int, err error) {
   173  	if c.reader == nil {
   174  		err = c.readResponse()
   175  		if err != nil {
   176  			return
   177  		}
   178  	}
   179  	return c.reader.Read(p)
   180  }
   181  
   182  func (c *clientConn) ReadBuffer(buffer *buf.Buffer) error {
   183  	if c.reader == nil {
   184  		err := c.readResponse()
   185  		if err != nil {
   186  			return err
   187  		}
   188  	}
   189  	return c.reader.ReadBuffer(buffer)
   190  }
   191  
   192  func (c *clientConn) Write(p []byte) (n int, err error) {
   193  	if c.writer == nil {
   194  		err = c.writeRequest(p)
   195  		if err == nil {
   196  			n = len(p)
   197  		}
   198  		return
   199  	}
   200  	return c.writer.Write(p)
   201  }
   202  
   203  func (c *clientConn) WriteBuffer(buffer *buf.Buffer) error {
   204  	if c.writer == nil {
   205  		defer buffer.Release()
   206  		return c.writeRequest(buffer.Bytes())
   207  	}
   208  	return c.writer.WriteBuffer(buffer)
   209  }
   210  
   211  func (c *clientConn) NeedHandshake() bool {
   212  	return c.writer == nil
   213  }
   214  
   215  func (c *clientConn) Upstream() any {
   216  	return c.Conn
   217  }
   218  
   219  func (c *clientConn) WriterMTU() int {
   220  	return MaxPacketSize
   221  }
   222  
   223  type clientPacketConn struct {
   224  	N.AbstractConn
   225  	reader N.ExtendedReader
   226  	writer N.ExtendedWriter
   227  	method *Method
   228  }
   229  
   230  func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   231  	err = c.reader.ReadBuffer(buffer)
   232  	if err != nil {
   233  		return
   234  	}
   235  	return c.readPacket(buffer)
   236  }
   237  
   238  func (c *clientPacketConn) readPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   239  	if buffer.Len() < c.method.keySaltLength {
   240  		return M.Socksaddr{}, C.ErrPacketTooShort
   241  	}
   242  	key := buf.NewSize(c.method.keySaltLength)
   243  	legacykey.Kdf(c.method.key, buffer.To(c.method.keySaltLength), key.Extend(c.method.keySaltLength))
   244  	readCipher, err := c.method.constructor(key.Bytes())
   245  	key.Release()
   246  	if err != nil {
   247  		return
   248  	}
   249  	packet, err := readCipher.Open(buffer.Index(c.method.keySaltLength), rw.ZeroBytes[:readCipher.NonceSize()], buffer.From(c.method.keySaltLength), nil)
   250  	if err != nil {
   251  		return
   252  	}
   253  	buffer.Advance(c.method.keySaltLength)
   254  	buffer.Truncate(len(packet))
   255  	if err != nil {
   256  		return
   257  	}
   258  	destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer)
   259  	if err != nil {
   260  		return
   261  	}
   262  	return destination.Unwrap(), nil
   263  }
   264  
   265  func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   266  	header := buf.With(buffer.ExtendHeader(c.method.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination)))
   267  	header.WriteRandom(c.method.keySaltLength)
   268  	err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
   269  	if err != nil {
   270  		return err
   271  	}
   272  	key := buf.NewSize(c.method.keySaltLength)
   273  	legacykey.Kdf(c.method.key, header.To(c.method.keySaltLength), key.Extend(c.method.keySaltLength))
   274  	writeCipher, err := c.method.constructor(key.Bytes())
   275  	key.Release()
   276  	if err != nil {
   277  		return err
   278  	}
   279  	writeCipher.Seal(buffer.Index(c.method.keySaltLength), rw.ZeroBytes[:writeCipher.NonceSize()], buffer.From(c.method.keySaltLength), nil)
   280  	buffer.Extend(shadowio.Overhead)
   281  	return c.writer.WriteBuffer(buffer)
   282  }
   283  
   284  func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   285  	n, err = c.reader.Read(p)
   286  	if err != nil {
   287  		return
   288  	}
   289  	if n < c.method.keySaltLength {
   290  		err = C.ErrPacketTooShort
   291  		return
   292  	}
   293  	key := buf.NewSize(c.method.keySaltLength)
   294  	legacykey.Kdf(c.method.key, p[:c.method.keySaltLength], key.Extend(c.method.keySaltLength))
   295  	readCipher, err := c.method.constructor(key.Bytes())
   296  	key.Release()
   297  	if err != nil {
   298  		return
   299  	}
   300  	packet, err := readCipher.Open(p[c.method.keySaltLength:c.method.keySaltLength], rw.ZeroBytes[:readCipher.NonceSize()], p[c.method.keySaltLength:n], nil)
   301  	if err != nil {
   302  		return
   303  	}
   304  	packetContent := buf.As(packet)
   305  	destination, err := M.SocksaddrSerializer.ReadAddrPort(packetContent)
   306  	if err != nil {
   307  		return
   308  	}
   309  	if !destination.IsFqdn() {
   310  		addr = destination.UDPAddr()
   311  	} else {
   312  		addr = destination
   313  	}
   314  	n = copy(p, packetContent.Bytes())
   315  	return
   316  }
   317  
   318  func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   319  	destination := M.SocksaddrFromNet(addr)
   320  	buffer := buf.NewSize(c.method.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p) + shadowio.Overhead)
   321  	defer buffer.Release()
   322  	buffer.WriteRandom(c.method.keySaltLength)
   323  	err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination)
   324  	if err != nil {
   325  		return
   326  	}
   327  	common.Must1(buffer.Write(p))
   328  	key := buf.NewSize(c.method.keySaltLength)
   329  	legacykey.Kdf(c.method.key, buffer.To(c.method.keySaltLength), key.Extend(c.method.keySaltLength))
   330  	writeCipher, err := c.method.constructor(key.Bytes())
   331  	key.Release()
   332  	if err != nil {
   333  		return
   334  	}
   335  	writeCipher.Seal(buffer.Index(c.method.keySaltLength), rw.ZeroBytes[:writeCipher.NonceSize()], buffer.From(c.method.keySaltLength), nil)
   336  	buffer.Extend(shadowio.Overhead)
   337  	_, err = c.writer.Write(buffer.Bytes())
   338  	if err != nil {
   339  		return
   340  	}
   341  	return len(p), nil
   342  }
   343  
   344  func (c *clientPacketConn) FrontHeadroom() int {
   345  	return c.method.keySaltLength + M.MaxSocksaddrLength
   346  }
   347  
   348  func (c *clientPacketConn) RearHeadroom() int {
   349  	return shadowio.Overhead
   350  }
   351  
   352  func (c *clientPacketConn) ReaderMTU() int {
   353  	return MaxPacketSize
   354  }
   355  
   356  func (c *clientPacketConn) WriterMTU() int {
   357  	return MaxPacketSize
   358  }
   359  
   360  func (c *clientPacketConn) Upstream() any {
   361  	return c.AbstractConn
   362  }