github.com/sagernet/sing-shadowsocks@v0.2.6/shadowaead/protocol.go (about)

     1  package shadowaead
     2  
     3  import (
     4  	"crypto/aes"
     5  	"crypto/cipher"
     6  	"crypto/sha1"
     7  	"io"
     8  	"net"
     9  
    10  	"github.com/sagernet/sing-shadowsocks"
    11  	"github.com/sagernet/sing/common"
    12  	"github.com/sagernet/sing/common/buf"
    13  	M "github.com/sagernet/sing/common/metadata"
    14  	N "github.com/sagernet/sing/common/network"
    15  	"github.com/sagernet/sing/common/rw"
    16  
    17  	"golang.org/x/crypto/chacha20poly1305"
    18  	"golang.org/x/crypto/hkdf"
    19  )
    20  
    21  var List = []string{
    22  	"aes-128-gcm",
    23  	"aes-192-gcm",
    24  	"aes-256-gcm",
    25  	"chacha20-ietf-poly1305",
    26  	"xchacha20-ietf-poly1305",
    27  }
    28  
    29  var _ shadowsocks.Method = (*Method)(nil)
    30  
    31  func New(method string, key []byte, password string) (*Method, error) {
    32  	m := &Method{
    33  		name: method,
    34  	}
    35  	switch method {
    36  	case "aes-128-gcm":
    37  		m.keySaltLength = 16
    38  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    39  	case "aes-192-gcm":
    40  		m.keySaltLength = 24
    41  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    42  	case "aes-256-gcm":
    43  		m.keySaltLength = 32
    44  		m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM)
    45  	case "chacha20-ietf-poly1305":
    46  		m.keySaltLength = 32
    47  		m.constructor = chacha20poly1305.New
    48  	case "xchacha20-ietf-poly1305":
    49  		m.keySaltLength = 32
    50  		m.constructor = chacha20poly1305.NewX
    51  	}
    52  	if len(key) == m.keySaltLength {
    53  		m.key = key
    54  	} else if len(key) > 0 {
    55  		return nil, shadowsocks.ErrBadKey
    56  	} else if password == "" {
    57  		return nil, shadowsocks.ErrMissingPassword
    58  	} else {
    59  		m.key = shadowsocks.Key([]byte(password), m.keySaltLength)
    60  	}
    61  	return m, nil
    62  }
    63  
    64  func Kdf(key, iv []byte, buffer *buf.Buffer) {
    65  	kdf := hkdf.New(sha1.New, key, iv, []byte("ss-subkey"))
    66  	common.Must1(buffer.ReadFullFrom(kdf, buffer.FreeLen()))
    67  }
    68  
    69  func aeadCipher(block func(key []byte) (cipher.Block, error), aead func(block cipher.Block) (cipher.AEAD, error)) func(key []byte) (cipher.AEAD, error) {
    70  	return func(key []byte) (cipher.AEAD, error) {
    71  		b, err := block(key)
    72  		if err != nil {
    73  			return nil, err
    74  		}
    75  		return aead(b)
    76  	}
    77  }
    78  
    79  type Method struct {
    80  	name          string
    81  	keySaltLength int
    82  	constructor   func(key []byte) (cipher.AEAD, error)
    83  	key           []byte
    84  }
    85  
    86  func (m *Method) Name() string {
    87  	return m.name
    88  }
    89  
    90  func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) {
    91  	shadowsocksConn := &clientConn{
    92  		Conn:        conn,
    93  		Method:      m,
    94  		destination: destination,
    95  	}
    96  	return shadowsocksConn, shadowsocksConn.writeRequest(nil)
    97  }
    98  
    99  func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn {
   100  	return &clientConn{
   101  		Conn:        conn,
   102  		Method:      m,
   103  		destination: destination,
   104  	}
   105  }
   106  
   107  func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn {
   108  	return &clientPacketConn{m, conn}
   109  }
   110  
   111  type clientConn struct {
   112  	net.Conn
   113  	*Method
   114  	destination M.Socksaddr
   115  	reader      *Reader
   116  	writer      *Writer
   117  }
   118  
   119  func (c *clientConn) writeRequest(payload []byte) error {
   120  	salt := buf.NewSize(c.keySaltLength)
   121  	defer salt.Release()
   122  	salt.WriteRandom(c.keySaltLength)
   123  
   124  	key := buf.NewSize(c.keySaltLength)
   125  
   126  	Kdf(c.key, salt.Bytes(), key)
   127  	writeCipher, err := c.constructor(key.Bytes())
   128  	key.Release()
   129  	if err != nil {
   130  		return err
   131  	}
   132  	writer := NewWriter(c.Conn, writeCipher, MaxPacketSize)
   133  	header := writer.Buffer()
   134  	common.Must1(header.Write(salt.Bytes()))
   135  	bufferedWriter := writer.BufferedWriter(header.Len())
   136  
   137  	if len(payload) > 0 {
   138  		err = M.SocksaddrSerializer.WriteAddrPort(bufferedWriter, c.destination)
   139  		if err != nil {
   140  			return err
   141  		}
   142  
   143  		_, err = bufferedWriter.Write(payload)
   144  		if err != nil {
   145  			return err
   146  		}
   147  	} else {
   148  		err = M.SocksaddrSerializer.WriteAddrPort(bufferedWriter, c.destination)
   149  		if err != nil {
   150  			return err
   151  		}
   152  	}
   153  
   154  	err = bufferedWriter.Flush()
   155  	if err != nil {
   156  		return err
   157  	}
   158  
   159  	c.writer = writer
   160  	return nil
   161  }
   162  
   163  func (c *clientConn) readResponse() error {
   164  	salt := buf.NewSize(c.keySaltLength)
   165  	defer salt.Release()
   166  	_, err := salt.ReadFullFrom(c.Conn, c.keySaltLength)
   167  	if err != nil {
   168  		return err
   169  	}
   170  	key := buf.NewSize(c.keySaltLength)
   171  	defer key.Release()
   172  	Kdf(c.key, salt.Bytes(), key)
   173  	readCipher, err := c.constructor(key.Bytes())
   174  	if err != nil {
   175  		return err
   176  	}
   177  	c.reader = NewReader(
   178  		c.Conn,
   179  		readCipher,
   180  		MaxPacketSize,
   181  	)
   182  	return nil
   183  }
   184  
   185  func (c *clientConn) Read(p []byte) (n int, err error) {
   186  	if c.reader == nil {
   187  		if err = c.readResponse(); err != nil {
   188  			return
   189  		}
   190  	}
   191  	return c.reader.Read(p)
   192  }
   193  
   194  func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) {
   195  	if c.reader == nil {
   196  		if err = c.readResponse(); err != nil {
   197  			return
   198  		}
   199  	}
   200  	return c.reader.WriteTo(w)
   201  }
   202  
   203  func (c *clientConn) Write(p []byte) (n int, err error) {
   204  	if c.writer == nil {
   205  		err = c.writeRequest(p)
   206  		if err != nil {
   207  			return
   208  		}
   209  		return len(p), nil
   210  	}
   211  	return c.writer.Write(p)
   212  }
   213  
   214  func (c *clientConn) NeedHandshake() bool {
   215  	return c.writer == nil
   216  }
   217  
   218  func (c *clientConn) NeedAdditionalReadDeadline() bool {
   219  	return true
   220  }
   221  
   222  func (c *clientConn) Upstream() any {
   223  	return c.Conn
   224  }
   225  
   226  type clientPacketConn struct {
   227  	*Method
   228  	net.Conn
   229  }
   230  
   231  func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   232  	defer buffer.Release()
   233  	header := buf.With(buffer.ExtendHeader(c.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination)))
   234  	header.WriteRandom(c.keySaltLength)
   235  	err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
   236  	if err != nil {
   237  		return err
   238  	}
   239  	key := buf.NewSize(c.keySaltLength)
   240  	Kdf(c.key, buffer.To(c.keySaltLength), key)
   241  	writeCipher, err := c.constructor(key.Bytes())
   242  	key.Release()
   243  	if err != nil {
   244  		return err
   245  	}
   246  	writeCipher.Seal(buffer.Index(c.keySaltLength), rw.ZeroBytes[:writeCipher.NonceSize()], buffer.From(c.keySaltLength), nil)
   247  	buffer.Extend(Overhead)
   248  	return common.Error(c.Write(buffer.Bytes()))
   249  }
   250  
   251  func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
   252  	n, err := c.Read(buffer.FreeBytes())
   253  	if err != nil {
   254  		return M.Socksaddr{}, err
   255  	}
   256  	buffer.Truncate(n)
   257  	if buffer.Len() < c.keySaltLength {
   258  		return M.Socksaddr{}, io.ErrShortBuffer
   259  	}
   260  	key := buf.NewSize(c.keySaltLength)
   261  	Kdf(c.key, buffer.To(c.keySaltLength), key)
   262  	readCipher, err := c.constructor(key.Bytes())
   263  	key.Release()
   264  	if err != nil {
   265  		return M.Socksaddr{}, err
   266  	}
   267  	packet, err := readCipher.Open(buffer.Index(c.keySaltLength), rw.ZeroBytes[:readCipher.NonceSize()], buffer.From(c.keySaltLength), nil)
   268  	if err != nil {
   269  		return M.Socksaddr{}, err
   270  	}
   271  	buffer.Advance(c.keySaltLength)
   272  	buffer.Truncate(len(packet))
   273  	if err != nil {
   274  		return M.Socksaddr{}, err
   275  	}
   276  	return M.SocksaddrSerializer.ReadAddrPort(buffer)
   277  }
   278  
   279  func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   280  	buffer := buf.With(p)
   281  	destination, err := c.ReadPacket(buffer)
   282  	if err != nil {
   283  		return
   284  	}
   285  	if destination.IsFqdn() {
   286  		addr = destination
   287  	} else {
   288  		addr = destination.UDPAddr()
   289  	}
   290  	n = copy(p, buffer.Bytes())
   291  	return
   292  }
   293  
   294  func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   295  	destination := M.SocksaddrFromNet(addr)
   296  	buffer := buf.NewSize(c.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p) + Overhead)
   297  	defer buffer.Release()
   298  	buffer.Resize(c.keySaltLength+M.SocksaddrSerializer.AddrPortLen(destination), 0)
   299  	common.Must1(buffer.Write(p))
   300  	err = c.WritePacket(buffer, destination)
   301  	if err != nil {
   302  		return
   303  	}
   304  	return len(p), nil
   305  }
   306  
   307  func (c *clientPacketConn) FrontHeadroom() int {
   308  	return c.keySaltLength + M.MaxSocksaddrLength
   309  }
   310  
   311  func (c *clientPacketConn) RearHeadroom() int {
   312  	return Overhead
   313  }
   314  
   315  func (c *clientPacketConn) ReaderMTU() int {
   316  	return MaxPacketSize
   317  }
   318  
   319  func (c *clientPacketConn) WriterMTU() int {
   320  	return MaxPacketSize
   321  }
   322  
   323  func (c *clientPacketConn) Upstream() any {
   324  	return c.Conn
   325  }