github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/tuic/server.go (about)

     1  //go:build with_quic
     2  
     3  package tuic
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"encoding/binary"
     9  	"github.com/inazumav/sing-box/option"
    10  	"github.com/sagernet/quic-go"
    11  	"io"
    12  	"net"
    13  	"runtime"
    14  	"strings"
    15  	"sync"
    16  	"time"
    17  
    18  	"github.com/inazumav/sing-box/common/baderror"
    19  	"github.com/inazumav/sing-box/common/qtls"
    20  	"github.com/inazumav/sing-box/common/tls"
    21  	"github.com/sagernet/sing/common"
    22  	"github.com/sagernet/sing/common/auth"
    23  	"github.com/sagernet/sing/common/buf"
    24  	"github.com/sagernet/sing/common/bufio"
    25  	E "github.com/sagernet/sing/common/exceptions"
    26  	"github.com/sagernet/sing/common/logger"
    27  	M "github.com/sagernet/sing/common/metadata"
    28  	N "github.com/sagernet/sing/common/network"
    29  
    30  	"github.com/gofrs/uuid/v5"
    31  )
    32  
    33  type ServerOptions struct {
    34  	Context           context.Context
    35  	Logger            logger.Logger
    36  	TLSConfig         tls.ServerConfig
    37  	Users             []User
    38  	CongestionControl string
    39  	AuthTimeout       time.Duration
    40  	ZeroRTTHandshake  bool
    41  	Heartbeat         time.Duration
    42  	Handler           ServerHandler
    43  }
    44  
    45  type User struct {
    46  	Name     string
    47  	UUID     uuid.UUID
    48  	Password string
    49  }
    50  
    51  type ServerHandler interface {
    52  	N.TCPConnectionHandler
    53  	N.UDPConnectionHandler
    54  }
    55  
    56  type Server struct {
    57  	ctx               context.Context
    58  	logger            logger.Logger
    59  	tlsConfig         tls.ServerConfig
    60  	heartbeat         time.Duration
    61  	quicConfig        *quic.Config
    62  	userMap           map[uuid.UUID]User
    63  	congestionControl string
    64  	authTimeout       time.Duration
    65  	handler           ServerHandler
    66  
    67  	quicListener io.Closer
    68  }
    69  
    70  func NewServer(options ServerOptions) (*Server, error) {
    71  	if options.AuthTimeout == 0 {
    72  		options.AuthTimeout = 3 * time.Second
    73  	}
    74  	if options.Heartbeat == 0 {
    75  		options.Heartbeat = 10 * time.Second
    76  	}
    77  	quicConfig := &quic.Config{
    78  		DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
    79  		MaxDatagramFrameSize:    1400,
    80  		EnableDatagrams:         true,
    81  		Allow0RTT:               options.ZeroRTTHandshake,
    82  		MaxIncomingStreams:      1 << 60,
    83  		MaxIncomingUniStreams:   1 << 60,
    84  	}
    85  	switch options.CongestionControl {
    86  	case "":
    87  		options.CongestionControl = "cubic"
    88  	case "cubic", "new_reno", "bbr":
    89  	default:
    90  		return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl)
    91  	}
    92  	if len(options.Users) == 0 {
    93  		return nil, E.New("missing users")
    94  	}
    95  	userMap := make(map[uuid.UUID]User)
    96  	for _, user := range options.Users {
    97  		userMap[user.UUID] = user
    98  	}
    99  	return &Server{
   100  		ctx:               options.Context,
   101  		logger:            options.Logger,
   102  		tlsConfig:         options.TLSConfig,
   103  		heartbeat:         options.Heartbeat,
   104  		quicConfig:        quicConfig,
   105  		userMap:           userMap,
   106  		congestionControl: options.CongestionControl,
   107  		authTimeout:       options.AuthTimeout,
   108  		handler:           options.Handler,
   109  	}, nil
   110  }
   111  
   112  func (s *Server) Start(conn net.PacketConn) error {
   113  	if !s.quicConfig.Allow0RTT {
   114  		listener, err := qtls.Listen(conn, s.tlsConfig, s.quicConfig)
   115  		if err != nil {
   116  			return err
   117  		}
   118  		s.quicListener = listener
   119  		go func() {
   120  			for {
   121  				connection, hErr := listener.Accept(s.ctx)
   122  				if hErr != nil {
   123  					if strings.Contains(hErr.Error(), "server closed") {
   124  						s.logger.Debug(E.Cause(hErr, "listener closed"))
   125  					} else {
   126  						s.logger.Error(E.Cause(hErr, "listener closed"))
   127  					}
   128  					return
   129  				}
   130  				go s.handleConnection(connection)
   131  			}
   132  		}()
   133  	} else {
   134  		listener, err := qtls.ListenEarly(conn, s.tlsConfig, s.quicConfig)
   135  		if err != nil {
   136  			return err
   137  		}
   138  		s.quicListener = listener
   139  		go func() {
   140  			for {
   141  				connection, hErr := listener.Accept(s.ctx)
   142  				if hErr != nil {
   143  					if strings.Contains(hErr.Error(), "server closed") {
   144  						s.logger.Debug(E.Cause(hErr, "listener closed"))
   145  					} else {
   146  						s.logger.Error(E.Cause(hErr, "listener closed"))
   147  					}
   148  					return
   149  				}
   150  				go s.handleConnection(connection)
   151  			}
   152  		}()
   153  	}
   154  	return nil
   155  }
   156  
   157  func (s *Server) Close() error {
   158  	return common.Close(
   159  		s.quicListener,
   160  	)
   161  }
   162  
   163  func (s *Server) AddUsers(users []option.TUICUser) error {
   164  	for _, u := range users {
   165  		uuid, err := uuid.FromString(u.UUID)
   166  		if err != nil {
   167  			return E.Cause(err, "invalid uuid for user ", u.UUID)
   168  		}
   169  		s.userMap[uuid] = User{
   170  			Name:     u.Name,
   171  			UUID:     uuid,
   172  			Password: u.Password,
   173  		}
   174  	}
   175  	return nil
   176  }
   177  
   178  func (s *Server) DelUsers(uuids []string) error {
   179  	for _, u := range uuids {
   180  		ud, err := uuid.FromString(u)
   181  		if err != nil {
   182  			return E.Cause(err, "invalid uuid for user ", ud)
   183  		}
   184  		delete(s.userMap, ud)
   185  	}
   186  	return nil
   187  }
   188  
   189  func (s *Server) handleConnection(connection quic.Connection) {
   190  	setCongestion(s.ctx, connection, s.congestionControl)
   191  	session := &serverSession{
   192  		Server:     s,
   193  		ctx:        s.ctx,
   194  		quicConn:   connection,
   195  		source:     M.SocksaddrFromNet(connection.RemoteAddr()),
   196  		connDone:   make(chan struct{}),
   197  		authDone:   make(chan struct{}),
   198  		udpConnMap: make(map[uint16]*udpPacketConn),
   199  	}
   200  	session.handle()
   201  }
   202  
   203  type serverSession struct {
   204  	*Server
   205  	ctx        context.Context
   206  	quicConn   quic.Connection
   207  	source     M.Socksaddr
   208  	connAccess sync.Mutex
   209  	connDone   chan struct{}
   210  	connErr    error
   211  	authDone   chan struct{}
   212  	authUser   *User
   213  	udpAccess  sync.RWMutex
   214  	udpConnMap map[uint16]*udpPacketConn
   215  }
   216  
   217  func (s *serverSession) handle() {
   218  	if s.ctx.Done() != nil {
   219  		go func() {
   220  			select {
   221  			case <-s.ctx.Done():
   222  				s.closeWithError(s.ctx.Err())
   223  			case <-s.connDone:
   224  			}
   225  		}()
   226  	}
   227  	go s.loopUniStreams()
   228  	go s.loopStreams()
   229  	go s.loopMessages()
   230  	go s.handleAuthTimeout()
   231  	go s.loopHeartbeats()
   232  }
   233  
   234  func (s *serverSession) loopUniStreams() {
   235  	for {
   236  		uniStream, err := s.quicConn.AcceptUniStream(s.ctx)
   237  		if err != nil {
   238  			return
   239  		}
   240  		go func() {
   241  			err = s.handleUniStream(uniStream)
   242  			if err != nil {
   243  				s.closeWithError(E.Cause(err, "handle uni stream"))
   244  			}
   245  		}()
   246  	}
   247  }
   248  
   249  func (s *serverSession) handleUniStream(stream quic.ReceiveStream) error {
   250  	defer stream.CancelRead(0)
   251  	buffer := buf.New()
   252  	defer buffer.Release()
   253  	_, err := buffer.ReadAtLeastFrom(stream, 2)
   254  	if err != nil {
   255  		return E.Cause(err, "read request")
   256  	}
   257  	version := buffer.Byte(0)
   258  	if version != Version {
   259  		return E.New("unknown version ", buffer.Byte(0))
   260  	}
   261  	command := buffer.Byte(1)
   262  	switch command {
   263  	case CommandAuthenticate:
   264  		select {
   265  		case <-s.authDone:
   266  			return E.New("authentication: multiple authentication requests")
   267  		default:
   268  		}
   269  		if buffer.Len() < AuthenticateLen {
   270  			_, err = buffer.ReadFullFrom(stream, AuthenticateLen-buffer.Len())
   271  			if err != nil {
   272  				return E.Cause(err, "authentication: read request")
   273  			}
   274  		}
   275  		userUUID := uuid.FromBytesOrNil(buffer.Range(2, 2+16))
   276  		user, loaded := s.userMap[userUUID]
   277  		if !loaded {
   278  			return E.New("authentication: unknown user ", userUUID)
   279  		}
   280  		handshakeState := s.quicConn.ConnectionState()
   281  		tuicToken, err := handshakeState.ExportKeyingMaterial(string(user.UUID[:]), []byte(user.Password), 32)
   282  		if err != nil {
   283  			return E.Cause(err, "authentication: export keying material")
   284  		}
   285  		if !bytes.Equal(tuicToken, buffer.Range(2+16, 2+16+32)) {
   286  			return E.New("authentication: token mismatch")
   287  		}
   288  		s.authUser = &user
   289  		close(s.authDone)
   290  		return nil
   291  	case CommandPacket:
   292  		select {
   293  		case <-s.connDone:
   294  			return s.connErr
   295  		case <-s.authDone:
   296  		}
   297  		message := udpMessagePool.Get().(*udpMessage)
   298  		err = readUDPMessage(message, io.MultiReader(bytes.NewReader(buffer.From(2)), stream))
   299  		if err != nil {
   300  			message.release()
   301  			return err
   302  		}
   303  		s.handleUDPMessage(message, true)
   304  		return nil
   305  	case CommandDissociate:
   306  		select {
   307  		case <-s.connDone:
   308  			return s.connErr
   309  		case <-s.authDone:
   310  		}
   311  		if buffer.Len() > 4 {
   312  			return E.New("invalid dissociate message")
   313  		}
   314  		var sessionID uint16
   315  		err = binary.Read(io.MultiReader(bytes.NewReader(buffer.From(2)), stream), binary.BigEndian, &sessionID)
   316  		if err != nil {
   317  			return err
   318  		}
   319  		s.udpAccess.RLock()
   320  		udpConn, loaded := s.udpConnMap[sessionID]
   321  		s.udpAccess.RUnlock()
   322  		if loaded {
   323  			udpConn.closeWithError(E.New("remote closed"))
   324  			s.udpAccess.Lock()
   325  			delete(s.udpConnMap, sessionID)
   326  			s.udpAccess.Unlock()
   327  		}
   328  		return nil
   329  	default:
   330  		return E.New("unknown command ", command)
   331  	}
   332  }
   333  
   334  func (s *serverSession) handleAuthTimeout() {
   335  	select {
   336  	case <-s.connDone:
   337  	case <-s.authDone:
   338  	case <-time.After(s.authTimeout):
   339  		s.closeWithError(E.New("authentication timeout"))
   340  	}
   341  }
   342  
   343  func (s *serverSession) loopStreams() {
   344  	for {
   345  		stream, err := s.quicConn.AcceptStream(s.ctx)
   346  		if err != nil {
   347  			return
   348  		}
   349  		go func() {
   350  			err = s.handleStream(stream)
   351  			if err != nil {
   352  				stream.CancelRead(0)
   353  				stream.Close()
   354  				s.logger.Error(E.Cause(err, "handle stream request"))
   355  			}
   356  		}()
   357  	}
   358  }
   359  
   360  func (s *serverSession) handleStream(stream quic.Stream) error {
   361  	buffer := buf.NewSize(2 + M.MaxSocksaddrLength)
   362  	defer buffer.Release()
   363  	_, err := buffer.ReadAtLeastFrom(stream, 2)
   364  	if err != nil {
   365  		return E.Cause(err, "read request")
   366  	}
   367  	version, _ := buffer.ReadByte()
   368  	if version != Version {
   369  		return E.New("unknown version ", buffer.Byte(0))
   370  	}
   371  	command, _ := buffer.ReadByte()
   372  	if command != CommandConnect {
   373  		return E.New("unsupported stream command ", command)
   374  	}
   375  	destination, err := addressSerializer.ReadAddrPort(io.MultiReader(buffer, stream))
   376  	if err != nil {
   377  		return E.Cause(err, "read request destination")
   378  	}
   379  	select {
   380  	case <-s.connDone:
   381  		return s.connErr
   382  	case <-s.authDone:
   383  	}
   384  	var conn net.Conn = &serverConn{
   385  		Stream:      stream,
   386  		destination: destination,
   387  	}
   388  	if buffer.IsEmpty() {
   389  		buffer.Release()
   390  	} else {
   391  		conn = bufio.NewCachedConn(conn, buffer)
   392  	}
   393  	ctx := s.ctx
   394  	if s.authUser.Name != "" {
   395  		ctx = auth.ContextWithUser(s.ctx, s.authUser.Name)
   396  	}
   397  	_ = s.handler.NewConnection(ctx, conn, M.Metadata{
   398  		Source:      s.source,
   399  		Destination: destination,
   400  	})
   401  	return nil
   402  }
   403  
   404  func (s *serverSession) loopHeartbeats() {
   405  	ticker := time.NewTicker(s.heartbeat)
   406  	defer ticker.Stop()
   407  	for {
   408  		select {
   409  		case <-s.connDone:
   410  			return
   411  		case <-ticker.C:
   412  			err := s.quicConn.SendMessage([]byte{Version, CommandHeartbeat})
   413  			if err != nil {
   414  				s.closeWithError(E.Cause(err, "send heartbeat"))
   415  			}
   416  		}
   417  	}
   418  }
   419  
   420  func (s *serverSession) closeWithError(err error) {
   421  	s.connAccess.Lock()
   422  	defer s.connAccess.Unlock()
   423  	select {
   424  	case <-s.connDone:
   425  		return
   426  	default:
   427  		s.connErr = err
   428  		close(s.connDone)
   429  	}
   430  	if E.IsClosedOrCanceled(err) {
   431  		s.logger.Debug(E.Cause(err, "connection failed"))
   432  	} else {
   433  		s.logger.Error(E.Cause(err, "connection failed"))
   434  	}
   435  	_ = s.quicConn.CloseWithError(0, "")
   436  }
   437  
   438  type serverConn struct {
   439  	quic.Stream
   440  	destination M.Socksaddr
   441  }
   442  
   443  func (c *serverConn) Read(p []byte) (n int, err error) {
   444  	n, err = c.Stream.Read(p)
   445  	return n, baderror.WrapQUIC(err)
   446  }
   447  
   448  func (c *serverConn) Write(p []byte) (n int, err error) {
   449  	n, err = c.Stream.Write(p)
   450  	return n, baderror.WrapQUIC(err)
   451  }
   452  
   453  func (c *serverConn) LocalAddr() net.Addr {
   454  	return c.destination
   455  }
   456  
   457  func (c *serverConn) RemoteAddr() net.Addr {
   458  	return M.Socksaddr{}
   459  }
   460  
   461  func (c *serverConn) Close() error {
   462  	c.Stream.CancelRead(0)
   463  	return c.Stream.Close()
   464  }