github.com/ipfans/trojan-go@v0.11.0/tunnel/tls/server.go (about) 1 package tls 2 3 import ( 4 "bufio" 5 "bytes" 6 "context" 7 "crypto/tls" 8 "crypto/x509" 9 "encoding/pem" 10 "io" 11 "io/ioutil" 12 "net" 13 "net/http" 14 "os" 15 "strings" 16 "sync" 17 "sync/atomic" 18 "time" 19 20 "github.com/ipfans/trojan-go/common" 21 "github.com/ipfans/trojan-go/config" 22 "github.com/ipfans/trojan-go/log" 23 "github.com/ipfans/trojan-go/redirector" 24 "github.com/ipfans/trojan-go/tunnel" 25 "github.com/ipfans/trojan-go/tunnel/tls/fingerprint" 26 "github.com/ipfans/trojan-go/tunnel/transport" 27 "github.com/ipfans/trojan-go/tunnel/websocket" 28 ) 29 30 // Server is a tls server 31 type Server struct { 32 fallbackAddress *tunnel.Address 33 verifySNI bool 34 sni string 35 alpn []string 36 PreferServerCipher bool 37 keyPair []tls.Certificate 38 keyPairLock sync.RWMutex 39 httpResp []byte 40 cipherSuite []uint16 41 sessionTicket bool 42 curve []tls.CurveID 43 keyLogger io.WriteCloser 44 connChan chan tunnel.Conn 45 wsChan chan tunnel.Conn 46 redir *redirector.Redirector 47 ctx context.Context 48 cancel context.CancelFunc 49 underlay tunnel.Server 50 nextHTTP int32 51 portOverrider map[string]int 52 } 53 54 func (s *Server) Close() error { 55 s.cancel() 56 if s.keyLogger != nil { 57 s.keyLogger.Close() 58 } 59 return s.underlay.Close() 60 } 61 62 func isDomainNameMatched(pattern string, domainName string) bool { 63 if strings.HasPrefix(pattern, "*.") { 64 suffix := pattern[2:] 65 domainPrefixLen := len(domainName) - len(suffix) - 1 66 return strings.HasSuffix(domainName, suffix) && domainPrefixLen > 0 && !strings.Contains(domainName[:domainPrefixLen], ".") 67 } 68 return pattern == domainName 69 } 70 71 func (s *Server) acceptLoop() { 72 for { 73 conn, err := s.underlay.AcceptConn(&Tunnel{}) 74 if err != nil { 75 select { 76 case <-s.ctx.Done(): 77 default: 78 log.Fatal(common.NewError("transport accept error" + err.Error())) 79 } 80 return 81 } 82 go func(conn net.Conn) { 83 tlsConfig := &tls.Config{ 84 CipherSuites: s.cipherSuite, 85 PreferServerCipherSuites: s.PreferServerCipher, 86 SessionTicketsDisabled: !s.sessionTicket, 87 NextProtos: s.alpn, 88 KeyLogWriter: s.keyLogger, 89 GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { 90 s.keyPairLock.RLock() 91 defer s.keyPairLock.RUnlock() 92 sni := s.keyPair[0].Leaf.Subject.CommonName 93 dnsNames := s.keyPair[0].Leaf.DNSNames 94 if s.sni != "" { 95 sni = s.sni 96 } 97 matched := isDomainNameMatched(sni, hello.ServerName) 98 for _, name := range dnsNames { 99 if isDomainNameMatched(name, hello.ServerName) { 100 matched = true 101 break 102 } 103 } 104 if s.verifySNI && !matched { 105 return nil, common.NewError("sni mismatched: " + hello.ServerName + ", expected: " + s.sni) 106 } 107 return &s.keyPair[0], nil 108 }, 109 } 110 111 // ------------------------ WAR ZONE ---------------------------- 112 113 handshakeRewindConn := common.NewRewindConn(conn) 114 handshakeRewindConn.SetBufferSize(2048) 115 116 tlsConn := tls.Server(handshakeRewindConn, tlsConfig) 117 err = tlsConn.Handshake() 118 handshakeRewindConn.StopBuffering() 119 120 if err != nil { 121 if strings.Contains(err.Error(), "first record does not look like a TLS handshake") { 122 // not a valid tls client hello 123 handshakeRewindConn.Rewind() 124 log.Error(common.NewError("failed to perform tls handshake with " + tlsConn.RemoteAddr().String() + ", redirecting").Base(err)) 125 switch { 126 case s.fallbackAddress != nil: 127 s.redir.Redirect(&redirector.Redirection{ 128 InboundConn: handshakeRewindConn, 129 RedirectTo: s.fallbackAddress, 130 }) 131 case s.httpResp != nil: 132 handshakeRewindConn.Write(s.httpResp) 133 handshakeRewindConn.Close() 134 default: 135 handshakeRewindConn.Close() 136 } 137 } else { 138 // in other cases, simply close it 139 tlsConn.Close() 140 log.Error(common.NewError("tls handshake failed").Base(err)) 141 } 142 return 143 } 144 145 log.Info("tls connection from", conn.RemoteAddr()) 146 state := tlsConn.ConnectionState() 147 log.Trace("tls handshake", tls.CipherSuiteName(state.CipherSuite), state.DidResume, state.NegotiatedProtocol) 148 149 // we use a real http header parser to mimic a real http server 150 rewindConn := common.NewRewindConn(tlsConn) 151 rewindConn.SetBufferSize(1024) 152 r := bufio.NewReader(rewindConn) 153 httpReq, err := http.ReadRequest(r) 154 rewindConn.Rewind() 155 rewindConn.StopBuffering() 156 if err != nil { 157 // this is not a http request. pass it to trojan protocol layer for further inspection 158 s.connChan <- &transport.Conn{ 159 Conn: rewindConn, 160 } 161 } else { 162 if atomic.LoadInt32(&s.nextHTTP) != 1 { 163 // there is no websocket layer waiting for connections, redirect it 164 log.Error("incoming http request, but no websocket server is listening") 165 s.redir.Redirect(&redirector.Redirection{ 166 InboundConn: rewindConn, 167 RedirectTo: s.fallbackAddress, 168 }) 169 return 170 } 171 // this is a http request, pass it to websocket protocol layer 172 log.Debug("http req: ", httpReq) 173 s.wsChan <- &transport.Conn{ 174 Conn: rewindConn, 175 } 176 } 177 }(conn) 178 } 179 } 180 181 func (s *Server) AcceptConn(overlay tunnel.Tunnel) (tunnel.Conn, error) { 182 if _, ok := overlay.(*websocket.Tunnel); ok { 183 atomic.StoreInt32(&s.nextHTTP, 1) 184 log.Debug("next proto http") 185 // websocket overlay 186 select { 187 case conn := <-s.wsChan: 188 return conn, nil 189 case <-s.ctx.Done(): 190 return nil, common.NewError("transport server closed") 191 } 192 } 193 // trojan overlay 194 select { 195 case conn := <-s.connChan: 196 return conn, nil 197 case <-s.ctx.Done(): 198 return nil, common.NewError("transport server closed") 199 } 200 } 201 202 func (s *Server) AcceptPacket(tunnel.Tunnel) (tunnel.PacketConn, error) { 203 panic("not supported") 204 } 205 206 func (s *Server) checkKeyPairLoop(checkRate time.Duration, keyPath string, certPath string, password string) { 207 var lastKeyBytes, lastCertBytes []byte 208 ticker := time.NewTicker(checkRate) 209 210 for { 211 log.Debug("checking cert...") 212 keyBytes, err := ioutil.ReadFile(keyPath) 213 if err != nil { 214 log.Error(common.NewError("tls failed to check key").Base(err)) 215 continue 216 } 217 certBytes, err := ioutil.ReadFile(certPath) 218 if err != nil { 219 log.Error(common.NewError("tls failed to check cert").Base(err)) 220 continue 221 } 222 if !bytes.Equal(keyBytes, lastKeyBytes) || !bytes.Equal(lastCertBytes, certBytes) { 223 log.Info("new key pair detected") 224 keyPair, err := loadKeyPair(keyPath, certPath, password) 225 if err != nil { 226 log.Error(common.NewError("tls failed to load new key pair").Base(err)) 227 continue 228 } 229 s.keyPairLock.Lock() 230 s.keyPair = []tls.Certificate{*keyPair} 231 s.keyPairLock.Unlock() 232 lastKeyBytes = keyBytes 233 lastCertBytes = certBytes 234 } 235 236 select { 237 case <-ticker.C: 238 continue 239 case <-s.ctx.Done(): 240 log.Debug("exiting") 241 ticker.Stop() 242 return 243 } 244 } 245 } 246 247 func loadKeyPair(keyPath string, certPath string, password string) (*tls.Certificate, error) { 248 if password != "" { 249 keyFile, err := ioutil.ReadFile(keyPath) 250 if err != nil { 251 return nil, common.NewError("failed to load key file").Base(err) 252 } 253 keyBlock, _ := pem.Decode(keyFile) 254 if keyBlock == nil { 255 return nil, common.NewError("failed to decode key file").Base(err) 256 } 257 decryptedKey, err := x509.DecryptPEMBlock(keyBlock, []byte(password)) 258 if err == nil { 259 return nil, common.NewError("failed to decrypt key").Base(err) 260 } 261 262 certFile, err := ioutil.ReadFile(certPath) 263 certBlock, _ := pem.Decode(certFile) 264 if certBlock == nil { 265 return nil, common.NewError("failed to decode cert file").Base(err) 266 } 267 268 keyPair, err := tls.X509KeyPair(certBlock.Bytes, decryptedKey) 269 if err != nil { 270 return nil, err 271 } 272 keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0]) 273 if err != nil { 274 return nil, common.NewError("failed to parse leaf certificate").Base(err) 275 } 276 277 return &keyPair, nil 278 } 279 keyPair, err := tls.LoadX509KeyPair(certPath, keyPath) 280 if err != nil { 281 return nil, common.NewError("failed to load key pair").Base(err) 282 } 283 keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0]) 284 if err != nil { 285 return nil, common.NewError("failed to parse leaf certificate").Base(err) 286 } 287 return &keyPair, nil 288 } 289 290 // NewServer creates a tls layer server 291 func NewServer(ctx context.Context, underlay tunnel.Server) (*Server, error) { 292 cfg := config.FromContext(ctx, Name).(*Config) 293 294 var fallbackAddress *tunnel.Address 295 var httpResp []byte 296 if cfg.TLS.FallbackPort != 0 { 297 if cfg.TLS.FallbackHost == "" { 298 cfg.TLS.FallbackHost = cfg.RemoteHost 299 log.Warn("empty tls fallback address") 300 } 301 fallbackAddress = tunnel.NewAddressFromHostPort("tcp", cfg.TLS.FallbackHost, cfg.TLS.FallbackPort) 302 fallbackConn, err := net.Dial("tcp", fallbackAddress.String()) 303 if err != nil { 304 return nil, common.NewError("invalid fallback address").Base(err) 305 } 306 fallbackConn.Close() 307 } else { 308 log.Warn("empty tls fallback port") 309 if cfg.TLS.HTTPResponseFileName != "" { 310 httpRespBody, err := ioutil.ReadFile(cfg.TLS.HTTPResponseFileName) 311 if err != nil { 312 return nil, common.NewError("invalid response file").Base(err) 313 } 314 httpResp = httpRespBody 315 } else { 316 log.Warn("empty tls http response") 317 } 318 } 319 320 keyPair, err := loadKeyPair(cfg.TLS.KeyPath, cfg.TLS.CertPath, cfg.TLS.KeyPassword) 321 if err != nil { 322 return nil, common.NewError("tls failed to load key pair") 323 } 324 325 var keyLogger io.WriteCloser 326 if cfg.TLS.KeyLogPath != "" { 327 log.Warn("tls key logging activated. USE OF KEY LOGGING COMPROMISES SECURITY. IT SHOULD ONLY BE USED FOR DEBUGGING.") 328 file, err := os.OpenFile(cfg.TLS.KeyLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o600) 329 if err != nil { 330 return nil, common.NewError("failed to open key log file").Base(err) 331 } 332 keyLogger = file 333 } 334 335 var cipherSuite []uint16 336 if len(cfg.TLS.Cipher) != 0 { 337 cipherSuite = fingerprint.ParseCipher(strings.Split(cfg.TLS.Cipher, ":")) 338 } 339 340 ctx, cancel := context.WithCancel(ctx) 341 server := &Server{ 342 underlay: underlay, 343 fallbackAddress: fallbackAddress, 344 httpResp: httpResp, 345 verifySNI: cfg.TLS.VerifyHostName, 346 sni: cfg.TLS.SNI, 347 alpn: cfg.TLS.ALPN, 348 PreferServerCipher: cfg.TLS.PreferServerCipher, 349 sessionTicket: cfg.TLS.ReuseSession, 350 connChan: make(chan tunnel.Conn, 32), 351 wsChan: make(chan tunnel.Conn, 32), 352 redir: redirector.NewRedirector(ctx), 353 keyPair: []tls.Certificate{*keyPair}, 354 keyLogger: keyLogger, 355 cipherSuite: cipherSuite, 356 ctx: ctx, 357 cancel: cancel, 358 } 359 360 go server.acceptLoop() 361 if cfg.TLS.CertCheckRate > 0 { 362 go server.checkKeyPairLoop( 363 time.Second*time.Duration(cfg.TLS.CertCheckRate), 364 cfg.TLS.KeyPath, 365 cfg.TLS.CertPath, 366 cfg.TLS.KeyPassword, 367 ) 368 } 369 370 log.Debug("tls server created") 371 return server, nil 372 }