github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/inbound/hysteria.go (about)

     1  //go:build with_quic
     2  
     3  package inbound
     4  
     5  import (
     6  	"context"
     7  	"github.com/sagernet/quic-go"
     8  	"sync"
     9  
    10  	"github.com/inazumav/sing-box/adapter"
    11  	"github.com/inazumav/sing-box/common/qtls"
    12  	"github.com/inazumav/sing-box/common/tls"
    13  	C "github.com/inazumav/sing-box/constant"
    14  	"github.com/inazumav/sing-box/log"
    15  	"github.com/inazumav/sing-box/option"
    16  	"github.com/inazumav/sing-box/transport/hysteria"
    17  	"github.com/sagernet/quic-go/congestion"
    18  	"github.com/sagernet/sing/common"
    19  	"github.com/sagernet/sing/common/auth"
    20  	E "github.com/sagernet/sing/common/exceptions"
    21  	M "github.com/sagernet/sing/common/metadata"
    22  	N "github.com/sagernet/sing/common/network"
    23  )
    24  
    25  var _ adapter.Inbound = (*Hysteria)(nil)
    26  
    27  type Hysteria struct {
    28  	myInboundAdapter
    29  	quicConfig   *quic.Config
    30  	tlsConfig    tls.ServerConfig
    31  	users        map[string]string // map[password]name
    32  	xplusKey     []byte
    33  	sendBPS      uint64
    34  	recvBPS      uint64
    35  	listener     qtls.QUICListener
    36  	udpAccess    sync.RWMutex
    37  	udpSessionId uint32
    38  	udpSessions  map[uint32]chan *hysteria.UDPMessage
    39  	udpDefragger hysteria.Defragger
    40  }
    41  
    42  func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaInboundOptions) (*Hysteria, error) {
    43  	options.UDPFragmentDefault = true
    44  	quicConfig := &quic.Config{
    45  		InitialStreamReceiveWindow:     options.ReceiveWindowConn,
    46  		MaxStreamReceiveWindow:         options.ReceiveWindowConn,
    47  		InitialConnectionReceiveWindow: options.ReceiveWindowClient,
    48  		MaxConnectionReceiveWindow:     options.ReceiveWindowClient,
    49  		MaxIncomingStreams:             int64(options.MaxConnClient),
    50  		KeepAlivePeriod:                hysteria.KeepAlivePeriod,
    51  		DisablePathMTUDiscovery:        options.DisableMTUDiscovery || !(C.IsLinux || C.IsWindows),
    52  		EnableDatagrams:                true,
    53  	}
    54  	if options.ReceiveWindowConn == 0 {
    55  		quicConfig.InitialStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow
    56  		quicConfig.MaxStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow
    57  	}
    58  	if options.ReceiveWindowClient == 0 {
    59  		quicConfig.InitialConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow
    60  		quicConfig.MaxConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow
    61  	}
    62  	if quicConfig.MaxIncomingStreams == 0 {
    63  		quicConfig.MaxIncomingStreams = hysteria.DefaultMaxIncomingStreams
    64  	}
    65  	users := make(map[string]string, len(options.Users))
    66  	for _, u := range options.Users {
    67  		if len(u.Auth) > 0 {
    68  			users[string(u.Auth)] = u.Name
    69  		} else {
    70  			users[u.AuthString] = u.Name
    71  		}
    72  	}
    73  	var xplus []byte
    74  	if options.Obfs != "" {
    75  		xplus = []byte(options.Obfs)
    76  	}
    77  	var up, down uint64
    78  	if len(options.Up) > 0 {
    79  		up = hysteria.StringToBps(options.Up)
    80  		if up == 0 {
    81  			return nil, E.New("invalid up speed format: ", options.Up)
    82  		}
    83  	} else {
    84  		up = uint64(options.UpMbps) * hysteria.MbpsToBps
    85  	}
    86  	if len(options.Down) > 0 {
    87  		down = hysteria.StringToBps(options.Down)
    88  		if down == 0 {
    89  			return nil, E.New("invalid down speed format: ", options.Down)
    90  		}
    91  	} else {
    92  		down = uint64(options.DownMbps) * hysteria.MbpsToBps
    93  	}
    94  	if up < hysteria.MinSpeedBPS {
    95  		return nil, E.New("invalid up speed")
    96  	}
    97  	if down < hysteria.MinSpeedBPS {
    98  		return nil, E.New("invalid down speed")
    99  	}
   100  	inbound := &Hysteria{
   101  		myInboundAdapter: myInboundAdapter{
   102  			protocol:      C.TypeHysteria,
   103  			network:       []string{N.NetworkUDP},
   104  			ctx:           ctx,
   105  			router:        router,
   106  			logger:        logger,
   107  			tag:           tag,
   108  			listenOptions: options.ListenOptions,
   109  		},
   110  		quicConfig:  quicConfig,
   111  		users:       users,
   112  		xplusKey:    xplus,
   113  		sendBPS:     up,
   114  		recvBPS:     down,
   115  		udpSessions: make(map[uint32]chan *hysteria.UDPMessage),
   116  	}
   117  	if options.TLS == nil || !options.TLS.Enabled {
   118  		return nil, C.ErrTLSRequired
   119  	}
   120  	if len(options.TLS.ALPN) == 0 {
   121  		options.TLS.ALPN = []string{hysteria.DefaultALPN}
   122  	}
   123  	tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
   124  	if err != nil {
   125  		return nil, err
   126  	}
   127  	inbound.tlsConfig = tlsConfig
   128  	return inbound, nil
   129  }
   130  
   131  func (h *Hysteria) Start() error {
   132  	packetConn, err := h.myInboundAdapter.ListenUDP()
   133  	if err != nil {
   134  		return err
   135  	}
   136  	if len(h.xplusKey) > 0 {
   137  		packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey)
   138  		packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn}
   139  	}
   140  	err = h.tlsConfig.Start()
   141  	if err != nil {
   142  		return err
   143  	}
   144  	listener, err := qtls.Listen(packetConn, h.tlsConfig, h.quicConfig)
   145  	if err != nil {
   146  		return err
   147  	}
   148  	h.listener = listener
   149  	h.logger.Info("udp server started at ", listener.Addr())
   150  	go h.acceptLoop()
   151  	return nil
   152  }
   153  
   154  func (h *Hysteria) acceptLoop() {
   155  	for {
   156  		ctx := log.ContextWithNewID(h.ctx)
   157  		conn, err := h.listener.Accept(ctx)
   158  		if err != nil {
   159  			return
   160  		}
   161  		go func() {
   162  			hErr := h.accept(ctx, conn)
   163  			if hErr != nil {
   164  				conn.CloseWithError(0, "")
   165  				NewError(h.logger, ctx, E.Cause(hErr, "process connection from ", conn.RemoteAddr()))
   166  			}
   167  		}()
   168  	}
   169  }
   170  
   171  func (h *Hysteria) accept(ctx context.Context, conn quic.Connection) error {
   172  	controlStream, err := conn.AcceptStream(ctx)
   173  	if err != nil {
   174  		return err
   175  	}
   176  	clientHello, err := hysteria.ReadClientHello(controlStream)
   177  	if err != nil {
   178  		return err
   179  	}
   180  	if len(h.users) > 0 {
   181  		var user string
   182  		if u, ok := h.users[string(clientHello.Auth)]; ok {
   183  			user = u
   184  			if user == "" {
   185  				user = string(clientHello.Auth)
   186  			}
   187  		} else {
   188  			err = hysteria.WriteServerHello(controlStream, hysteria.ServerHello{
   189  				Message: "wrong password",
   190  			})
   191  			return E.Errors(E.New("wrong password: ", string(clientHello.Auth)), err)
   192  		}
   193  		ctx = auth.ContextWithUser(ctx, user)
   194  		h.logger.InfoContext(ctx, "[", user, "] inbound connection from ", conn.RemoteAddr())
   195  	} else {
   196  		h.logger.InfoContext(ctx, "inbound connection from ", conn.RemoteAddr())
   197  	}
   198  	h.logger.DebugContext(ctx, "peer send speed: ", clientHello.SendBPS/1024/1024, " MBps, peer recv speed: ", clientHello.RecvBPS/1024/1024, " MBps")
   199  	if clientHello.SendBPS == 0 || clientHello.RecvBPS == 0 {
   200  		return E.New("invalid rate from client")
   201  	}
   202  	serverSendBPS, serverRecvBPS := clientHello.RecvBPS, clientHello.SendBPS
   203  	if h.sendBPS > 0 && serverSendBPS > h.sendBPS {
   204  		serverSendBPS = h.sendBPS
   205  	}
   206  	if h.recvBPS > 0 && serverRecvBPS > h.recvBPS {
   207  		serverRecvBPS = h.recvBPS
   208  	}
   209  	err = hysteria.WriteServerHello(controlStream, hysteria.ServerHello{
   210  		OK:      true,
   211  		SendBPS: serverSendBPS,
   212  		RecvBPS: serverRecvBPS,
   213  	})
   214  	if err != nil {
   215  		return err
   216  	}
   217  	conn.SetCongestionControl(hysteria.NewBrutalSender(congestion.ByteCount(serverSendBPS)))
   218  	go h.udpRecvLoop(conn)
   219  	for {
   220  		var stream quic.Stream
   221  		stream, err = conn.AcceptStream(ctx)
   222  		if err != nil {
   223  			return err
   224  		}
   225  		go func() {
   226  			hErr := h.acceptStream(ctx, conn /*&hysteria.StreamWrapper{Stream: stream}*/, stream)
   227  			if hErr != nil {
   228  				stream.Close()
   229  				NewError(h.logger, ctx, E.Cause(hErr, "process stream from ", conn.RemoteAddr()))
   230  			}
   231  		}()
   232  	}
   233  }
   234  
   235  func (h *Hysteria) udpRecvLoop(conn quic.Connection) {
   236  	for {
   237  		packet, err := conn.ReceiveMessage(h.ctx)
   238  		if err != nil {
   239  			return
   240  		}
   241  		message, err := hysteria.ParseUDPMessage(packet)
   242  		if err != nil {
   243  			h.logger.Error("parse udp message: ", err)
   244  			continue
   245  		}
   246  		dfMsg := h.udpDefragger.Feed(message)
   247  		if dfMsg == nil {
   248  			continue
   249  		}
   250  		h.udpAccess.RLock()
   251  		ch, ok := h.udpSessions[dfMsg.SessionID]
   252  		if ok {
   253  			select {
   254  			case ch <- dfMsg:
   255  				// OK
   256  			default:
   257  				// Silently drop the message when the channel is full
   258  			}
   259  		}
   260  		h.udpAccess.RUnlock()
   261  	}
   262  }
   263  
   264  func (h *Hysteria) acceptStream(ctx context.Context, conn quic.Connection, stream quic.Stream) error {
   265  	request, err := hysteria.ReadClientRequest(stream)
   266  	if err != nil {
   267  		return err
   268  	}
   269  	var metadata adapter.InboundContext
   270  	metadata.Inbound = h.tag
   271  	metadata.InboundType = C.TypeHysteria
   272  	metadata.InboundOptions = h.listenOptions.InboundOptions
   273  	metadata.Source = M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap()
   274  	metadata.OriginDestination = M.SocksaddrFromNet(conn.LocalAddr()).Unwrap()
   275  	metadata.Destination = M.ParseSocksaddrHostPort(request.Host, request.Port).Unwrap()
   276  	metadata.User, _ = auth.UserFromContext[string](ctx)
   277  
   278  	if !request.UDP {
   279  		err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{
   280  			OK: true,
   281  		})
   282  		if err != nil {
   283  			return err
   284  		}
   285  		h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
   286  		return h.router.RouteConnection(ctx, hysteria.NewConn(stream, metadata.Destination, false), metadata)
   287  	} else {
   288  		h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
   289  		var id uint32
   290  		h.udpAccess.Lock()
   291  		id = h.udpSessionId
   292  		nCh := make(chan *hysteria.UDPMessage, 1024)
   293  		h.udpSessions[id] = nCh
   294  		h.udpSessionId += 1
   295  		h.udpAccess.Unlock()
   296  		err = hysteria.WriteServerResponse(stream, hysteria.ServerResponse{
   297  			OK:           true,
   298  			UDPSessionID: id,
   299  		})
   300  		if err != nil {
   301  			return err
   302  		}
   303  		packetConn := hysteria.NewPacketConn(conn, stream, id, metadata.Destination, nCh, common.Closer(func() error {
   304  			h.udpAccess.Lock()
   305  			if ch, ok := h.udpSessions[id]; ok {
   306  				close(ch)
   307  				delete(h.udpSessions, id)
   308  			}
   309  			h.udpAccess.Unlock()
   310  			return nil
   311  		}))
   312  		go packetConn.Hold()
   313  		return h.router.RoutePacketConnection(ctx, packetConn, metadata)
   314  	}
   315  }
   316  
   317  func (h *Hysteria) AddUsers(users []option.HysteriaUser) error {
   318  	for _, u := range users {
   319  		if len(u.Auth) > 0 {
   320  			h.users[string(u.Auth)] = u.Name
   321  		} else {
   322  			h.users[u.AuthString] = u.Name
   323  		}
   324  	}
   325  	return nil
   326  }
   327  
   328  func (h *Hysteria) DelUsers(names []string) error {
   329  	for _, n := range names {
   330  		delete(h.users, n)
   331  	}
   332  	return nil
   333  }
   334  
   335  func (h *Hysteria) Close() error {
   336  	h.udpAccess.Lock()
   337  	for _, session := range h.udpSessions {
   338  		close(session)
   339  	}
   340  	h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
   341  	h.udpAccess.Unlock()
   342  	return common.Close(
   343  		&h.myInboundAdapter,
   344  		h.listener,
   345  		h.tlsConfig,
   346  	)
   347  }