github.com/laof/lite-speed-test@v0.0.0-20230930011949-1f39b7037845/outbound/trojan.go (about)

     1  package outbound
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"encoding/json"
     7  	"fmt"
     8  	"net"
     9  	"net/http"
    10  	"strconv"
    11  
    12  	C "github.com/laof/lite-speed-test/constant"
    13  	"github.com/laof/lite-speed-test/stats"
    14  	"github.com/laof/lite-speed-test/transport/dialer"
    15  	"github.com/laof/lite-speed-test/transport/gun"
    16  	"github.com/laof/lite-speed-test/transport/trojan"
    17  	"golang.org/x/net/http2"
    18  )
    19  
    20  type Trojan struct {
    21  	*Base
    22  	instance *trojan.Trojan
    23  	option   *TrojanOption
    24  
    25  	// for gun mux
    26  	gunTLSConfig *tls.Config
    27  	gunConfig    *gun.Config
    28  	transport    *http2.Transport
    29  }
    30  
    31  type TrojanOption struct {
    32  	Name           string      `proxy:"name,omitempty"`
    33  	Server         string      `proxy:"server"`
    34  	Port           int         `proxy:"port"`
    35  	Password       string      `proxy:"password"`
    36  	ALPN           []string    `proxy:"alpn,omitempty"`
    37  	SNI            string      `proxy:"sni,omitempty"`
    38  	SkipCertVerify bool        `proxy:"skip-cert-verify,omitempty"`
    39  	UDP            bool        `proxy:"udp,omitempty"`
    40  	Remarks        string      `proxy:"remarks,omitempty"`
    41  	Network        string      `proxy:"network,omitempty"`
    42  	GrpcOpts       GrpcOptions `proxy:"grpc-opts,omitempty"`
    43  	WSOpts         WSOptions   `proxy:"ws-opts,omitempty"`
    44  }
    45  
    46  func (t *Trojan) plainStream(c net.Conn) (net.Conn, error) {
    47  	if t.option.Network == "ws" {
    48  		host, port, _ := net.SplitHostPort(t.addr)
    49  		wsOpts := &trojan.WebsocketOption{
    50  			Host: host,
    51  			Port: port,
    52  			Path: t.option.WSOpts.Path,
    53  		}
    54  
    55  		if t.option.SNI != "" {
    56  			wsOpts.Host = t.option.SNI
    57  		}
    58  
    59  		if len(t.option.WSOpts.Headers) != 0 {
    60  			header := http.Header{}
    61  			for key, value := range t.option.WSOpts.Headers {
    62  				header.Add(key, value)
    63  			}
    64  			wsOpts.Headers = header
    65  		}
    66  
    67  		return t.instance.StreamWebsocketConn(c, wsOpts)
    68  	}
    69  
    70  	return t.instance.StreamConn(c)
    71  }
    72  
    73  func (t *Trojan) StreamConn(c net.Conn, metadata *C.Metadata) (net.Conn, error) {
    74  	c, err := t.plainStream(c)
    75  	if err != nil {
    76  		return nil, fmt.Errorf("%s connect error: %w", t.addr, err)
    77  	}
    78  
    79  	err = t.instance.WriteHeader(c, trojan.CommandTCP, serializesSocksAddr(metadata))
    80  	return c, err
    81  }
    82  
    83  func (t *Trojan) DialContext(ctx context.Context, metadata *C.Metadata) (net.Conn, error) {
    84  	// gun transport
    85  	if t.transport != nil {
    86  		c, err := gun.StreamGunWithTransport(t.transport, t.gunConfig)
    87  		if err != nil {
    88  			return nil, err
    89  		}
    90  
    91  		if err = t.instance.WriteHeader(c, trojan.CommandTCP, serializesSocksAddr(metadata)); err != nil {
    92  			c.Close()
    93  			return nil, err
    94  		}
    95  		sc := stats.NewStatsConn(c)
    96  		return t.StreamConn(sc, metadata)
    97  	}
    98  
    99  	c, err := dialer.DialContext(ctx, "tcp", t.addr)
   100  	if err != nil {
   101  		return nil, fmt.Errorf("%s connect error: %w", t.addr, err)
   102  	}
   103  	tcpKeepAlive(c)
   104  	sc := stats.NewStatsConn(c)
   105  	return t.StreamConn(sc, metadata)
   106  }
   107  
   108  // TODO: grpc transport
   109  func (t *Trojan) DialUDP(metadata *C.Metadata) (_ net.PacketConn, err error) {
   110  	var c net.Conn
   111  
   112  	// grpc transport
   113  	if t.transport != nil {
   114  		c, err = gun.StreamGunWithTransport(t.transport, t.gunConfig)
   115  		if err != nil {
   116  			return nil, fmt.Errorf("%s connect error: %w", t.addr, err)
   117  		}
   118  	} else {
   119  		ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout)
   120  		defer cancel()
   121  		c, err := dialer.DialContext(ctx, "tcp", t.addr)
   122  		if err != nil {
   123  			return nil, fmt.Errorf("%s connect error: %w", t.addr, err)
   124  		}
   125  		tcpKeepAlive(c)
   126  		c, err = t.instance.StreamConn(c)
   127  		if err != nil {
   128  			return nil, fmt.Errorf("%s connect error: %w", t.addr, err)
   129  		}
   130  	}
   131  
   132  	err = t.instance.WriteHeader(c, trojan.CommandUDP, serializesSocksAddr(metadata))
   133  	if err != nil {
   134  		return nil, err
   135  	}
   136  
   137  	pc := t.instance.PacketConn(c)
   138  	return pc, err
   139  }
   140  
   141  func (t *Trojan) MarshalJSON() ([]byte, error) {
   142  	return json.Marshal(map[string]string{
   143  		"type": "Trojan",
   144  	})
   145  }
   146  
   147  func NewTrojan(option *TrojanOption) (*Trojan, error) {
   148  	addr := net.JoinHostPort(option.Server, strconv.Itoa(option.Port))
   149  
   150  	tOption := &trojan.Option{
   151  		Password:           option.Password,
   152  		ALPN:               option.ALPN,
   153  		ServerName:         option.Server,
   154  		SkipCertVerify:     option.SkipCertVerify,
   155  		ClientSessionCache: getClientSessionCache(),
   156  	}
   157  
   158  	if option.SNI != "" {
   159  		tOption.ServerName = option.SNI
   160  	}
   161  
   162  	t := &Trojan{
   163  		Base: &Base{
   164  			name: option.Name,
   165  			addr: addr,
   166  			udp:  option.UDP,
   167  		},
   168  		instance: trojan.New(tOption),
   169  		option:   option,
   170  	}
   171  
   172  	// if option.Network == "grpc" {
   173  	// 	dialFn := func(network, addr string) (net.Conn, error) {
   174  	// 		c, err := dialer.DialContext(context.Background(), "tcp", t.addr)
   175  	// 		if err != nil {
   176  	// 			return nil, fmt.Errorf("%s connect error: %s", t.addr, err.Error())
   177  	// 		}
   178  	// 		tcpKeepAlive(c)
   179  	// 		return c, nil
   180  	// 	}
   181  
   182  	// 	tlsConfig := &tls.Config{
   183  	// 		NextProtos:         option.ALPN,
   184  	// 		MinVersion:         tls.VersionTLS12,
   185  	// 		InsecureSkipVerify: tOption.SkipCertVerify,
   186  	// 		ServerName:         tOption.ServerName,
   187  	// 	}
   188  	// 	t.transport = gun.NewHTTP2Client(dialFn, tlsConfig)
   189  	// 	t.gunTLSConfig = tlsConfig
   190  	// 	t.gunConfig = &gun.Config{
   191  	// 		ServiceName: option.GrpcOpts.GrpcServiceName,
   192  	// 		Host:        tOption.ServerName,
   193  	// 	}
   194  	// }
   195  
   196  	return t, nil
   197  }