github.com/ipfans/trojan-go@v0.11.0/tunnel/tls/server.go (about)

     1  package tls
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"context"
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"encoding/pem"
    10  	"io"
    11  	"io/ioutil"
    12  	"net"
    13  	"net/http"
    14  	"os"
    15  	"strings"
    16  	"sync"
    17  	"sync/atomic"
    18  	"time"
    19  
    20  	"github.com/ipfans/trojan-go/common"
    21  	"github.com/ipfans/trojan-go/config"
    22  	"github.com/ipfans/trojan-go/log"
    23  	"github.com/ipfans/trojan-go/redirector"
    24  	"github.com/ipfans/trojan-go/tunnel"
    25  	"github.com/ipfans/trojan-go/tunnel/tls/fingerprint"
    26  	"github.com/ipfans/trojan-go/tunnel/transport"
    27  	"github.com/ipfans/trojan-go/tunnel/websocket"
    28  )
    29  
    30  // Server is a tls server
    31  type Server struct {
    32  	fallbackAddress    *tunnel.Address
    33  	verifySNI          bool
    34  	sni                string
    35  	alpn               []string
    36  	PreferServerCipher bool
    37  	keyPair            []tls.Certificate
    38  	keyPairLock        sync.RWMutex
    39  	httpResp           []byte
    40  	cipherSuite        []uint16
    41  	sessionTicket      bool
    42  	curve              []tls.CurveID
    43  	keyLogger          io.WriteCloser
    44  	connChan           chan tunnel.Conn
    45  	wsChan             chan tunnel.Conn
    46  	redir              *redirector.Redirector
    47  	ctx                context.Context
    48  	cancel             context.CancelFunc
    49  	underlay           tunnel.Server
    50  	nextHTTP           int32
    51  	portOverrider      map[string]int
    52  }
    53  
    54  func (s *Server) Close() error {
    55  	s.cancel()
    56  	if s.keyLogger != nil {
    57  		s.keyLogger.Close()
    58  	}
    59  	return s.underlay.Close()
    60  }
    61  
    62  func isDomainNameMatched(pattern string, domainName string) bool {
    63  	if strings.HasPrefix(pattern, "*.") {
    64  		suffix := pattern[2:]
    65  		domainPrefixLen := len(domainName) - len(suffix) - 1
    66  		return strings.HasSuffix(domainName, suffix) && domainPrefixLen > 0 && !strings.Contains(domainName[:domainPrefixLen], ".")
    67  	}
    68  	return pattern == domainName
    69  }
    70  
    71  func (s *Server) acceptLoop() {
    72  	for {
    73  		conn, err := s.underlay.AcceptConn(&Tunnel{})
    74  		if err != nil {
    75  			select {
    76  			case <-s.ctx.Done():
    77  			default:
    78  				log.Fatal(common.NewError("transport accept error" + err.Error()))
    79  			}
    80  			return
    81  		}
    82  		go func(conn net.Conn) {
    83  			tlsConfig := &tls.Config{
    84  				CipherSuites:             s.cipherSuite,
    85  				PreferServerCipherSuites: s.PreferServerCipher,
    86  				SessionTicketsDisabled:   !s.sessionTicket,
    87  				NextProtos:               s.alpn,
    88  				KeyLogWriter:             s.keyLogger,
    89  				GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
    90  					s.keyPairLock.RLock()
    91  					defer s.keyPairLock.RUnlock()
    92  					sni := s.keyPair[0].Leaf.Subject.CommonName
    93  					dnsNames := s.keyPair[0].Leaf.DNSNames
    94  					if s.sni != "" {
    95  						sni = s.sni
    96  					}
    97  					matched := isDomainNameMatched(sni, hello.ServerName)
    98  					for _, name := range dnsNames {
    99  						if isDomainNameMatched(name, hello.ServerName) {
   100  							matched = true
   101  							break
   102  						}
   103  					}
   104  					if s.verifySNI && !matched {
   105  						return nil, common.NewError("sni mismatched: " + hello.ServerName + ", expected: " + s.sni)
   106  					}
   107  					return &s.keyPair[0], nil
   108  				},
   109  			}
   110  
   111  			// ------------------------ WAR ZONE ----------------------------
   112  
   113  			handshakeRewindConn := common.NewRewindConn(conn)
   114  			handshakeRewindConn.SetBufferSize(2048)
   115  
   116  			tlsConn := tls.Server(handshakeRewindConn, tlsConfig)
   117  			err = tlsConn.Handshake()
   118  			handshakeRewindConn.StopBuffering()
   119  
   120  			if err != nil {
   121  				if strings.Contains(err.Error(), "first record does not look like a TLS handshake") {
   122  					// not a valid tls client hello
   123  					handshakeRewindConn.Rewind()
   124  					log.Error(common.NewError("failed to perform tls handshake with " + tlsConn.RemoteAddr().String() + ", redirecting").Base(err))
   125  					switch {
   126  					case s.fallbackAddress != nil:
   127  						s.redir.Redirect(&redirector.Redirection{
   128  							InboundConn: handshakeRewindConn,
   129  							RedirectTo:  s.fallbackAddress,
   130  						})
   131  					case s.httpResp != nil:
   132  						handshakeRewindConn.Write(s.httpResp)
   133  						handshakeRewindConn.Close()
   134  					default:
   135  						handshakeRewindConn.Close()
   136  					}
   137  				} else {
   138  					// in other cases, simply close it
   139  					tlsConn.Close()
   140  					log.Error(common.NewError("tls handshake failed").Base(err))
   141  				}
   142  				return
   143  			}
   144  
   145  			log.Info("tls connection from", conn.RemoteAddr())
   146  			state := tlsConn.ConnectionState()
   147  			log.Trace("tls handshake", tls.CipherSuiteName(state.CipherSuite), state.DidResume, state.NegotiatedProtocol)
   148  
   149  			// we use a real http header parser to mimic a real http server
   150  			rewindConn := common.NewRewindConn(tlsConn)
   151  			rewindConn.SetBufferSize(1024)
   152  			r := bufio.NewReader(rewindConn)
   153  			httpReq, err := http.ReadRequest(r)
   154  			rewindConn.Rewind()
   155  			rewindConn.StopBuffering()
   156  			if err != nil {
   157  				// this is not a http request. pass it to trojan protocol layer for further inspection
   158  				s.connChan <- &transport.Conn{
   159  					Conn: rewindConn,
   160  				}
   161  			} else {
   162  				if atomic.LoadInt32(&s.nextHTTP) != 1 {
   163  					// there is no websocket layer waiting for connections, redirect it
   164  					log.Error("incoming http request, but no websocket server is listening")
   165  					s.redir.Redirect(&redirector.Redirection{
   166  						InboundConn: rewindConn,
   167  						RedirectTo:  s.fallbackAddress,
   168  					})
   169  					return
   170  				}
   171  				// this is a http request, pass it to websocket protocol layer
   172  				log.Debug("http req: ", httpReq)
   173  				s.wsChan <- &transport.Conn{
   174  					Conn: rewindConn,
   175  				}
   176  			}
   177  		}(conn)
   178  	}
   179  }
   180  
   181  func (s *Server) AcceptConn(overlay tunnel.Tunnel) (tunnel.Conn, error) {
   182  	if _, ok := overlay.(*websocket.Tunnel); ok {
   183  		atomic.StoreInt32(&s.nextHTTP, 1)
   184  		log.Debug("next proto http")
   185  		// websocket overlay
   186  		select {
   187  		case conn := <-s.wsChan:
   188  			return conn, nil
   189  		case <-s.ctx.Done():
   190  			return nil, common.NewError("transport server closed")
   191  		}
   192  	}
   193  	// trojan overlay
   194  	select {
   195  	case conn := <-s.connChan:
   196  		return conn, nil
   197  	case <-s.ctx.Done():
   198  		return nil, common.NewError("transport server closed")
   199  	}
   200  }
   201  
   202  func (s *Server) AcceptPacket(tunnel.Tunnel) (tunnel.PacketConn, error) {
   203  	panic("not supported")
   204  }
   205  
   206  func (s *Server) checkKeyPairLoop(checkRate time.Duration, keyPath string, certPath string, password string) {
   207  	var lastKeyBytes, lastCertBytes []byte
   208  	ticker := time.NewTicker(checkRate)
   209  
   210  	for {
   211  		log.Debug("checking cert...")
   212  		keyBytes, err := ioutil.ReadFile(keyPath)
   213  		if err != nil {
   214  			log.Error(common.NewError("tls failed to check key").Base(err))
   215  			continue
   216  		}
   217  		certBytes, err := ioutil.ReadFile(certPath)
   218  		if err != nil {
   219  			log.Error(common.NewError("tls failed to check cert").Base(err))
   220  			continue
   221  		}
   222  		if !bytes.Equal(keyBytes, lastKeyBytes) || !bytes.Equal(lastCertBytes, certBytes) {
   223  			log.Info("new key pair detected")
   224  			keyPair, err := loadKeyPair(keyPath, certPath, password)
   225  			if err != nil {
   226  				log.Error(common.NewError("tls failed to load new key pair").Base(err))
   227  				continue
   228  			}
   229  			s.keyPairLock.Lock()
   230  			s.keyPair = []tls.Certificate{*keyPair}
   231  			s.keyPairLock.Unlock()
   232  			lastKeyBytes = keyBytes
   233  			lastCertBytes = certBytes
   234  		}
   235  
   236  		select {
   237  		case <-ticker.C:
   238  			continue
   239  		case <-s.ctx.Done():
   240  			log.Debug("exiting")
   241  			ticker.Stop()
   242  			return
   243  		}
   244  	}
   245  }
   246  
   247  func loadKeyPair(keyPath string, certPath string, password string) (*tls.Certificate, error) {
   248  	if password != "" {
   249  		keyFile, err := ioutil.ReadFile(keyPath)
   250  		if err != nil {
   251  			return nil, common.NewError("failed to load key file").Base(err)
   252  		}
   253  		keyBlock, _ := pem.Decode(keyFile)
   254  		if keyBlock == nil {
   255  			return nil, common.NewError("failed to decode key file").Base(err)
   256  		}
   257  		decryptedKey, err := x509.DecryptPEMBlock(keyBlock, []byte(password))
   258  		if err == nil {
   259  			return nil, common.NewError("failed to decrypt key").Base(err)
   260  		}
   261  
   262  		certFile, err := ioutil.ReadFile(certPath)
   263  		certBlock, _ := pem.Decode(certFile)
   264  		if certBlock == nil {
   265  			return nil, common.NewError("failed to decode cert file").Base(err)
   266  		}
   267  
   268  		keyPair, err := tls.X509KeyPair(certBlock.Bytes, decryptedKey)
   269  		if err != nil {
   270  			return nil, err
   271  		}
   272  		keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
   273  		if err != nil {
   274  			return nil, common.NewError("failed to parse leaf certificate").Base(err)
   275  		}
   276  
   277  		return &keyPair, nil
   278  	}
   279  	keyPair, err := tls.LoadX509KeyPair(certPath, keyPath)
   280  	if err != nil {
   281  		return nil, common.NewError("failed to load key pair").Base(err)
   282  	}
   283  	keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
   284  	if err != nil {
   285  		return nil, common.NewError("failed to parse leaf certificate").Base(err)
   286  	}
   287  	return &keyPair, nil
   288  }
   289  
   290  // NewServer creates a tls layer server
   291  func NewServer(ctx context.Context, underlay tunnel.Server) (*Server, error) {
   292  	cfg := config.FromContext(ctx, Name).(*Config)
   293  
   294  	var fallbackAddress *tunnel.Address
   295  	var httpResp []byte
   296  	if cfg.TLS.FallbackPort != 0 {
   297  		if cfg.TLS.FallbackHost == "" {
   298  			cfg.TLS.FallbackHost = cfg.RemoteHost
   299  			log.Warn("empty tls fallback address")
   300  		}
   301  		fallbackAddress = tunnel.NewAddressFromHostPort("tcp", cfg.TLS.FallbackHost, cfg.TLS.FallbackPort)
   302  		fallbackConn, err := net.Dial("tcp", fallbackAddress.String())
   303  		if err != nil {
   304  			return nil, common.NewError("invalid fallback address").Base(err)
   305  		}
   306  		fallbackConn.Close()
   307  	} else {
   308  		log.Warn("empty tls fallback port")
   309  		if cfg.TLS.HTTPResponseFileName != "" {
   310  			httpRespBody, err := ioutil.ReadFile(cfg.TLS.HTTPResponseFileName)
   311  			if err != nil {
   312  				return nil, common.NewError("invalid response file").Base(err)
   313  			}
   314  			httpResp = httpRespBody
   315  		} else {
   316  			log.Warn("empty tls http response")
   317  		}
   318  	}
   319  
   320  	keyPair, err := loadKeyPair(cfg.TLS.KeyPath, cfg.TLS.CertPath, cfg.TLS.KeyPassword)
   321  	if err != nil {
   322  		return nil, common.NewError("tls failed to load key pair")
   323  	}
   324  
   325  	var keyLogger io.WriteCloser
   326  	if cfg.TLS.KeyLogPath != "" {
   327  		log.Warn("tls key logging activated. USE OF KEY LOGGING COMPROMISES SECURITY. IT SHOULD ONLY BE USED FOR DEBUGGING.")
   328  		file, err := os.OpenFile(cfg.TLS.KeyLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o600)
   329  		if err != nil {
   330  			return nil, common.NewError("failed to open key log file").Base(err)
   331  		}
   332  		keyLogger = file
   333  	}
   334  
   335  	var cipherSuite []uint16
   336  	if len(cfg.TLS.Cipher) != 0 {
   337  		cipherSuite = fingerprint.ParseCipher(strings.Split(cfg.TLS.Cipher, ":"))
   338  	}
   339  
   340  	ctx, cancel := context.WithCancel(ctx)
   341  	server := &Server{
   342  		underlay:           underlay,
   343  		fallbackAddress:    fallbackAddress,
   344  		httpResp:           httpResp,
   345  		verifySNI:          cfg.TLS.VerifyHostName,
   346  		sni:                cfg.TLS.SNI,
   347  		alpn:               cfg.TLS.ALPN,
   348  		PreferServerCipher: cfg.TLS.PreferServerCipher,
   349  		sessionTicket:      cfg.TLS.ReuseSession,
   350  		connChan:           make(chan tunnel.Conn, 32),
   351  		wsChan:             make(chan tunnel.Conn, 32),
   352  		redir:              redirector.NewRedirector(ctx),
   353  		keyPair:            []tls.Certificate{*keyPair},
   354  		keyLogger:          keyLogger,
   355  		cipherSuite:        cipherSuite,
   356  		ctx:                ctx,
   357  		cancel:             cancel,
   358  	}
   359  
   360  	go server.acceptLoop()
   361  	if cfg.TLS.CertCheckRate > 0 {
   362  		go server.checkKeyPairLoop(
   363  			time.Second*time.Duration(cfg.TLS.CertCheckRate),
   364  			cfg.TLS.KeyPath,
   365  			cfg.TLS.CertPath,
   366  			cfg.TLS.KeyPassword,
   367  		)
   368  	}
   369  
   370  	log.Debug("tls server created")
   371  	return server, nil
   372  }