github.com/sagernet/sing-box@v1.2.7/inbound/hysteria.go (about)

     1  //go:build with_quic
     2  
     3  package inbound
     4  
     5  import (
     6  	"context"
     7  	"sync"
     8  
     9  	"github.com/sagernet/quic-go"
    10  	"github.com/sagernet/quic-go/congestion"
    11  	"github.com/sagernet/sing-box/adapter"
    12  	"github.com/sagernet/sing-box/common/tls"
    13  	C "github.com/sagernet/sing-box/constant"
    14  	"github.com/sagernet/sing-box/log"
    15  	"github.com/sagernet/sing-box/option"
    16  	"github.com/sagernet/sing-box/transport/hysteria"
    17  	"github.com/sagernet/sing/common"
    18  	"github.com/sagernet/sing/common/auth"
    19  	E "github.com/sagernet/sing/common/exceptions"
    20  	F "github.com/sagernet/sing/common/format"
    21  	M "github.com/sagernet/sing/common/metadata"
    22  	N "github.com/sagernet/sing/common/network"
    23  
    24  	"golang.org/x/exp/slices"
    25  )
    26  
    27  var _ adapter.Inbound = (*Hysteria)(nil)
    28  
    29  type Hysteria struct {
    30  	myInboundAdapter
    31  	quicConfig   *quic.Config
    32  	tlsConfig    tls.ServerConfig
    33  	authKey      []string
    34  	authUser     []string
    35  	xplusKey     []byte
    36  	sendBPS      uint64
    37  	recvBPS      uint64
    38  	listener     quic.Listener
    39  	udpAccess    sync.RWMutex
    40  	udpSessionId uint32
    41  	udpSessions  map[uint32]chan *hysteria.UDPMessage
    42  	udpDefragger hysteria.Defragger
    43  }
    44  
    45  func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaInboundOptions) (*Hysteria, error) {
    46  	options.UDPFragmentDefault = true
    47  	quicConfig := &quic.Config{
    48  		InitialStreamReceiveWindow:     options.ReceiveWindowConn,
    49  		MaxStreamReceiveWindow:         options.ReceiveWindowConn,
    50  		InitialConnectionReceiveWindow: options.ReceiveWindowClient,
    51  		MaxConnectionReceiveWindow:     options.ReceiveWindowClient,
    52  		MaxIncomingStreams:             int64(options.MaxConnClient),
    53  		KeepAlivePeriod:                hysteria.KeepAlivePeriod,
    54  		DisablePathMTUDiscovery:        options.DisableMTUDiscovery || !(C.IsLinux || C.IsWindows),
    55  		EnableDatagrams:                true,
    56  	}
    57  	if options.ReceiveWindowConn == 0 {
    58  		quicConfig.InitialStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow
    59  		quicConfig.MaxStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow
    60  	}
    61  	if options.ReceiveWindowClient == 0 {
    62  		quicConfig.InitialConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow
    63  		quicConfig.MaxConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow
    64  	}
    65  	if quicConfig.MaxIncomingStreams == 0 {
    66  		quicConfig.MaxIncomingStreams = hysteria.DefaultMaxIncomingStreams
    67  	}
    68  	authKey := common.Map(options.Users, func(it option.HysteriaUser) string {
    69  		if len(it.Auth) > 0 {
    70  			return string(it.Auth)
    71  		} else {
    72  			return it.AuthString
    73  		}
    74  	})
    75  	authUser := common.Map(options.Users, func(it option.HysteriaUser) string {
    76  		return it.Name
    77  	})
    78  	var xplus []byte
    79  	if options.Obfs != "" {
    80  		xplus = []byte(options.Obfs)
    81  	}
    82  	var up, down uint64
    83  	if len(options.Up) > 0 {
    84  		up = hysteria.StringToBps(options.Up)
    85  		if up == 0 {
    86  			return nil, E.New("invalid up speed format: ", options.Up)
    87  		}
    88  	} else {
    89  		up = uint64(options.UpMbps) * hysteria.MbpsToBps
    90  	}
    91  	if len(options.Down) > 0 {
    92  		down = hysteria.StringToBps(options.Down)
    93  		if down == 0 {
    94  			return nil, E.New("invalid down speed format: ", options.Down)
    95  		}
    96  	} else {
    97  		down = uint64(options.DownMbps) * hysteria.MbpsToBps
    98  	}
    99  	if up < hysteria.MinSpeedBPS {
   100  		return nil, E.New("invalid up speed")
   101  	}
   102  	if down < hysteria.MinSpeedBPS {
   103  		return nil, E.New("invalid down speed")
   104  	}
   105  	inbound := &Hysteria{
   106  		myInboundAdapter: myInboundAdapter{
   107  			protocol:      C.TypeHysteria,
   108  			network:       []string{N.NetworkUDP},
   109  			ctx:           ctx,
   110  			router:        router,
   111  			logger:        logger,
   112  			tag:           tag,
   113  			listenOptions: options.ListenOptions,
   114  		},
   115  		quicConfig:  quicConfig,
   116  		authKey:     authKey,
   117  		authUser:    authUser,
   118  		xplusKey:    xplus,
   119  		sendBPS:     up,
   120  		recvBPS:     down,
   121  		udpSessions: make(map[uint32]chan *hysteria.UDPMessage),
   122  	}
   123  	if options.TLS == nil || !options.TLS.Enabled {
   124  		return nil, C.ErrTLSRequired
   125  	}
   126  	if len(options.TLS.ALPN) == 0 {
   127  		options.TLS.ALPN = []string{hysteria.DefaultALPN}
   128  	}
   129  	tlsConfig, err := tls.NewServer(ctx, router, logger, common.PtrValueOrDefault(options.TLS))
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  	inbound.tlsConfig = tlsConfig
   134  	return inbound, nil
   135  }
   136  
   137  func (h *Hysteria) Start() error {
   138  	packetConn, err := h.myInboundAdapter.ListenUDP()
   139  	if err != nil {
   140  		return err
   141  	}
   142  	if len(h.xplusKey) > 0 {
   143  		packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey)
   144  		packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn}
   145  	}
   146  	err = h.tlsConfig.Start()
   147  	if err != nil {
   148  		return err
   149  	}
   150  	rawConfig, err := h.tlsConfig.Config()
   151  	if err != nil {
   152  		return err
   153  	}
   154  	listener, err := quic.Listen(packetConn, rawConfig, h.quicConfig)
   155  	if err != nil {
   156  		return err
   157  	}
   158  	h.listener = listener
   159  	h.logger.Info("udp server started at ", listener.Addr())
   160  	go h.acceptLoop()
   161  	return nil
   162  }
   163  
   164  func (h *Hysteria) acceptLoop() {
   165  	for {
   166  		ctx := log.ContextWithNewID(h.ctx)
   167  		conn, err := h.listener.Accept(ctx)
   168  		if err != nil {
   169  			return
   170  		}
   171  		go func() {
   172  			hErr := h.accept(ctx, conn)
   173  			if hErr != nil {
   174  				conn.CloseWithError(0, "")
   175  				NewError(h.logger, ctx, E.Cause(hErr, "process connection from ", conn.RemoteAddr()))
   176  			}
   177  		}()
   178  	}
   179  }
   180  
   181  func (h *Hysteria) accept(ctx context.Context, conn quic.Connection) error {
   182  	controlStream, err := conn.AcceptStream(ctx)
   183  	if err != nil {
   184  		return err
   185  	}
   186  	clientHello, err := hysteria.ReadClientHello(controlStream)
   187  	if err != nil {
   188  		return err
   189  	}
   190  	if len(h.authKey) > 0 {
   191  		userIndex := slices.Index(h.authKey, string(clientHello.Auth))
   192  		if userIndex == -1 {
   193  			err = hysteria.WriteServerHello(controlStream, hysteria.ServerHello{
   194  				Message: "wrong password",
   195  			})
   196  			return E.Errors(E.New("wrong password: ", string(clientHello.Auth)), err)
   197  		}
   198  		user := h.authUser[userIndex]
   199  		if user == "" {
   200  			user = F.ToString(userIndex)
   201  		} else {
   202  			ctx = auth.ContextWithUser(ctx, user)
   203  		}
   204  		h.logger.InfoContext(ctx, "[", user, "] inbound connection from ", conn.RemoteAddr())
   205  	} else {
   206  		h.logger.InfoContext(ctx, "inbound connection from ", conn.RemoteAddr())
   207  	}
   208  	h.logger.DebugContext(ctx, "peer send speed: ", clientHello.SendBPS/1024/1024, " MBps, peer recv speed: ", clientHello.RecvBPS/1024/1024, " MBps")
   209  	if clientHello.SendBPS == 0 || clientHello.RecvBPS == 0 {
   210  		return E.New("invalid rate from client")
   211  	}
   212  	serverSendBPS, serverRecvBPS := clientHello.RecvBPS, clientHello.SendBPS
   213  	if h.sendBPS > 0 && serverSendBPS > h.sendBPS {
   214  		serverSendBPS = h.sendBPS
   215  	}
   216  	if h.recvBPS > 0 && serverRecvBPS > h.recvBPS {
   217  		serverRecvBPS = h.recvBPS
   218  	}
   219  	err = hysteria.WriteServerHello(controlStream, hysteria.ServerHello{
   220  		OK:      true,
   221  		SendBPS: serverSendBPS,
   222  		RecvBPS: serverRecvBPS,
   223  	})
   224  	if err != nil {
   225  		return err
   226  	}
   227  	conn.SetCongestionControl(hysteria.NewBrutalSender(congestion.ByteCount(serverSendBPS)))
   228  	go h.udpRecvLoop(conn)
   229  	for {
   230  		var stream quic.Stream
   231  		stream, err = conn.AcceptStream(ctx)
   232  		if err != nil {
   233  			return err
   234  		}
   235  		go func() {
   236  			hErr := h.acceptStream(ctx, conn /*&hysteria.StreamWrapper{Stream: stream}*/, stream)
   237  			if hErr != nil {
   238  				stream.Close()
   239  				NewError(h.logger, ctx, E.Cause(hErr, "process stream from ", conn.RemoteAddr()))
   240  			}
   241  		}()
   242  	}
   243  }
   244  
   245  func (h *Hysteria) udpRecvLoop(conn quic.Connection) {
   246  	for {
   247  		packet, err := conn.ReceiveMessage()
   248  		if err != nil {
   249  			return
   250  		}
   251  		message, err := hysteria.ParseUDPMessage(packet)
   252  		if err != nil {
   253  			h.logger.Error("parse udp message: ", err)
   254  			continue
   255  		}
   256  		dfMsg := h.udpDefragger.Feed(message)
   257  		if dfMsg == nil {
   258  			continue
   259  		}
   260  		h.udpAccess.RLock()
   261  		ch, ok := h.udpSessions[dfMsg.SessionID]
   262  		if ok {
   263  			select {
   264  			case ch <- dfMsg:
   265  				// OK
   266  			default:
   267  				// Silently drop the message when the channel is full
   268  			}
   269  		}
   270  		h.udpAccess.RUnlock()
   271  	}
   272  }
   273  
   274  func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, stream quic.Stream) error {
   275  	request, err := hysteria.ReadClientRequest(stream)
   276  	if err != nil {
   277  		return err
   278  	}
   279  	var metadata adapter.InboundContext
   280  	metadata.Inbound = h.tag
   281  	metadata.InboundType = C.TypeHysteria
   282  	metadata.InboundOptions = h.listenOptions.InboundOptions
   283  	metadata.Source = M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap()
   284  	metadata.OriginDestination = M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
   285  	metadata.Destination = M.ParseSocksaddrHostPort(request.Host, request.Port).Unwrap()
   286  
   287  	if !request.UDP {
   288  		err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{
   289  			OK: true,
   290  		})
   291  		if err != nil {
   292  			return err
   293  		}
   294  		h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
   295  		return h.router.RouteConnection(ctx, hysteria.NewConn(stream, metadata.Destination, false), metadata)
   296  	} else {
   297  		h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
   298  		var id uint32
   299  		h.udpAccess.Lock()
   300  		id = h.udpSessionId
   301  		nCh := make(chan *hysteria.UDPMessage, 1024)
   302  		h.udpSessions[id] = nCh
   303  		h.udpSessionId += 1
   304  		h.udpAccess.Unlock()
   305  		err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{
   306  			OK:           true,
   307  			UDPSessionID: id,
   308  		})
   309  		if err != nil {
   310  			return err
   311  		}
   312  		packetConn := hysteria.NewPacketConn(conn, stream, id, metadata.Destination, nCh, common.Closer(func() error {
   313  			h.udpAccess.Lock()
   314  			if ch, ok := h.udpSessions[id]; ok {
   315  				close(ch)
   316  				delete(h.udpSessions, id)
   317  			}
   318  			h.udpAccess.Unlock()
   319  			return nil
   320  		}))
   321  		go packetConn.Hold()
   322  		return h.router.RoutePacketConnection(ctx, packetConn, metadata)
   323  	}
   324  }
   325  
   326  func (h *Hysteria) Close() error {
   327  	h.udpAccess.Lock()
   328  	for _, session := range h.udpSessions {
   329  		close(session)
   330  	}
   331  	h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
   332  	h.udpAccess.Unlock()
   333  	return common.Close(
   334  		&h.myInboundAdapter,
   335  		h.listener,
   336  		h.tlsConfig,
   337  	)
   338  }