github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/quic/gquic-go/server.go (about)

     1  package gquic
     2  
     3  import (
     4  	"crypto/tls"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/crypto"
    13  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/handshake"
    14  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/protocol"
    15  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/utils"
    16  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/wire"
    17  )
    18  
    19  // packetHandler handles packets
    20  type packetHandler interface {
    21  	handlePacket(*receivedPacket)
    22  	io.Closer
    23  	destroy(error)
    24  	GetVersion() protocol.VersionNumber
    25  	GetPerspective() protocol.Perspective
    26  }
    27  
    28  type unknownPacketHandler interface {
    29  	handlePacket(*receivedPacket)
    30  	closeWithError(error) error
    31  }
    32  
    33  type packetHandlerManager interface {
    34  	Add(protocol.ConnectionID, packetHandler)
    35  	SetServer(unknownPacketHandler)
    36  	Remove(protocol.ConnectionID)
    37  	CloseServer()
    38  }
    39  
    40  type quicSession interface {
    41  	Session
    42  	handlePacket(*receivedPacket)
    43  	GetVersion() protocol.VersionNumber
    44  	run() error
    45  	destroy(error)
    46  	closeRemote(error)
    47  }
    48  
    49  type sessionRunner interface {
    50  	onHandshakeComplete(Session)
    51  	removeConnectionID(protocol.ConnectionID)
    52  }
    53  
    54  type runner struct {
    55  	onHandshakeCompleteImpl func(Session)
    56  	removeConnectionIDImpl  func(protocol.ConnectionID)
    57  }
    58  
    59  func (r *runner) onHandshakeComplete(s Session)              { r.onHandshakeCompleteImpl(s) }
    60  func (r *runner) removeConnectionID(c protocol.ConnectionID) { r.removeConnectionIDImpl(c) }
    61  
    62  var _ sessionRunner = &runner{}
    63  
    64  // A Listener of QUIC
    65  type server struct {
    66  	mutex sync.Mutex
    67  
    68  	tlsConf *tls.Config
    69  	config  *Config
    70  
    71  	conn net.PacketConn
    72  	// If the server is started with ListenAddr, we create a packet conn.
    73  	// If it is started with Listen, we take a packet conn as a parameter.
    74  	createdPacketConn bool
    75  
    76  	supportsTLS bool
    77  	serverTLS   *serverTLS
    78  
    79  	certChain crypto.CertChain
    80  	scfg      *handshake.ServerConfig
    81  
    82  	sessionHandler packetHandlerManager
    83  
    84  	serverError error
    85  	errorChan   chan struct{}
    86  	closed      bool
    87  
    88  	sessionQueue chan Session
    89  
    90  	sessionRunner sessionRunner
    91  	// set as a member, so they can be set in the tests
    92  	newSession func(connection, sessionRunner, protocol.VersionNumber, protocol.ConnectionID, protocol.ConnectionID, *handshake.ServerConfig, *tls.Config, *Config, utils.Logger) (quicSession, error)
    93  
    94  	logger utils.Logger
    95  }
    96  
    97  var _ Listener = &server{}
    98  var _ unknownPacketHandler = &server{}
    99  
   100  // ListenAddr creates a QUIC server listening on a given address.
   101  // The tls.Config must not be nil, the quic.Config may be nil.
   102  func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) {
   103  	udpAddr, err := net.ResolveUDPAddr("udp", addr)
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  	conn, err := net.ListenUDP("udp", udpAddr)
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  	serv, err := listen(conn, tlsConf, config)
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  	serv.createdPacketConn = true
   116  	return serv, nil
   117  }
   118  
   119  // Listen listens for QUIC connections on a given net.PacketConn.
   120  // The tls.Config must not be nil, the quic.Config may be nil.
   121  func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
   122  	return listen(conn, tlsConf, config)
   123  }
   124  
   125  func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, error) {
   126  	if tlsConf == nil || (len(tlsConf.Certificates) == 0 && tlsConf.GetCertificate == nil) {
   127  		return nil, errors.New("quic: neither Certificates nor GetCertificate set in tls.Config")
   128  	}
   129  	certChain := crypto.NewCertChain(tlsConf)
   130  	kex, err := crypto.NewCurve25519KEX()
   131  	if err != nil {
   132  		return nil, err
   133  	}
   134  	scfg, err := handshake.NewServerConfig(kex, certChain)
   135  	if err != nil {
   136  		return nil, err
   137  	}
   138  	config = populateServerConfig(config)
   139  
   140  	var supportsTLS bool
   141  	for _, v := range config.Versions {
   142  		if !protocol.IsValidVersion(v) {
   143  			return nil, fmt.Errorf("%s is not a valid QUIC version", v)
   144  		}
   145  		// check if any of the supported versions supports TLS
   146  		if v.UsesTLS() {
   147  			supportsTLS = true
   148  			break
   149  		}
   150  	}
   151  
   152  	sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength)
   153  	if err != nil {
   154  		return nil, err
   155  	}
   156  	s := &server{
   157  		conn:           conn,
   158  		tlsConf:        tlsConf,
   159  		config:         config,
   160  		certChain:      certChain,
   161  		scfg:           scfg,
   162  		newSession:     newSession,
   163  		sessionHandler: sessionHandler,
   164  		sessionQueue:   make(chan Session, 5),
   165  		errorChan:      make(chan struct{}),
   166  		supportsTLS:    supportsTLS,
   167  		logger:         utils.DefaultLogger.WithPrefix("server"),
   168  	}
   169  	s.setup()
   170  	if supportsTLS {
   171  		if err := s.setupTLS(); err != nil {
   172  			return nil, err
   173  		}
   174  	}
   175  	sessionHandler.SetServer(s)
   176  	s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
   177  	return s, nil
   178  }
   179  
   180  func (s *server) setup() {
   181  	s.sessionRunner = &runner{
   182  		onHandshakeCompleteImpl: func(sess Session) { s.sessionQueue <- sess },
   183  		removeConnectionIDImpl:  s.sessionHandler.Remove,
   184  	}
   185  }
   186  
   187  func (s *server) setupTLS() error {
   188  	serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, s.sessionRunner, s.tlsConf, s.logger)
   189  	if err != nil {
   190  		return err
   191  	}
   192  	s.serverTLS = serverTLS
   193  	// handle TLS connection establishment statelessly
   194  	go func() {
   195  		for {
   196  			select {
   197  			case <-s.errorChan:
   198  				return
   199  			case tlsSession := <-sessionChan:
   200  				// The connection ID is a randomly chosen value.
   201  				// It is safe to assume that it doesn't collide with other randomly chosen values.
   202  				serverSession := newServerSession(tlsSession.sess, s.config, s.logger)
   203  				s.sessionHandler.Add(tlsSession.connID, serverSession)
   204  			}
   205  		}
   206  	}()
   207  	return nil
   208  }
   209  
   210  var defaultAcceptCookie = func(clientAddr net.Addr, cookie *Cookie) bool {
   211  	if cookie == nil {
   212  		return false
   213  	}
   214  	if time.Now().After(cookie.SentTime.Add(protocol.CookieExpiryTime)) {
   215  		return false
   216  	}
   217  	var sourceAddr string
   218  	if udpAddr, ok := clientAddr.(*net.UDPAddr); ok {
   219  		sourceAddr = udpAddr.IP.String()
   220  	} else {
   221  		sourceAddr = clientAddr.String()
   222  	}
   223  	return sourceAddr == cookie.RemoteAddr
   224  }
   225  
   226  // populateServerConfig populates fields in the quic.Config with their default values, if none are set
   227  // it may be called with nil
   228  func populateServerConfig(config *Config) *Config {
   229  	if config == nil {
   230  		config = &Config{}
   231  	}
   232  	versions := config.Versions
   233  	if len(versions) == 0 {
   234  		versions = protocol.SupportedVersions
   235  	}
   236  
   237  	vsa := defaultAcceptCookie
   238  	if config.AcceptCookie != nil {
   239  		vsa = config.AcceptCookie
   240  	}
   241  
   242  	handshakeTimeout := protocol.DefaultHandshakeTimeout
   243  	if config.HandshakeTimeout != 0 {
   244  		handshakeTimeout = config.HandshakeTimeout
   245  	}
   246  	idleTimeout := protocol.DefaultIdleTimeout
   247  	if config.IdleTimeout != 0 {
   248  		idleTimeout = config.IdleTimeout
   249  	}
   250  
   251  	maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
   252  	if maxReceiveStreamFlowControlWindow == 0 {
   253  		maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindowServer
   254  	}
   255  	maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
   256  	if maxReceiveConnectionFlowControlWindow == 0 {
   257  		maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowServer
   258  	}
   259  	maxIncomingStreams := config.MaxIncomingStreams
   260  	if maxIncomingStreams == 0 {
   261  		maxIncomingStreams = protocol.DefaultMaxIncomingStreams
   262  	} else if maxIncomingStreams < 0 {
   263  		maxIncomingStreams = 0
   264  	}
   265  	maxIncomingUniStreams := config.MaxIncomingUniStreams
   266  	if maxIncomingUniStreams == 0 {
   267  		maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
   268  	} else if maxIncomingUniStreams < 0 {
   269  		maxIncomingUniStreams = 0
   270  	}
   271  	connIDLen := config.ConnectionIDLength
   272  	if connIDLen == 0 {
   273  		connIDLen = protocol.DefaultConnectionIDLength
   274  	}
   275  	for _, v := range versions {
   276  		if v == protocol.Version44 {
   277  			connIDLen = protocol.ConnectionIDLenGQUIC
   278  		}
   279  	}
   280  
   281  	return &Config{
   282  		Versions:                              versions,
   283  		HandshakeTimeout:                      handshakeTimeout,
   284  		IdleTimeout:                           idleTimeout,
   285  		AcceptCookie:                          vsa,
   286  		KeepAlive:                             config.KeepAlive,
   287  		MaxReceiveStreamFlowControlWindow:     maxReceiveStreamFlowControlWindow,
   288  		MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
   289  		MaxIncomingStreams:                    maxIncomingStreams,
   290  		MaxIncomingUniStreams:                 maxIncomingUniStreams,
   291  		ConnectionIDLength:                    connIDLen,
   292  	}
   293  }
   294  
   295  // Accept returns newly openend sessions
   296  func (s *server) Accept() (Session, error) {
   297  	var sess Session
   298  	select {
   299  	case sess = <-s.sessionQueue:
   300  		return sess, nil
   301  	case <-s.errorChan:
   302  		return nil, s.serverError
   303  	}
   304  }
   305  
   306  // Close the server
   307  func (s *server) Close() error {
   308  	s.mutex.Lock()
   309  	defer s.mutex.Unlock()
   310  	if s.closed {
   311  		return nil
   312  	}
   313  	return s.closeWithMutex()
   314  }
   315  
   316  func (s *server) closeWithMutex() error {
   317  	s.sessionHandler.CloseServer()
   318  	if s.serverError == nil {
   319  		s.serverError = errors.New("server closed")
   320  	}
   321  	var err error
   322  	// If the server was started with ListenAddr, we created the packet conn.
   323  	// We need to close it in order to make the go routine reading from that conn return.
   324  	if s.createdPacketConn {
   325  		err = s.conn.Close()
   326  	}
   327  	s.closed = true
   328  	close(s.errorChan)
   329  	return err
   330  }
   331  
   332  func (s *server) closeWithError(e error) error {
   333  	s.mutex.Lock()
   334  	defer s.mutex.Unlock()
   335  	if s.closed {
   336  		return nil
   337  	}
   338  	s.serverError = e
   339  	return s.closeWithMutex()
   340  }
   341  
   342  // Addr returns the server's network address
   343  func (s *server) Addr() net.Addr {
   344  	return s.conn.LocalAddr()
   345  }
   346  
   347  func (s *server) handlePacket(p *receivedPacket) {
   348  	if err := s.handlePacketImpl(p); err != nil {
   349  		s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
   350  	}
   351  }
   352  
   353  func (s *server) handlePacketImpl(p *receivedPacket) error {
   354  	hdr := p.header
   355  
   356  	if hdr.VersionFlag || hdr.IsLongHeader {
   357  		// send a Version Negotiation Packet if the client is speaking a different protocol version
   358  		if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
   359  			return s.sendVersionNegotiationPacket(p)
   360  		}
   361  	}
   362  	if hdr.Type == protocol.PacketTypeInitial && hdr.Version.UsesTLS() {
   363  		go s.serverTLS.HandleInitial(p)
   364  		return nil
   365  	}
   366  
   367  	// TODO(#943): send Stateless Reset, if this an IETF QUIC packet
   368  	if !hdr.VersionFlag && !hdr.Version.UsesIETFHeaderFormat() {
   369  		_, err := s.conn.WriteTo(wire.WritePublicReset(hdr.DestConnectionID, 0, 0), p.remoteAddr)
   370  		return err
   371  	}
   372  
   373  	// This is (potentially) a Client Hello.
   374  	// Make sure it has the minimum required size before spending any more ressources on it.
   375  	if len(p.data) < protocol.MinClientHelloSize {
   376  		return errors.New("dropping small packet for unknown connection")
   377  	}
   378  
   379  	var destConnID, srcConnID protocol.ConnectionID
   380  	if hdr.Version.UsesIETFHeaderFormat() {
   381  		srcConnID = hdr.DestConnectionID
   382  	} else {
   383  		destConnID = hdr.DestConnectionID
   384  		srcConnID = hdr.DestConnectionID
   385  	}
   386  	s.logger.Infof("Serving new connection: %s, version %s from %v", hdr.DestConnectionID, hdr.Version, p.remoteAddr)
   387  	sess, err := s.newSession(
   388  		&conn{pconn: s.conn, currentAddr: p.remoteAddr},
   389  		s.sessionRunner,
   390  		hdr.Version,
   391  		destConnID,
   392  		srcConnID,
   393  		s.scfg,
   394  		s.tlsConf,
   395  		s.config,
   396  		s.logger,
   397  	)
   398  	if err != nil {
   399  		return err
   400  	}
   401  	s.sessionHandler.Add(hdr.DestConnectionID, newServerSession(sess, s.config, s.logger))
   402  	go sess.run()
   403  	sess.handlePacket(p)
   404  	return nil
   405  }
   406  
   407  func (s *server) sendVersionNegotiationPacket(p *receivedPacket) error {
   408  	hdr := p.header
   409  	s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
   410  
   411  	var data []byte
   412  	if hdr.IsPublicHeader {
   413  		data = wire.ComposeGQUICVersionNegotiation(hdr.DestConnectionID, s.config.Versions)
   414  	} else {
   415  		var err error
   416  		data, err = wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
   417  		if err != nil {
   418  			return err
   419  		}
   420  	}
   421  	_, err := s.conn.WriteTo(data, p.remoteAddr)
   422  	return err
   423  }