github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/hysteria2/server.go (about)

     1  package hysteria2
     2  
     3  import (
     4  	"context"
     5  	"github.com/sagernet/quic-go"
     6  	"io"
     7  	"net"
     8  	"net/http"
     9  	"os"
    10  	"runtime"
    11  	"strings"
    12  	"sync"
    13  
    14  	"github.com/inazumav/sing-box/common/baderror"
    15  	"github.com/inazumav/sing-box/common/qtls"
    16  	"github.com/inazumav/sing-box/common/tls"
    17  	"github.com/inazumav/sing-box/transport/hysteria2/congestion"
    18  	"github.com/inazumav/sing-box/transport/hysteria2/internal/protocol"
    19  	tuicCongestion "github.com/inazumav/sing-box/transport/tuic/congestion"
    20  	"github.com/sagernet/quic-go/http3"
    21  	"github.com/sagernet/sing/common"
    22  	"github.com/sagernet/sing/common/auth"
    23  	E "github.com/sagernet/sing/common/exceptions"
    24  	"github.com/sagernet/sing/common/logger"
    25  	M "github.com/sagernet/sing/common/metadata"
    26  	N "github.com/sagernet/sing/common/network"
    27  )
    28  
    29  type ServerOptions struct {
    30  	Context               context.Context
    31  	Logger                logger.Logger
    32  	SendBPS               uint64
    33  	ReceiveBPS            uint64
    34  	IgnoreClientBandwidth bool
    35  	SalamanderPassword    string
    36  	TLSConfig             tls.ServerConfig
    37  	Users                 []User
    38  	UDPDisabled           bool
    39  	Handler               ServerHandler
    40  	MasqueradeHandler     http.Handler
    41  }
    42  
    43  type User struct {
    44  	Name     string
    45  	Password string
    46  }
    47  
    48  type ServerHandler interface {
    49  	N.TCPConnectionHandler
    50  	N.UDPConnectionHandler
    51  }
    52  
    53  type Server struct {
    54  	ctx                   context.Context
    55  	logger                logger.Logger
    56  	sendBPS               uint64
    57  	receiveBPS            uint64
    58  	ignoreClientBandwidth bool
    59  	salamanderPassword    string
    60  	tlsConfig             tls.ServerConfig
    61  	quicConfig            *quic.Config
    62  	userMap               map[string]User
    63  	udpDisabled           bool
    64  	handler               ServerHandler
    65  	masqueradeHandler     http.Handler
    66  	quicListener          io.Closer
    67  }
    68  
    69  func NewServer(options ServerOptions) (*Server, error) {
    70  	quicConfig := &quic.Config{
    71  		DisablePathMTUDiscovery:        !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"),
    72  		MaxDatagramFrameSize:           1400,
    73  		EnableDatagrams:                !options.UDPDisabled,
    74  		MaxIncomingStreams:             1 << 60,
    75  		InitialStreamReceiveWindow:     defaultStreamReceiveWindow,
    76  		MaxStreamReceiveWindow:         defaultStreamReceiveWindow,
    77  		InitialConnectionReceiveWindow: defaultConnReceiveWindow,
    78  		MaxConnectionReceiveWindow:     defaultConnReceiveWindow,
    79  		MaxIdleTimeout:                 defaultMaxIdleTimeout,
    80  		KeepAlivePeriod:                defaultKeepAlivePeriod,
    81  	}
    82  	if len(options.Users) == 0 {
    83  		return nil, E.New("missing users")
    84  	}
    85  	userMap := make(map[string]User)
    86  	for _, user := range options.Users {
    87  		userMap[user.Password] = user
    88  	}
    89  	if options.MasqueradeHandler == nil {
    90  		options.MasqueradeHandler = http.NotFoundHandler()
    91  	}
    92  	return &Server{
    93  		ctx:                   options.Context,
    94  		logger:                options.Logger,
    95  		sendBPS:               options.SendBPS,
    96  		receiveBPS:            options.ReceiveBPS,
    97  		ignoreClientBandwidth: options.IgnoreClientBandwidth,
    98  		salamanderPassword:    options.SalamanderPassword,
    99  		tlsConfig:             options.TLSConfig,
   100  		quicConfig:            quicConfig,
   101  		userMap:               userMap,
   102  		udpDisabled:           options.UDPDisabled,
   103  		handler:               options.Handler,
   104  		masqueradeHandler:     options.MasqueradeHandler,
   105  	}, nil
   106  }
   107  
   108  func (s *Server) Start(conn net.PacketConn) error {
   109  	if s.salamanderPassword != "" {
   110  		conn = NewSalamanderConn(conn, []byte(s.salamanderPassword))
   111  	}
   112  	err := qtls.ConfigureHTTP3(s.tlsConfig)
   113  	if err != nil {
   114  		return err
   115  	}
   116  	listener, err := qtls.Listen(conn, s.tlsConfig, s.quicConfig)
   117  	if err != nil {
   118  		return err
   119  	}
   120  	s.quicListener = listener
   121  	go s.loopConnections(listener)
   122  	return nil
   123  }
   124  
   125  func (s *Server) Close() error {
   126  	return common.Close(
   127  		s.quicListener,
   128  	)
   129  }
   130  
   131  func (s *Server) loopConnections(listener qtls.QUICListener) {
   132  	for {
   133  		connection, err := listener.Accept(s.ctx)
   134  		if err != nil {
   135  			if strings.Contains(err.Error(), "server closed") {
   136  				s.logger.Debug(E.Cause(err, "listener closed"))
   137  			} else {
   138  				s.logger.Error(E.Cause(err, "listener closed"))
   139  			}
   140  			return
   141  		}
   142  		go s.handleConnection(connection)
   143  	}
   144  }
   145  
   146  func (s *Server) handleConnection(connection quic.Connection) {
   147  	session := &serverSession{
   148  		Server:     s,
   149  		ctx:        s.ctx,
   150  		quicConn:   connection,
   151  		source:     M.SocksaddrFromNet(connection.RemoteAddr()),
   152  		connDone:   make(chan struct{}),
   153  		udpConnMap: make(map[uint32]*udpPacketConn),
   154  	}
   155  	httpServer := http3.Server{
   156  		Handler:        session,
   157  		StreamHijacker: session.handleStream0,
   158  	}
   159  	_ = httpServer.ServeQUICConn(connection)
   160  	_ = connection.CloseWithError(0, "")
   161  }
   162  
   163  type serverSession struct {
   164  	*Server
   165  	ctx           context.Context
   166  	quicConn      quic.Connection
   167  	source        M.Socksaddr
   168  	connAccess    sync.Mutex
   169  	connDone      chan struct{}
   170  	connErr       error
   171  	authenticated bool
   172  	authUser      *User
   173  	udpAccess     sync.RWMutex
   174  	udpConnMap    map[uint32]*udpPacketConn
   175  }
   176  
   177  func (s *serverSession) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   178  	if r.Method == http.MethodPost && r.Host == protocol.URLHost && r.URL.Path == protocol.URLPath {
   179  		if s.authenticated {
   180  			protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{
   181  				UDPEnabled: !s.udpDisabled,
   182  				Rx:         s.receiveBPS,
   183  				RxAuto:     s.ignoreClientBandwidth,
   184  			})
   185  			w.WriteHeader(protocol.StatusAuthOK)
   186  			return
   187  		}
   188  		request := protocol.AuthRequestFromHeader(r.Header)
   189  		user, loaded := s.userMap[request.Auth]
   190  		if !loaded {
   191  			s.masqueradeHandler.ServeHTTP(w, r)
   192  			return
   193  		}
   194  		s.authUser = &user
   195  		s.authenticated = true
   196  		if !s.ignoreClientBandwidth && request.Rx > 0 {
   197  			var sendBps uint64
   198  			if s.sendBPS > 0 && s.sendBPS < request.Rx {
   199  				sendBps = s.sendBPS
   200  			} else {
   201  				sendBps = request.Rx
   202  			}
   203  			s.quicConn.SetCongestionControl(congestion.NewBrutalSender(sendBps))
   204  		} else {
   205  			s.quicConn.SetCongestionControl(tuicCongestion.NewBBRSender(
   206  				tuicCongestion.DefaultClock{},
   207  				tuicCongestion.GetInitialPacketSize(s.quicConn.RemoteAddr()),
   208  				tuicCongestion.InitialCongestionWindow*tuicCongestion.InitialMaxDatagramSize,
   209  				tuicCongestion.DefaultBBRMaxCongestionWindow*tuicCongestion.InitialMaxDatagramSize,
   210  			))
   211  		}
   212  		protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{
   213  			UDPEnabled: !s.udpDisabled,
   214  			Rx:         s.receiveBPS,
   215  			RxAuto:     s.ignoreClientBandwidth,
   216  		})
   217  		w.WriteHeader(protocol.StatusAuthOK)
   218  		if s.ctx.Done() != nil {
   219  			go func() {
   220  				select {
   221  				case <-s.ctx.Done():
   222  					s.closeWithError(s.ctx.Err())
   223  				case <-s.connDone:
   224  				}
   225  			}()
   226  		}
   227  		if !s.udpDisabled {
   228  			go s.loopMessages()
   229  		}
   230  	} else {
   231  		s.masqueradeHandler.ServeHTTP(w, r)
   232  	}
   233  }
   234  
   235  func (s *serverSession) handleStream0(frameType http3.FrameType, connection quic.Connection, stream quic.Stream, err error) (bool, error) {
   236  	if !s.authenticated || err != nil {
   237  		return false, nil
   238  	}
   239  	if frameType != protocol.FrameTypeTCPRequest {
   240  		return false, nil
   241  	}
   242  	go func() {
   243  		hErr := s.handleStream(stream)
   244  		if hErr != nil {
   245  			stream.CancelRead(0)
   246  			stream.Close()
   247  			s.logger.Error(E.Cause(hErr, "handle stream request"))
   248  		}
   249  	}()
   250  	return true, nil
   251  }
   252  
   253  func (s *serverSession) handleStream(stream quic.Stream) error {
   254  	destinationString, err := protocol.ReadTCPRequest(stream)
   255  	if err != nil {
   256  		return E.New("read TCP request")
   257  	}
   258  	var conn net.Conn = &serverConn{
   259  		Stream: stream,
   260  	}
   261  	ctx := s.ctx
   262  	if s.authUser.Name != "" {
   263  		ctx = auth.ContextWithUser(s.ctx, s.authUser.Name)
   264  	}
   265  	_ = s.handler.NewConnection(ctx, conn, M.Metadata{
   266  		Source:      s.source,
   267  		Destination: M.ParseSocksaddr(destinationString),
   268  	})
   269  	return nil
   270  }
   271  
   272  func (s *serverSession) closeWithError(err error) {
   273  	s.connAccess.Lock()
   274  	defer s.connAccess.Unlock()
   275  	select {
   276  	case <-s.connDone:
   277  		return
   278  	default:
   279  		s.connErr = err
   280  		close(s.connDone)
   281  	}
   282  	if E.IsClosedOrCanceled(err) {
   283  		s.logger.Debug(E.Cause(err, "connection failed"))
   284  	} else {
   285  		s.logger.Error(E.Cause(err, "connection failed"))
   286  	}
   287  	_ = s.quicConn.CloseWithError(0, "")
   288  }
   289  
   290  type serverConn struct {
   291  	quic.Stream
   292  	responseWritten bool
   293  }
   294  
   295  func (c *serverConn) HandshakeFailure(err error) error {
   296  	if c.responseWritten {
   297  		return os.ErrClosed
   298  	}
   299  	c.responseWritten = true
   300  	buffer := protocol.WriteTCPResponse(false, err.Error(), nil)
   301  	defer buffer.Release()
   302  	return common.Error(c.Stream.Write(buffer.Bytes()))
   303  }
   304  
   305  func (c *serverConn) Read(p []byte) (n int, err error) {
   306  	n, err = c.Stream.Read(p)
   307  	return n, baderror.WrapQUIC(err)
   308  }
   309  
   310  func (c *serverConn) Write(p []byte) (n int, err error) {
   311  	if !c.responseWritten {
   312  		c.responseWritten = true
   313  		buffer := protocol.WriteTCPResponse(true, "", p)
   314  		defer buffer.Release()
   315  		_, err = c.Stream.Write(buffer.Bytes())
   316  		if err != nil {
   317  			return 0, baderror.WrapQUIC(err)
   318  		}
   319  		return len(p), nil
   320  	}
   321  	n, err = c.Stream.Write(p)
   322  	return n, baderror.WrapQUIC(err)
   323  }
   324  
   325  func (c *serverConn) LocalAddr() net.Addr {
   326  	return M.Socksaddr{}
   327  }
   328  
   329  func (c *serverConn) RemoteAddr() net.Addr {
   330  	return M.Socksaddr{}
   331  }
   332  
   333  func (c *serverConn) Close() error {
   334  	c.Stream.CancelRead(0)
   335  	return c.Stream.Close()
   336  }