github.com/ipfans/trojan-go@v0.11.0/tunnel/tls/client.go (about)

     1  package tls
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"encoding/pem"
     8  	"io"
     9  	"io/ioutil"
    10  	"strings"
    11  
    12  	utls "github.com/refraction-networking/utls"
    13  
    14  	"github.com/ipfans/trojan-go/common"
    15  	"github.com/ipfans/trojan-go/config"
    16  	"github.com/ipfans/trojan-go/log"
    17  	"github.com/ipfans/trojan-go/tunnel"
    18  	"github.com/ipfans/trojan-go/tunnel/tls/fingerprint"
    19  	"github.com/ipfans/trojan-go/tunnel/transport"
    20  )
    21  
    22  // Client is a tls client
    23  type Client struct {
    24  	verify        bool
    25  	sni           string
    26  	ca            *x509.CertPool
    27  	cipher        []uint16
    28  	sessionTicket bool
    29  	reuseSession  bool
    30  	fingerprint   string
    31  	helloID       utls.ClientHelloID
    32  	keyLogger     io.WriteCloser
    33  	underlay      tunnel.Client
    34  }
    35  
    36  func (c *Client) Close() error {
    37  	if c.keyLogger != nil {
    38  		c.keyLogger.Close()
    39  	}
    40  	return c.underlay.Close()
    41  }
    42  
    43  func (c *Client) DialPacket(tunnel.Tunnel) (tunnel.PacketConn, error) {
    44  	panic("not supported")
    45  }
    46  
    47  func (c *Client) DialConn(_ *tunnel.Address, overlay tunnel.Tunnel) (tunnel.Conn, error) {
    48  	conn, err := c.underlay.DialConn(nil, &Tunnel{})
    49  	if err != nil {
    50  		return nil, common.NewError("tls failed to dial conn").Base(err)
    51  	}
    52  
    53  	if c.fingerprint != "" {
    54  		// utls fingerprint
    55  		tlsConn := utls.UClient(conn, &utls.Config{
    56  			RootCAs:            c.ca,
    57  			ServerName:         c.sni,
    58  			InsecureSkipVerify: !c.verify,
    59  			KeyLogWriter:       c.keyLogger,
    60  		}, c.helloID)
    61  		if err := tlsConn.Handshake(); err != nil {
    62  			return nil, common.NewError("tls failed to handshake with remote server").Base(err)
    63  		}
    64  		return &transport.Conn{
    65  			Conn: tlsConn,
    66  		}, nil
    67  	}
    68  	// golang default tls library
    69  	tlsConn := tls.Client(conn, &tls.Config{
    70  		InsecureSkipVerify:     !c.verify,
    71  		ServerName:             c.sni,
    72  		RootCAs:                c.ca,
    73  		KeyLogWriter:           c.keyLogger,
    74  		CipherSuites:           c.cipher,
    75  		SessionTicketsDisabled: !c.sessionTicket,
    76  	})
    77  	err = tlsConn.Handshake()
    78  	if err != nil {
    79  		return nil, common.NewError("tls failed to handshake with remote server").Base(err)
    80  	}
    81  	return &transport.Conn{
    82  		Conn: tlsConn,
    83  	}, nil
    84  }
    85  
    86  // NewClient creates a tls client
    87  func NewClient(ctx context.Context, underlay tunnel.Client) (*Client, error) {
    88  	cfg := config.FromContext(ctx, Name).(*Config)
    89  
    90  	helloID := utls.ClientHelloID{}
    91  	if cfg.TLS.Fingerprint != "" {
    92  		switch cfg.TLS.Fingerprint {
    93  		case "firefox":
    94  			helloID = utls.HelloFirefox_Auto
    95  		case "chrome":
    96  			helloID = utls.HelloChrome_Auto
    97  		case "ios":
    98  			helloID = utls.HelloIOS_Auto
    99  		default:
   100  			return nil, common.NewError("invalid fingerprint " + cfg.TLS.Fingerprint)
   101  		}
   102  		log.Info("tls fingerprint", cfg.TLS.Fingerprint, "applied")
   103  	}
   104  
   105  	if cfg.TLS.SNI == "" {
   106  		cfg.TLS.SNI = cfg.RemoteHost
   107  		log.Warn("tls sni is unspecified")
   108  	}
   109  
   110  	client := &Client{
   111  		underlay:      underlay,
   112  		verify:        cfg.TLS.Verify,
   113  		sni:           cfg.TLS.SNI,
   114  		cipher:        fingerprint.ParseCipher(strings.Split(cfg.TLS.Cipher, ":")),
   115  		sessionTicket: cfg.TLS.ReuseSession,
   116  		fingerprint:   cfg.TLS.Fingerprint,
   117  		helloID:       helloID,
   118  	}
   119  
   120  	if cfg.TLS.CertPath != "" {
   121  		caCertByte, err := ioutil.ReadFile(cfg.TLS.CertPath)
   122  		if err != nil {
   123  			return nil, common.NewError("failed to load cert file").Base(err)
   124  		}
   125  		client.ca = x509.NewCertPool()
   126  		ok := client.ca.AppendCertsFromPEM(caCertByte)
   127  		if !ok {
   128  			log.Warn("invalid cert list")
   129  		}
   130  		log.Info("using custom cert")
   131  
   132  		// print cert info
   133  		pemCerts := caCertByte
   134  		for len(pemCerts) > 0 {
   135  			var block *pem.Block
   136  			block, pemCerts = pem.Decode(pemCerts)
   137  			if block == nil {
   138  				break
   139  			}
   140  			if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
   141  				continue
   142  			}
   143  			cert, err := x509.ParseCertificate(block.Bytes)
   144  			if err != nil {
   145  				continue
   146  			}
   147  			log.Trace("issuer:", cert.Issuer, "subject:", cert.Subject)
   148  		}
   149  	}
   150  
   151  	if cfg.TLS.CertPath == "" {
   152  		log.Info("cert is unspecified, using default ca list")
   153  	}
   154  
   155  	log.Debug("tls client created")
   156  	return client, nil
   157  }