github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/vmess/client.go (about)

     1  package vmess
     2  
     3  import (
     4  	"bytes"
     5  	"crypto"
     6  	"crypto/aes"
     7  	"crypto/cipher"
     8  	"crypto/md5"
     9  	crand "crypto/rand"
    10  	"crypto/sha256"
    11  	"encoding/binary"
    12  	"errors"
    13  	"fmt"
    14  	"hash/fnv"
    15  	"io"
    16  	"math/rand/v2"
    17  	"net"
    18  	"runtime"
    19  	"strings"
    20  	"time"
    21  
    22  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    23  	ssr "github.com/Asutorufa/yuhaiin/pkg/net/proxy/shadowsocksr/utils"
    24  	"github.com/Asutorufa/yuhaiin/pkg/utils/relay"
    25  	"github.com/Asutorufa/yuhaiin/pkg/utils/uuid"
    26  	"golang.org/x/crypto/chacha20poly1305"
    27  )
    28  
    29  // Request Options
    30  const (
    31  	OptBasicFormat byte = 0
    32  	OptChunkStream byte = 1
    33  	// OptReuseTCPConnection byte = 2
    34  	// OptMetadataObfuscate  byte = 4
    35  )
    36  
    37  // Security types
    38  const (
    39  	SecurityAES128GCM        byte = 3
    40  	SecurityChacha20Poly1305 byte = 4
    41  	SecurityNone             byte = 5
    42  )
    43  
    44  type CMD byte
    45  
    46  // CMD types
    47  const (
    48  	CmdTCP CMD = 1
    49  	CmdUDP CMD = 2
    50  )
    51  
    52  func (c CMD) Byte() byte { return byte(c) }
    53  
    54  var _ net.Conn = (*Conn)(nil)
    55  
    56  // Client vmess client
    57  type Client struct {
    58  	users    []*User
    59  	opt      byte
    60  	security byte
    61  
    62  	isAead bool
    63  }
    64  
    65  // Conn is a connection to vmess server
    66  type Conn struct {
    67  	user     *User
    68  	opt      byte
    69  	security byte
    70  
    71  	addr address
    72  
    73  	reqBodyIV   [16]byte
    74  	reqBodyKey  [16]byte
    75  	reqRespV    byte
    76  	respBodyIV  [16]byte
    77  	respBodyKey [16]byte
    78  
    79  	net.Conn
    80  	dataReader io.ReadCloser
    81  	dataWriter writer
    82  
    83  	isAead bool
    84  	CMD    CMD
    85  }
    86  
    87  // NewClient .
    88  func newClient(uuidStr, security string, alterID int) (*Client, error) {
    89  	uuid, err := uuid.ParseStd(uuidStr)
    90  	if err != nil {
    91  		return nil, err
    92  	}
    93  
    94  	c := &Client{isAead: alterID == 0}
    95  
    96  	user := NewUser(uuid)
    97  	c.users = append(c.users, user)
    98  	c.users = append(c.users, user.GenAlterIDUsers(alterID)...)
    99  
   100  	c.opt = OptChunkStream
   101  
   102  	security = strings.ToLower(security)
   103  	switch security {
   104  	case "aes-128-gcm":
   105  		c.security = SecurityAES128GCM
   106  	case "chacha20-poly1305":
   107  		c.security = SecurityChacha20Poly1305
   108  	case "none":
   109  		c.security = SecurityNone
   110  	case "auto":
   111  		fallthrough
   112  	case "":
   113  		c.security = SecurityChacha20Poly1305
   114  		if runtime.GOARCH == "amd64" || runtime.GOARCH == "s390x" || runtime.GOARCH == "arm64" {
   115  			c.security = SecurityAES128GCM
   116  		}
   117  		// NOTE: use basic format when no method specified
   118  		// c.opt = OptBasicFormat
   119  		// c.security = SecurityNone
   120  	default:
   121  		return nil, errors.New("unknown security type: " + security)
   122  	}
   123  
   124  	return c, nil
   125  }
   126  
   127  func (c *Client) NewConn(rc net.Conn, dst netapi.Address) (net.Conn, error) {
   128  	return c.newConn(rc, CmdTCP, dst)
   129  }
   130  
   131  func (c *Client) NewPacketConn(rc net.Conn, dst netapi.Address) (net.PacketConn, error) {
   132  	return c.newConn(rc, CmdUDP, dst)
   133  }
   134  
   135  // NewConn .
   136  func (c *Client) newConn(rc net.Conn, cmd CMD, dst netapi.Address) (*Conn, error) {
   137  	conn := &Conn{
   138  		isAead:   c.isAead,
   139  		user:     c.users[rand.IntN(len(c.users))],
   140  		opt:      c.opt,
   141  		security: c.security,
   142  		CMD:      cmd,
   143  		addr:     address{dst},
   144  	}
   145  
   146  	randBytes := make([]byte, 33)
   147  	_, _ = crand.Read(randBytes)
   148  
   149  	copy(conn.reqBodyIV[:], randBytes[:16])
   150  	copy(conn.reqBodyKey[:], randBytes[16:32])
   151  	conn.reqRespV = randBytes[32]
   152  
   153  	if !c.isAead {
   154  		conn.respBodyIV = md5.Sum(conn.reqBodyIV[:])
   155  		conn.respBodyKey = md5.Sum(conn.reqBodyKey[:])
   156  	} else {
   157  		rbIV := sha256.Sum256(conn.reqBodyIV[:])
   158  		copy(conn.respBodyIV[:], rbIV[:16])
   159  		rbKey := sha256.Sum256(conn.reqBodyKey[:])
   160  		copy(conn.respBodyKey[:], rbKey[:16])
   161  	}
   162  
   163  	// Request
   164  	req, err := conn.EncodeRequest()
   165  	if err != nil {
   166  		return nil, err
   167  	}
   168  
   169  	_, err = rc.Write(req)
   170  	if err != nil {
   171  		return nil, err
   172  	}
   173  
   174  	conn.Conn = rc
   175  
   176  	return conn, nil
   177  }
   178  
   179  func (c *Conn) RemoteAddr() net.Addr { return c.addr }
   180  
   181  // EncodeRequest encodes requests to network bytes
   182  func (c *Conn) EncodeRequest() ([]byte, error) {
   183  
   184  	buf := new(bytes.Buffer)
   185  
   186  	// Request
   187  	buf.WriteByte(1)           // Ver
   188  	buf.Write(c.reqBodyIV[:])  // IV
   189  	buf.Write(c.reqBodyKey[:]) // Key
   190  	buf.WriteByte(c.reqRespV)  // V
   191  	buf.WriteByte(c.opt)       // Opt
   192  
   193  	// pLen and Sec
   194  	paddingLen := rand.IntN(16)
   195  	buf.WriteByte(byte(paddingLen<<4) | c.security) // P(4bit) and Sec(4bit)
   196  
   197  	buf.WriteByte(0) // reserved
   198  
   199  	buf.WriteByte(c.CMD.Byte()) // cmd
   200  
   201  	// target
   202  	_ = binary.Write(buf, binary.BigEndian, uint16(c.addr.Port().Port())) // port
   203  
   204  	buf.WriteByte(byte(c.addr.Type())) // atyp
   205  	buf.Write(c.addr.Bytes())          // addr
   206  
   207  	// padding
   208  	if paddingLen > 0 {
   209  		_, _ = relay.CopyN(buf, crand.Reader, int64(paddingLen))
   210  	}
   211  
   212  	// F
   213  	fnv1a := fnv.New32a()
   214  	fnv1a.Write(buf.Bytes())
   215  	buf.Write(fnv1a.Sum(nil))
   216  
   217  	if !c.isAead {
   218  		now := time.Now().UTC()
   219  		block, err := aes.NewCipher(c.user.CmdKey[:])
   220  		if err != nil {
   221  			return nil, err
   222  		}
   223  		stream := cipher.NewCFBEncrypter(block, TimestampHash(now))
   224  		stream.XORKeyStream(buf.Bytes(), buf.Bytes())
   225  
   226  		abuf := new(bytes.Buffer)
   227  		ts := make([]byte, 8)
   228  		binary.BigEndian.PutUint64(ts, uint64(now.Unix()))
   229  		abuf.Write(ssr.Hmac(crypto.MD5, c.user.UUID.Bytes(), ts, nil))
   230  		abuf.Write(buf.Bytes())
   231  		return abuf.Bytes(), nil
   232  	}
   233  
   234  	// aead
   235  	var fixedLengthCmdKey [16]byte
   236  	copy(fixedLengthCmdKey[:], c.user.CmdKey[:])
   237  	vmessout := SealVMessAEADHeader(fixedLengthCmdKey, buf.Bytes())
   238  	return vmessout, nil
   239  }
   240  
   241  // DecodeRespHeader .
   242  func (c *Conn) DecodeRespHeader() error {
   243  	var buf []byte
   244  	if !c.isAead {
   245  		// !none aead
   246  		block, err := aes.NewCipher(c.respBodyKey[:])
   247  		if err != nil {
   248  			return err
   249  		}
   250  
   251  		stream := cipher.NewCFBDecrypter(block, c.respBodyIV[:])
   252  
   253  		buf = make([]byte, 4)
   254  		_, err = io.ReadFull(c.Conn, buf)
   255  		if err != nil {
   256  			return err
   257  		}
   258  
   259  		stream.XORKeyStream(buf, buf)
   260  	} else {
   261  		var err error
   262  		buf, err = DecodeResponseHeader(c.respBodyKey[:], c.respBodyIV[:], c.Conn)
   263  		if err != nil {
   264  			return fmt.Errorf("decode response header failed: %w", err)
   265  		}
   266  		if len(buf) < 4 {
   267  			return errors.New("unexpected buffer length")
   268  		}
   269  	}
   270  
   271  	if buf[0] != c.reqRespV {
   272  		return errors.New("unexpected response header")
   273  	}
   274  
   275  	// TODO: Dynamic port support
   276  	if buf[2] != 0 {
   277  		// dataLen := int32(buf[3])
   278  		return errors.New("dynamic port is not supported now")
   279  	}
   280  
   281  	return nil
   282  }
   283  
   284  func (c *Conn) Write(b []byte) (n int, err error) {
   285  	for {
   286  		if c.dataWriter != nil {
   287  			return c.dataWriter.Write(b)
   288  		}
   289  
   290  		c.initWriter()
   291  	}
   292  }
   293  
   294  func (c *Conn) initWriter() {
   295  	c.dataWriter = &connWriter{Conn: c.Conn}
   296  	if c.opt&OptChunkStream == OptChunkStream {
   297  		switch c.security {
   298  		case SecurityNone:
   299  			c.dataWriter = ChunkedWriter(c.Conn)
   300  
   301  		case SecurityAES128GCM:
   302  			block, _ := aes.NewCipher(c.reqBodyKey[:])
   303  			aead, _ := cipher.NewGCM(block)
   304  			c.dataWriter = AEADWriter(c.Conn, aead, c.reqBodyIV[:])
   305  
   306  		case SecurityChacha20Poly1305:
   307  			key := make([]byte, 32)
   308  			t := md5.Sum(c.reqBodyKey[:])
   309  			copy(key, t[:])
   310  			t = md5.Sum(key[:16])
   311  			copy(key[16:], t[:])
   312  			aead, _ := chacha20poly1305.New(key)
   313  			c.dataWriter = AEADWriter(c.Conn, aead, c.reqBodyIV[:])
   314  		}
   315  	}
   316  }
   317  
   318  func (c *Conn) Read(b []byte) (n int, err error) {
   319  	if c.dataReader != nil {
   320  		return c.dataReader.Read(b)
   321  	}
   322  
   323  	err = c.DecodeRespHeader()
   324  	if err != nil {
   325  		return 0, err
   326  	}
   327  
   328  	c.dataReader = c.Conn
   329  	if c.opt&OptChunkStream == OptChunkStream {
   330  		switch c.security {
   331  		case SecurityNone:
   332  			c.dataReader = ChunkedReader(c.Conn)
   333  
   334  		case SecurityAES128GCM:
   335  			block, err := aes.NewCipher(c.respBodyKey[:])
   336  			if err != nil {
   337  				return 0, fmt.Errorf("new aes cipher failed: %w", err)
   338  			}
   339  			aead, err := cipher.NewGCM(block)
   340  			if err != nil {
   341  				return 0, fmt.Errorf("new gcm failed: %w", err)
   342  			}
   343  			c.dataReader = AEADReader(c.Conn, aead, c.respBodyIV[:])
   344  
   345  		case SecurityChacha20Poly1305:
   346  			key := make([]byte, 32)
   347  			t := md5.Sum(c.respBodyKey[:])
   348  			copy(key, t[:])
   349  			t = md5.Sum(key[:16])
   350  			copy(key[16:], t[:])
   351  			aead, err := chacha20poly1305.New(key)
   352  			if err != nil {
   353  				return 0, fmt.Errorf("new chacha20poly1305 failed: %w", err)
   354  			}
   355  			c.dataReader = AEADReader(c.Conn, aead, c.respBodyIV[:])
   356  		}
   357  	}
   358  
   359  	return c.dataReader.Read(b)
   360  }
   361  
   362  func (c *Conn) Close() error {
   363  	if c.dataReader != nil {
   364  		defer c.dataReader.Close()
   365  	}
   366  
   367  	if c.dataWriter != nil {
   368  		defer c.dataWriter.Close()
   369  	}
   370  
   371  	return c.Conn.Close()
   372  }
   373  
   374  func (c *Conn) ReadFrom(b []byte) (int, net.Addr, error) {
   375  	n, err := c.Read(b)
   376  	return n, c.RemoteAddr(), err
   377  }
   378  
   379  func (c *Conn) WriteTo(b []byte, target net.Addr) (int, error) {
   380  	t, err := netapi.ParseSysAddr(target)
   381  	if err != nil {
   382  		return 0, err
   383  	}
   384  
   385  	if t.String() != c.addr.Address.String() {
   386  		return 0, fmt.Errorf("vmess only support symmetric NAT")
   387  	}
   388  
   389  	return c.Write(b)
   390  }