github.com/sagernet/sing-box@v1.2.7/outbound/hysteria.go (about)

     1  //go:build with_quic
     2  
     3  package outbound
     4  
     5  import (
     6  	"context"
     7  	"net"
     8  	"sync"
     9  
    10  	"github.com/sagernet/quic-go"
    11  	"github.com/sagernet/quic-go/congestion"
    12  	"github.com/sagernet/sing-box/adapter"
    13  	"github.com/sagernet/sing-box/common/dialer"
    14  	"github.com/sagernet/sing-box/common/tls"
    15  	C "github.com/sagernet/sing-box/constant"
    16  	"github.com/sagernet/sing-box/log"
    17  	"github.com/sagernet/sing-box/option"
    18  	"github.com/sagernet/sing-box/transport/hysteria"
    19  	"github.com/sagernet/sing/common"
    20  	"github.com/sagernet/sing/common/bufio"
    21  	E "github.com/sagernet/sing/common/exceptions"
    22  	M "github.com/sagernet/sing/common/metadata"
    23  	N "github.com/sagernet/sing/common/network"
    24  )
    25  
    26  var (
    27  	_ adapter.Outbound                = (*Hysteria)(nil)
    28  	_ adapter.InterfaceUpdateListener = (*Hysteria)(nil)
    29  )
    30  
    31  type Hysteria struct {
    32  	myOutboundAdapter
    33  	ctx          context.Context
    34  	dialer       N.Dialer
    35  	serverAddr   M.Socksaddr
    36  	tlsConfig    *tls.STDConfig
    37  	quicConfig   *quic.Config
    38  	authKey      []byte
    39  	xplusKey     []byte
    40  	sendBPS      uint64
    41  	recvBPS      uint64
    42  	connAccess   sync.Mutex
    43  	conn         quic.Connection
    44  	rawConn      net.Conn
    45  	udpAccess    sync.RWMutex
    46  	udpSessions  map[uint32]chan *hysteria.UDPMessage
    47  	udpDefragger hysteria.Defragger
    48  }
    49  
    50  func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaOutboundOptions) (*Hysteria, error) {
    51  	options.UDPFragmentDefault = true
    52  	if options.TLS == nil || !options.TLS.Enabled {
    53  		return nil, C.ErrTLSRequired
    54  	}
    55  	abstractTLSConfig, err := tls.NewClient(router, options.Server, common.PtrValueOrDefault(options.TLS))
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  	tlsConfig, err := abstractTLSConfig.Config()
    60  	if err != nil {
    61  		return nil, err
    62  	}
    63  	tlsConfig.MinVersion = tls.VersionTLS13
    64  	if len(tlsConfig.NextProtos) == 0 {
    65  		tlsConfig.NextProtos = []string{hysteria.DefaultALPN}
    66  	}
    67  	quicConfig := &quic.Config{
    68  		InitialStreamReceiveWindow:     options.ReceiveWindowConn,
    69  		MaxStreamReceiveWindow:         options.ReceiveWindowConn,
    70  		InitialConnectionReceiveWindow: options.ReceiveWindow,
    71  		MaxConnectionReceiveWindow:     options.ReceiveWindow,
    72  		KeepAlivePeriod:                hysteria.KeepAlivePeriod,
    73  		DisablePathMTUDiscovery:        options.DisableMTUDiscovery,
    74  		EnableDatagrams:                true,
    75  	}
    76  	if options.ReceiveWindowConn == 0 {
    77  		quicConfig.InitialStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow
    78  		quicConfig.MaxStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow
    79  	}
    80  	if options.ReceiveWindow == 0 {
    81  		quicConfig.InitialConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow
    82  		quicConfig.MaxConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow
    83  	}
    84  	if quicConfig.MaxIncomingStreams == 0 {
    85  		quicConfig.MaxIncomingStreams = hysteria.DefaultMaxIncomingStreams
    86  	}
    87  	var auth []byte
    88  	if len(options.Auth) > 0 {
    89  		auth = options.Auth
    90  	} else {
    91  		auth = []byte(options.AuthString)
    92  	}
    93  	var xplus []byte
    94  	if options.Obfs != "" {
    95  		xplus = []byte(options.Obfs)
    96  	}
    97  	var up, down uint64
    98  	if len(options.Up) > 0 {
    99  		up = hysteria.StringToBps(options.Up)
   100  		if up == 0 {
   101  			return nil, E.New("invalid up speed format: ", options.Up)
   102  		}
   103  	} else {
   104  		up = uint64(options.UpMbps) * hysteria.MbpsToBps
   105  	}
   106  	if len(options.Down) > 0 {
   107  		down = hysteria.StringToBps(options.Down)
   108  		if down == 0 {
   109  			return nil, E.New("invalid down speed format: ", options.Down)
   110  		}
   111  	} else {
   112  		down = uint64(options.DownMbps) * hysteria.MbpsToBps
   113  	}
   114  	if up < hysteria.MinSpeedBPS {
   115  		return nil, E.New("invalid up speed")
   116  	}
   117  	if down < hysteria.MinSpeedBPS {
   118  		return nil, E.New("invalid down speed")
   119  	}
   120  	return &Hysteria{
   121  		myOutboundAdapter: myOutboundAdapter{
   122  			protocol: C.TypeHysteria,
   123  			network:  options.Network.Build(),
   124  			router:   router,
   125  			logger:   logger,
   126  			tag:      tag,
   127  		},
   128  		ctx:        ctx,
   129  		dialer:     dialer.New(router, options.DialerOptions),
   130  		serverAddr: options.ServerOptions.Build(),
   131  		tlsConfig:  tlsConfig,
   132  		quicConfig: quicConfig,
   133  		authKey:    auth,
   134  		xplusKey:   xplus,
   135  		sendBPS:    up,
   136  		recvBPS:    down,
   137  	}, nil
   138  }
   139  
   140  func (h *Hysteria) offer(ctx context.Context) (quic.Connection, error) {
   141  	conn := h.conn
   142  	if conn != nil && !common.Done(conn.Context()) {
   143  		return conn, nil
   144  	}
   145  	h.connAccess.Lock()
   146  	defer h.connAccess.Unlock()
   147  	h.udpAccess.Lock()
   148  	defer h.udpAccess.Unlock()
   149  	conn = h.conn
   150  	if conn != nil && !common.Done(conn.Context()) {
   151  		return conn, nil
   152  	}
   153  	conn, err := h.offerNew(ctx)
   154  	if err != nil {
   155  		return nil, err
   156  	}
   157  	if common.Contains(h.network, N.NetworkUDP) {
   158  		for _, session := range h.udpSessions {
   159  			close(session)
   160  		}
   161  		h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
   162  		h.udpDefragger = hysteria.Defragger{}
   163  		go h.udpRecvLoop(conn)
   164  	}
   165  	return conn, nil
   166  }
   167  
   168  func (h *Hysteria) offerNew(ctx context.Context) (quic.Connection, error) {
   169  	udpConn, err := h.dialer.DialContext(h.ctx, "udp", h.serverAddr)
   170  	if err != nil {
   171  		return nil, err
   172  	}
   173  	var packetConn net.PacketConn
   174  	packetConn = bufio.NewUnbindPacketConn(udpConn)
   175  	if h.xplusKey != nil {
   176  		packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey)
   177  	}
   178  	packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn}
   179  	quicConn, err := quic.Dial(packetConn, udpConn.RemoteAddr(), h.serverAddr.AddrString(), h.tlsConfig, h.quicConfig)
   180  	if err != nil {
   181  		packetConn.Close()
   182  		return nil, err
   183  	}
   184  	controlStream, err := quicConn.OpenStreamSync(ctx)
   185  	if err != nil {
   186  		packetConn.Close()
   187  		return nil, err
   188  	}
   189  	err = hysteria.WriteClientHello(controlStream, hysteria.ClientHello{
   190  		SendBPS: h.sendBPS,
   191  		RecvBPS: h.recvBPS,
   192  		Auth:    h.authKey,
   193  	})
   194  	if err != nil {
   195  		packetConn.Close()
   196  		return nil, err
   197  	}
   198  	serverHello, err := hysteria.ReadServerHello(controlStream)
   199  	if err != nil {
   200  		packetConn.Close()
   201  		return nil, err
   202  	}
   203  	if !serverHello.OK {
   204  		packetConn.Close()
   205  		return nil, E.New("remote error: ", serverHello.Message)
   206  	}
   207  	quicConn.SetCongestionControl(hysteria.NewBrutalSender(congestion.ByteCount(serverHello.RecvBPS)))
   208  	h.conn = quicConn
   209  	h.rawConn = udpConn
   210  	return quicConn, nil
   211  }
   212  
   213  func (h *Hysteria) udpRecvLoop(conn quic.Connection) {
   214  	for {
   215  		packet, err := conn.ReceiveMessage()
   216  		if err != nil {
   217  			return
   218  		}
   219  		message, err := hysteria.ParseUDPMessage(packet)
   220  		if err != nil {
   221  			h.logger.Error("parse udp message: ", err)
   222  			continue
   223  		}
   224  		dfMsg := h.udpDefragger.Feed(message)
   225  		if dfMsg == nil {
   226  			continue
   227  		}
   228  		h.udpAccess.RLock()
   229  		ch, ok := h.udpSessions[dfMsg.SessionID]
   230  		if ok {
   231  			select {
   232  			case ch <- dfMsg:
   233  				// OK
   234  			default:
   235  				// Silently drop the message when the channel is full
   236  			}
   237  		}
   238  		h.udpAccess.RUnlock()
   239  	}
   240  }
   241  
   242  func (h *Hysteria) InterfaceUpdated() error {
   243  	h.Close()
   244  	return nil
   245  }
   246  
   247  func (h *Hysteria) Close() error {
   248  	h.connAccess.Lock()
   249  	defer h.connAccess.Unlock()
   250  	h.udpAccess.Lock()
   251  	defer h.udpAccess.Unlock()
   252  	if h.conn != nil {
   253  		h.conn.CloseWithError(0, "")
   254  		h.rawConn.Close()
   255  	}
   256  	for _, session := range h.udpSessions {
   257  		close(session)
   258  	}
   259  	h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
   260  	return nil
   261  }
   262  
   263  func (h *Hysteria) open(ctx context.Context) (quic.Connection, quic.Stream, error) {
   264  	conn, err := h.offer(ctx)
   265  	if err != nil {
   266  		return nil, nil, err
   267  	}
   268  	stream, err := conn.OpenStream()
   269  	if err != nil {
   270  		return nil, nil, err
   271  	}
   272  	return conn, &hysteria.StreamWrapper{Stream: stream}, nil
   273  }
   274  
   275  func (h *Hysteria) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
   276  	switch N.NetworkName(network) {
   277  	case N.NetworkTCP:
   278  		h.logger.InfoContext(ctx, "outbound connection to ", destination)
   279  		_, stream, err := h.open(ctx)
   280  		if err != nil {
   281  			return nil, err
   282  		}
   283  		err = hysteria.WriteClientRequest(stream, hysteria.ClientRequest{
   284  			Host: destination.AddrString(),
   285  			Port: destination.Port,
   286  		})
   287  		if err != nil {
   288  			stream.Close()
   289  			return nil, err
   290  		}
   291  		return hysteria.NewConn(stream, destination, true), nil
   292  	case N.NetworkUDP:
   293  		conn, err := h.ListenPacket(ctx, destination)
   294  		if err != nil {
   295  			return nil, err
   296  		}
   297  		return conn.(*hysteria.PacketConn), nil
   298  	default:
   299  		return nil, E.New("unsupported network: ", network)
   300  	}
   301  }
   302  
   303  func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
   304  	h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
   305  	conn, stream, err := h.open(ctx)
   306  	if err != nil {
   307  		return nil, err
   308  	}
   309  	err = hysteria.WriteClientRequest(stream, hysteria.ClientRequest{
   310  		UDP:  true,
   311  		Host: destination.AddrString(),
   312  		Port: destination.Port,
   313  	})
   314  	if err != nil {
   315  		stream.Close()
   316  		return nil, err
   317  	}
   318  	var response *hysteria.ServerResponse
   319  	response, err = hysteria.ReadServerResponse(stream)
   320  	if err != nil {
   321  		stream.Close()
   322  		return nil, err
   323  	}
   324  	if !response.OK {
   325  		stream.Close()
   326  		return nil, E.New("remote error: ", response.Message)
   327  	}
   328  	h.udpAccess.Lock()
   329  	nCh := make(chan *hysteria.UDPMessage, 1024)
   330  	h.udpSessions[response.UDPSessionID] = nCh
   331  	h.udpAccess.Unlock()
   332  	packetConn := hysteria.NewPacketConn(conn, stream, response.UDPSessionID, destination, nCh, common.Closer(func() error {
   333  		h.udpAccess.Lock()
   334  		if ch, ok := h.udpSessions[response.UDPSessionID]; ok {
   335  			close(ch)
   336  			delete(h.udpSessions, response.UDPSessionID)
   337  		}
   338  		h.udpAccess.Unlock()
   339  		return nil
   340  	}))
   341  	go packetConn.Hold()
   342  	return packetConn, nil
   343  }
   344  
   345  func (h *Hysteria) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
   346  	return NewConnection(ctx, h, conn, metadata)
   347  }
   348  
   349  func (h *Hysteria) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
   350  	return NewPacketConnection(ctx, h, conn, metadata)
   351  }