github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/quic/gquic-go/server.go (about) 1 package gquic 2 3 import ( 4 "crypto/tls" 5 "errors" 6 "fmt" 7 "io" 8 "net" 9 "sync" 10 "time" 11 12 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/crypto" 13 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/handshake" 14 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/protocol" 15 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/utils" 16 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/wire" 17 ) 18 19 // packetHandler handles packets 20 type packetHandler interface { 21 handlePacket(*receivedPacket) 22 io.Closer 23 destroy(error) 24 GetVersion() protocol.VersionNumber 25 GetPerspective() protocol.Perspective 26 } 27 28 type unknownPacketHandler interface { 29 handlePacket(*receivedPacket) 30 closeWithError(error) error 31 } 32 33 type packetHandlerManager interface { 34 Add(protocol.ConnectionID, packetHandler) 35 SetServer(unknownPacketHandler) 36 Remove(protocol.ConnectionID) 37 CloseServer() 38 } 39 40 type quicSession interface { 41 Session 42 handlePacket(*receivedPacket) 43 GetVersion() protocol.VersionNumber 44 run() error 45 destroy(error) 46 closeRemote(error) 47 } 48 49 type sessionRunner interface { 50 onHandshakeComplete(Session) 51 removeConnectionID(protocol.ConnectionID) 52 } 53 54 type runner struct { 55 onHandshakeCompleteImpl func(Session) 56 removeConnectionIDImpl func(protocol.ConnectionID) 57 } 58 59 func (r *runner) onHandshakeComplete(s Session) { r.onHandshakeCompleteImpl(s) } 60 func (r *runner) removeConnectionID(c protocol.ConnectionID) { r.removeConnectionIDImpl(c) } 61 62 var _ sessionRunner = &runner{} 63 64 // A Listener of QUIC 65 type server struct { 66 mutex sync.Mutex 67 68 tlsConf *tls.Config 69 config *Config 70 71 conn net.PacketConn 72 // If the server is started with ListenAddr, we create a packet conn. 73 // If it is started with Listen, we take a packet conn as a parameter. 74 createdPacketConn bool 75 76 supportsTLS bool 77 serverTLS *serverTLS 78 79 certChain crypto.CertChain 80 scfg *handshake.ServerConfig 81 82 sessionHandler packetHandlerManager 83 84 serverError error 85 errorChan chan struct{} 86 closed bool 87 88 sessionQueue chan Session 89 90 sessionRunner sessionRunner 91 // set as a member, so they can be set in the tests 92 newSession func(connection, sessionRunner, protocol.VersionNumber, protocol.ConnectionID, protocol.ConnectionID, *handshake.ServerConfig, *tls.Config, *Config, utils.Logger) (quicSession, error) 93 94 logger utils.Logger 95 } 96 97 var _ Listener = &server{} 98 var _ unknownPacketHandler = &server{} 99 100 // ListenAddr creates a QUIC server listening on a given address. 101 // The tls.Config must not be nil, the quic.Config may be nil. 102 func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) { 103 udpAddr, err := net.ResolveUDPAddr("udp", addr) 104 if err != nil { 105 return nil, err 106 } 107 conn, err := net.ListenUDP("udp", udpAddr) 108 if err != nil { 109 return nil, err 110 } 111 serv, err := listen(conn, tlsConf, config) 112 if err != nil { 113 return nil, err 114 } 115 serv.createdPacketConn = true 116 return serv, nil 117 } 118 119 // Listen listens for QUIC connections on a given net.PacketConn. 120 // The tls.Config must not be nil, the quic.Config may be nil. 121 func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) { 122 return listen(conn, tlsConf, config) 123 } 124 125 func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, error) { 126 if tlsConf == nil || (len(tlsConf.Certificates) == 0 && tlsConf.GetCertificate == nil) { 127 return nil, errors.New("quic: neither Certificates nor GetCertificate set in tls.Config") 128 } 129 certChain := crypto.NewCertChain(tlsConf) 130 kex, err := crypto.NewCurve25519KEX() 131 if err != nil { 132 return nil, err 133 } 134 scfg, err := handshake.NewServerConfig(kex, certChain) 135 if err != nil { 136 return nil, err 137 } 138 config = populateServerConfig(config) 139 140 var supportsTLS bool 141 for _, v := range config.Versions { 142 if !protocol.IsValidVersion(v) { 143 return nil, fmt.Errorf("%s is not a valid QUIC version", v) 144 } 145 // check if any of the supported versions supports TLS 146 if v.UsesTLS() { 147 supportsTLS = true 148 break 149 } 150 } 151 152 sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength) 153 if err != nil { 154 return nil, err 155 } 156 s := &server{ 157 conn: conn, 158 tlsConf: tlsConf, 159 config: config, 160 certChain: certChain, 161 scfg: scfg, 162 newSession: newSession, 163 sessionHandler: sessionHandler, 164 sessionQueue: make(chan Session, 5), 165 errorChan: make(chan struct{}), 166 supportsTLS: supportsTLS, 167 logger: utils.DefaultLogger.WithPrefix("server"), 168 } 169 s.setup() 170 if supportsTLS { 171 if err := s.setupTLS(); err != nil { 172 return nil, err 173 } 174 } 175 sessionHandler.SetServer(s) 176 s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) 177 return s, nil 178 } 179 180 func (s *server) setup() { 181 s.sessionRunner = &runner{ 182 onHandshakeCompleteImpl: func(sess Session) { s.sessionQueue <- sess }, 183 removeConnectionIDImpl: s.sessionHandler.Remove, 184 } 185 } 186 187 func (s *server) setupTLS() error { 188 serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, s.sessionRunner, s.tlsConf, s.logger) 189 if err != nil { 190 return err 191 } 192 s.serverTLS = serverTLS 193 // handle TLS connection establishment statelessly 194 go func() { 195 for { 196 select { 197 case <-s.errorChan: 198 return 199 case tlsSession := <-sessionChan: 200 // The connection ID is a randomly chosen value. 201 // It is safe to assume that it doesn't collide with other randomly chosen values. 202 serverSession := newServerSession(tlsSession.sess, s.config, s.logger) 203 s.sessionHandler.Add(tlsSession.connID, serverSession) 204 } 205 } 206 }() 207 return nil 208 } 209 210 var defaultAcceptCookie = func(clientAddr net.Addr, cookie *Cookie) bool { 211 if cookie == nil { 212 return false 213 } 214 if time.Now().After(cookie.SentTime.Add(protocol.CookieExpiryTime)) { 215 return false 216 } 217 var sourceAddr string 218 if udpAddr, ok := clientAddr.(*net.UDPAddr); ok { 219 sourceAddr = udpAddr.IP.String() 220 } else { 221 sourceAddr = clientAddr.String() 222 } 223 return sourceAddr == cookie.RemoteAddr 224 } 225 226 // populateServerConfig populates fields in the quic.Config with their default values, if none are set 227 // it may be called with nil 228 func populateServerConfig(config *Config) *Config { 229 if config == nil { 230 config = &Config{} 231 } 232 versions := config.Versions 233 if len(versions) == 0 { 234 versions = protocol.SupportedVersions 235 } 236 237 vsa := defaultAcceptCookie 238 if config.AcceptCookie != nil { 239 vsa = config.AcceptCookie 240 } 241 242 handshakeTimeout := protocol.DefaultHandshakeTimeout 243 if config.HandshakeTimeout != 0 { 244 handshakeTimeout = config.HandshakeTimeout 245 } 246 idleTimeout := protocol.DefaultIdleTimeout 247 if config.IdleTimeout != 0 { 248 idleTimeout = config.IdleTimeout 249 } 250 251 maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow 252 if maxReceiveStreamFlowControlWindow == 0 { 253 maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindowServer 254 } 255 maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow 256 if maxReceiveConnectionFlowControlWindow == 0 { 257 maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowServer 258 } 259 maxIncomingStreams := config.MaxIncomingStreams 260 if maxIncomingStreams == 0 { 261 maxIncomingStreams = protocol.DefaultMaxIncomingStreams 262 } else if maxIncomingStreams < 0 { 263 maxIncomingStreams = 0 264 } 265 maxIncomingUniStreams := config.MaxIncomingUniStreams 266 if maxIncomingUniStreams == 0 { 267 maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams 268 } else if maxIncomingUniStreams < 0 { 269 maxIncomingUniStreams = 0 270 } 271 connIDLen := config.ConnectionIDLength 272 if connIDLen == 0 { 273 connIDLen = protocol.DefaultConnectionIDLength 274 } 275 for _, v := range versions { 276 if v == protocol.Version44 { 277 connIDLen = protocol.ConnectionIDLenGQUIC 278 } 279 } 280 281 return &Config{ 282 Versions: versions, 283 HandshakeTimeout: handshakeTimeout, 284 IdleTimeout: idleTimeout, 285 AcceptCookie: vsa, 286 KeepAlive: config.KeepAlive, 287 MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow, 288 MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow, 289 MaxIncomingStreams: maxIncomingStreams, 290 MaxIncomingUniStreams: maxIncomingUniStreams, 291 ConnectionIDLength: connIDLen, 292 } 293 } 294 295 // Accept returns newly openend sessions 296 func (s *server) Accept() (Session, error) { 297 var sess Session 298 select { 299 case sess = <-s.sessionQueue: 300 return sess, nil 301 case <-s.errorChan: 302 return nil, s.serverError 303 } 304 } 305 306 // Close the server 307 func (s *server) Close() error { 308 s.mutex.Lock() 309 defer s.mutex.Unlock() 310 if s.closed { 311 return nil 312 } 313 return s.closeWithMutex() 314 } 315 316 func (s *server) closeWithMutex() error { 317 s.sessionHandler.CloseServer() 318 if s.serverError == nil { 319 s.serverError = errors.New("server closed") 320 } 321 var err error 322 // If the server was started with ListenAddr, we created the packet conn. 323 // We need to close it in order to make the go routine reading from that conn return. 324 if s.createdPacketConn { 325 err = s.conn.Close() 326 } 327 s.closed = true 328 close(s.errorChan) 329 return err 330 } 331 332 func (s *server) closeWithError(e error) error { 333 s.mutex.Lock() 334 defer s.mutex.Unlock() 335 if s.closed { 336 return nil 337 } 338 s.serverError = e 339 return s.closeWithMutex() 340 } 341 342 // Addr returns the server's network address 343 func (s *server) Addr() net.Addr { 344 return s.conn.LocalAddr() 345 } 346 347 func (s *server) handlePacket(p *receivedPacket) { 348 if err := s.handlePacketImpl(p); err != nil { 349 s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err) 350 } 351 } 352 353 func (s *server) handlePacketImpl(p *receivedPacket) error { 354 hdr := p.header 355 356 if hdr.VersionFlag || hdr.IsLongHeader { 357 // send a Version Negotiation Packet if the client is speaking a different protocol version 358 if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { 359 return s.sendVersionNegotiationPacket(p) 360 } 361 } 362 if hdr.Type == protocol.PacketTypeInitial && hdr.Version.UsesTLS() { 363 go s.serverTLS.HandleInitial(p) 364 return nil 365 } 366 367 // TODO(#943): send Stateless Reset, if this an IETF QUIC packet 368 if !hdr.VersionFlag && !hdr.Version.UsesIETFHeaderFormat() { 369 _, err := s.conn.WriteTo(wire.WritePublicReset(hdr.DestConnectionID, 0, 0), p.remoteAddr) 370 return err 371 } 372 373 // This is (potentially) a Client Hello. 374 // Make sure it has the minimum required size before spending any more ressources on it. 375 if len(p.data) < protocol.MinClientHelloSize { 376 return errors.New("dropping small packet for unknown connection") 377 } 378 379 var destConnID, srcConnID protocol.ConnectionID 380 if hdr.Version.UsesIETFHeaderFormat() { 381 srcConnID = hdr.DestConnectionID 382 } else { 383 destConnID = hdr.DestConnectionID 384 srcConnID = hdr.DestConnectionID 385 } 386 s.logger.Infof("Serving new connection: %s, version %s from %v", hdr.DestConnectionID, hdr.Version, p.remoteAddr) 387 sess, err := s.newSession( 388 &conn{pconn: s.conn, currentAddr: p.remoteAddr}, 389 s.sessionRunner, 390 hdr.Version, 391 destConnID, 392 srcConnID, 393 s.scfg, 394 s.tlsConf, 395 s.config, 396 s.logger, 397 ) 398 if err != nil { 399 return err 400 } 401 s.sessionHandler.Add(hdr.DestConnectionID, newServerSession(sess, s.config, s.logger)) 402 go sess.run() 403 sess.handlePacket(p) 404 return nil 405 } 406 407 func (s *server) sendVersionNegotiationPacket(p *receivedPacket) error { 408 hdr := p.header 409 s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version) 410 411 var data []byte 412 if hdr.IsPublicHeader { 413 data = wire.ComposeGQUICVersionNegotiation(hdr.DestConnectionID, s.config.Versions) 414 } else { 415 var err error 416 data, err = wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) 417 if err != nil { 418 return err 419 } 420 } 421 _, err := s.conn.WriteTo(data, p.remoteAddr) 422 return err 423 }