github.com/laof/lite-speed-test@v0.0.0-20230930011949-1f39b7037845/outbound/vmess.go (about)

     1  package outbound
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"net"
     9  	"net/http"
    10  	"strconv"
    11  	"strings"
    12  
    13  	C "github.com/laof/lite-speed-test/constant"
    14  	"github.com/laof/lite-speed-test/log"
    15  	"github.com/laof/lite-speed-test/stats"
    16  	"github.com/laof/lite-speed-test/transport/dialer"
    17  	"github.com/laof/lite-speed-test/transport/resolver"
    18  	"github.com/laof/lite-speed-test/transport/socks5"
    19  	"github.com/laof/lite-speed-test/transport/vmess"
    20  	"github.com/laof/lite-speed-test/utils"
    21  )
    22  
    23  type Vmess struct {
    24  	*Base
    25  	client *vmess.Client
    26  	option *VmessOption
    27  }
    28  
    29  type VmessOption struct {
    30  	Name           string            `proxy:"name,omitempty"`
    31  	Server         string            `proxy:"server"`
    32  	Port           uint16            `proxy:"port"`
    33  	UUID           string            `proxy:"uuid,omitempty"`
    34  	Password       string            `proxy:"password,omitempty"`
    35  	AlterID        int               `proxy:"alterId,omitempty"`
    36  	Cipher         string            `proxy:"cipher,omitempty"`
    37  	TLS            bool              `proxy:"tls,omitempty"`
    38  	UDP            bool              `proxy:"udp,omitempty"`
    39  	Network        string            `proxy:"network,omitempty"`
    40  	HTTPOpts       HTTPOptions       `proxy:"http-opts,omitempty"`
    41  	HTTP2Opts      HTTP2Options      `proxy:"h2-opts,omitempty"`
    42  	WSPath         string            `proxy:"ws-path,omitempty"`
    43  	WSHeaders      map[string]string `proxy:"ws-headers,omitempty"`
    44  	SkipCertVerify bool              `proxy:"skip-cert-verify,omitempty"`
    45  	ServerName     string            `proxy:"servername,omitempty"`
    46  	Type           string            `proxy:"type,omitempty"`
    47  	WSOpts         WSOptions         `proxy:"ws-opts,omitempty"`
    48  }
    49  
    50  type HTTPOptions struct {
    51  	Method  string              `proxy:"method,omitempty"`
    52  	Path    []string            `proxy:"path,omitempty"`
    53  	Headers map[string][]string `proxy:"headers,omitempty"`
    54  }
    55  
    56  type HTTP2Options struct {
    57  	Host []string `proxy:"host,omitempty"`
    58  	Path string   `proxy:"path,omitempty"`
    59  }
    60  
    61  type GrpcOptions struct {
    62  	GrpcServiceName string `proxy:"grpc-service-name,omitempty"`
    63  }
    64  
    65  type WSOptions struct {
    66  	Path                string            `proxy:"path,omitempty"`
    67  	Headers             map[string]string `proxy:"headers,omitempty"`
    68  	MaxEarlyData        int               `proxy:"max-early-data,omitempty"`
    69  	EarlyDataHeaderName string            `proxy:"early-data-header-name,omitempty"`
    70  }
    71  
    72  // https://github.com/Dreamacro/clash/blob/412b44a98185b2a61500628835afcbd2c115b00e/adapter/outbound/vmess.go#L75
    73  func (v *Vmess) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
    74  	var err error
    75  	switch v.option.Network {
    76  	case "ws":
    77  		host, port, _ := net.SplitHostPort(v.addr)
    78  		wsOpts := &vmess.WebsocketConfig{
    79  			Host: host,
    80  			Port: port,
    81  			Path: v.option.WSPath,
    82  		}
    83  
    84  		if len(v.option.WSHeaders) != 0 {
    85  			header := http.Header{}
    86  			for key, value := range v.option.WSHeaders {
    87  				header.Add(key, value)
    88  			}
    89  			wsOpts.Headers = header
    90  		}
    91  
    92  		if v.option.TLS {
    93  			wsOpts.TLS = true
    94  			wsOpts.SessionCache = getClientSessionCache()
    95  			wsOpts.SkipCertVerify = v.option.SkipCertVerify
    96  			wsOpts.ServerName = v.option.ServerName
    97  		}
    98  		c, err = vmess.StreamWebsocketConn(c, wsOpts)
    99  	case "http":
   100  		// readability first, so just copy default TLS logic
   101  		if v.option.TLS {
   102  			host, _, _ := net.SplitHostPort(v.addr)
   103  			tlsOpts := &vmess.TLSConfig{
   104  				Host:           host,
   105  				SkipCertVerify: v.option.SkipCertVerify,
   106  				SessionCache:   getClientSessionCache(),
   107  			}
   108  
   109  			if v.option.ServerName != "" {
   110  				tlsOpts.Host = v.option.ServerName
   111  			}
   112  
   113  			c, err = vmess.StreamTLSConn(c, tlsOpts)
   114  			if err != nil {
   115  				return nil, err
   116  			}
   117  		}
   118  
   119  		host, _, _ := net.SplitHostPort(v.addr)
   120  		httpOpts := &vmess.HTTPConfig{
   121  			Host:    host,
   122  			Method:  v.option.HTTPOpts.Method,
   123  			Path:    v.option.HTTPOpts.Path,
   124  			Headers: v.option.HTTPOpts.Headers,
   125  		}
   126  
   127  		c = vmess.StreamHTTPConn(c, httpOpts)
   128  	case "h2":
   129  		host, _, _ := net.SplitHostPort(v.addr)
   130  		tlsOpts := vmess.TLSConfig{
   131  			Host:           host,
   132  			SkipCertVerify: v.option.SkipCertVerify,
   133  			SessionCache:   getClientSessionCache(),
   134  			NextProtos:     []string{"h2"},
   135  		}
   136  
   137  		if v.option.ServerName != "" {
   138  			tlsOpts.Host = v.option.ServerName
   139  		}
   140  
   141  		c, err = vmess.StreamTLSConn(c, &tlsOpts)
   142  		if err != nil {
   143  			return nil, err
   144  		}
   145  
   146  		h2Opts := &vmess.H2Config{
   147  			Hosts: v.option.HTTP2Opts.Host,
   148  			Path:  v.option.HTTP2Opts.Path,
   149  		}
   150  
   151  		c, err = vmess.StreamH2Conn(c, h2Opts)
   152  	default:
   153  		// handle TLS
   154  		if v.option.TLS {
   155  			host, _, _ := net.SplitHostPort(v.addr)
   156  			tlsOpts := &vmess.TLSConfig{
   157  				Host:           host,
   158  				SkipCertVerify: v.option.SkipCertVerify,
   159  				SessionCache:   getClientSessionCache(),
   160  			}
   161  
   162  			if v.option.ServerName != "" {
   163  				tlsOpts.Host = v.option.ServerName
   164  			}
   165  
   166  			c, err = vmess.StreamTLSConn(c, tlsOpts)
   167  		}
   168  	}
   169  
   170  	if err != nil {
   171  		return nil, err
   172  	}
   173  
   174  	return v.client.StreamConn(c, parseVmessAddr(metadata))
   175  }
   176  
   177  func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (net.Conn, error) {
   178  	log.I("start dial from", v.addr, "to", metadata.RemoteAddress())
   179  	c, err := dialer.DialContext(ctx, "tcp", v.addr)
   180  	if err != nil {
   181  		return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
   182  	}
   183  	tcpKeepAlive(c)
   184  	if metadata.Type == C.TEST {
   185  		if tcpconn, ok := c.(*net.TCPConn); ok {
   186  			tcpconn.SetLinger(0)
   187  		}
   188  	}
   189  
   190  	log.I("start StreamConn from", v.addr, "to", metadata.RemoteAddress())
   191  	sc := stats.NewConn(c)
   192  	return v.StreamConn(sc, metadata)
   193  }
   194  
   195  func (v *Vmess) DialUDP(metadata *C.Metadata) (net.PacketConn, error) {
   196  	// vmess use stream-oriented udp, so clash needs a net.UDPAddr
   197  	if !metadata.Resolved() {
   198  		ip, err := resolver.ResolveIP(metadata.Host)
   199  		if err != nil {
   200  			return nil, errors.New("can't resolve ip")
   201  		}
   202  		metadata.DstIP = ip
   203  	}
   204  
   205  	ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout)
   206  	defer cancel()
   207  	c, err := dialer.DialContext(ctx, "tcp", v.addr)
   208  	if err != nil {
   209  		return nil, fmt.Errorf("%s connect error: %s", v.addr, err.Error())
   210  	}
   211  	tcpKeepAlive(c)
   212  	sc := stats.NewConn(c)
   213  	c, err = v.StreamConn(sc, metadata)
   214  	if err != nil {
   215  		return nil, fmt.Errorf("new vmess client error: %v", err)
   216  	}
   217  	return &vmessPacketConn{Conn: c, rAddr: metadata.UDPAddr()}, nil
   218  }
   219  
   220  func (v *Vmess) MarshalJSON() ([]byte, error) {
   221  	return json.Marshal(map[string]string{
   222  		"type": "Trojan",
   223  	})
   224  }
   225  
   226  func NewVmess(option *VmessOption) (*Vmess, error) {
   227  	security := strings.ToLower(option.Cipher)
   228  	client, err := vmess.NewClient(vmess.Config{
   229  		UUID:     option.UUID,
   230  		AlterID:  uint16(option.AlterID),
   231  		Security: security,
   232  		HostName: option.Server,
   233  		Port:     option.Port,
   234  		IsAead:   option.AlterID == 0, // VMess AEAD will be used when alterId is 0
   235  	})
   236  	if err != nil {
   237  		return nil, err
   238  	}
   239  	if option.Network == "h2" && !option.TLS {
   240  		return nil, fmt.Errorf("TLS must be true with h2 network")
   241  	}
   242  
   243  	return &Vmess{
   244  		Base: &Base{
   245  			name: option.Name,
   246  			addr: net.JoinHostPort(option.Server, utils.U16toa(option.Port)),
   247  			udp:  option.UDP,
   248  		},
   249  		client: client,
   250  		option: option,
   251  	}, nil
   252  }
   253  
   254  func parseVmessAddr(metadata *C.Metadata) *vmess.DstAddr {
   255  	var addrType byte
   256  	var addr []byte
   257  	switch metadata.AddrType() {
   258  	case socks5.AtypIPv4:
   259  		addrType = byte(vmess.AtypIPv4)
   260  		addr = make([]byte, net.IPv4len)
   261  		copy(addr[:], metadata.DstIP.To4())
   262  	case socks5.AtypIPv6:
   263  		addrType = byte(vmess.AtypIPv6)
   264  		addr = make([]byte, net.IPv6len)
   265  		copy(addr[:], metadata.DstIP.To16())
   266  	case socks5.AtypDomainName:
   267  		addrType = byte(vmess.AtypDomainName)
   268  		addr = make([]byte, len(metadata.Host)+1)
   269  		addr[0] = byte(len(metadata.Host))
   270  		copy(addr[1:], []byte(metadata.Host))
   271  	}
   272  
   273  	port, _ := strconv.ParseUint(metadata.DstPort, 10, 16)
   274  	return &vmess.DstAddr{
   275  		UDP:      metadata.NetWork == C.UDP,
   276  		AddrType: addrType,
   277  		Addr:     addr,
   278  		Port:     uint(port),
   279  	}
   280  }
   281  
   282  type vmessPacketConn struct {
   283  	net.Conn
   284  	rAddr net.Addr
   285  }
   286  
   287  func (uc *vmessPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
   288  	return uc.Conn.Write(b)
   289  }
   290  
   291  func (uc *vmessPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
   292  	n, err := uc.Conn.Read(b)
   293  	return n, uc.rAddr, err
   294  }