github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/outbound/hysteria.go (about)

     1  //go:build with_quic
     2  
     3  package outbound
     4  
     5  import (
     6  	"context"
     7  	"github.com/sagernet/quic-go"
     8  	"net"
     9  	"sync"
    10  
    11  	"github.com/inazumav/sing-box/adapter"
    12  	"github.com/inazumav/sing-box/common/dialer"
    13  	"github.com/inazumav/sing-box/common/qtls"
    14  	"github.com/inazumav/sing-box/common/tls"
    15  	C "github.com/inazumav/sing-box/constant"
    16  	"github.com/inazumav/sing-box/log"
    17  	"github.com/inazumav/sing-box/option"
    18  	"github.com/inazumav/sing-box/transport/hysteria"
    19  	"github.com/sagernet/quic-go/congestion"
    20  	"github.com/sagernet/sing/common"
    21  	"github.com/sagernet/sing/common/bufio"
    22  	E "github.com/sagernet/sing/common/exceptions"
    23  	M "github.com/sagernet/sing/common/metadata"
    24  	N "github.com/sagernet/sing/common/network"
    25  )
    26  
    27  var (
    28  	_ adapter.Outbound                = (*Hysteria)(nil)
    29  	_ adapter.InterfaceUpdateListener = (*Hysteria)(nil)
    30  )
    31  
    32  type Hysteria struct {
    33  	myOutboundAdapter
    34  	ctx          context.Context
    35  	dialer       N.Dialer
    36  	serverAddr   M.Socksaddr
    37  	tlsConfig    tls.Config
    38  	quicConfig   *quic.Config
    39  	authKey      []byte
    40  	xplusKey     []byte
    41  	sendBPS      uint64
    42  	recvBPS      uint64
    43  	connAccess   sync.Mutex
    44  	conn         quic.Connection
    45  	rawConn      net.Conn
    46  	udpAccess    sync.RWMutex
    47  	udpSessions  map[uint32]chan *hysteria.UDPMessage
    48  	udpDefragger hysteria.Defragger
    49  }
    50  
    51  func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaOutboundOptions) (*Hysteria, error) {
    52  	options.UDPFragmentDefault = true
    53  	if options.TLS == nil || !options.TLS.Enabled {
    54  		return nil, C.ErrTLSRequired
    55  	}
    56  	tlsConfig, err := tls.NewClient(ctx, options.Server, common.PtrValueOrDefault(options.TLS))
    57  	if err != nil {
    58  		return nil, err
    59  	}
    60  	if len(tlsConfig.NextProtos()) == 0 {
    61  		tlsConfig.SetNextProtos([]string{hysteria.DefaultALPN})
    62  	}
    63  	quicConfig := &quic.Config{
    64  		InitialStreamReceiveWindow:     options.ReceiveWindowConn,
    65  		MaxStreamReceiveWindow:         options.ReceiveWindowConn,
    66  		InitialConnectionReceiveWindow: options.ReceiveWindow,
    67  		MaxConnectionReceiveWindow:     options.ReceiveWindow,
    68  		KeepAlivePeriod:                hysteria.KeepAlivePeriod,
    69  		DisablePathMTUDiscovery:        options.DisableMTUDiscovery,
    70  		EnableDatagrams:                true,
    71  	}
    72  	if options.ReceiveWindowConn == 0 {
    73  		quicConfig.InitialStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow
    74  		quicConfig.MaxStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow
    75  	}
    76  	if options.ReceiveWindow == 0 {
    77  		quicConfig.InitialConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow
    78  		quicConfig.MaxConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow
    79  	}
    80  	if quicConfig.MaxIncomingStreams == 0 {
    81  		quicConfig.MaxIncomingStreams = hysteria.DefaultMaxIncomingStreams
    82  	}
    83  	var auth []byte
    84  	if len(options.Auth) > 0 {
    85  		auth = options.Auth
    86  	} else {
    87  		auth = []byte(options.AuthString)
    88  	}
    89  	var xplus []byte
    90  	if options.Obfs != "" {
    91  		xplus = []byte(options.Obfs)
    92  	}
    93  	var up, down uint64
    94  	if len(options.Up) > 0 {
    95  		up = hysteria.StringToBps(options.Up)
    96  		if up == 0 {
    97  			return nil, E.New("invalid up speed format: ", options.Up)
    98  		}
    99  	} else {
   100  		up = uint64(options.UpMbps) * hysteria.MbpsToBps
   101  	}
   102  	if len(options.Down) > 0 {
   103  		down = hysteria.StringToBps(options.Down)
   104  		if down == 0 {
   105  			return nil, E.New("invalid down speed format: ", options.Down)
   106  		}
   107  	} else {
   108  		down = uint64(options.DownMbps) * hysteria.MbpsToBps
   109  	}
   110  	if up < hysteria.MinSpeedBPS {
   111  		return nil, E.New("invalid up speed")
   112  	}
   113  	if down < hysteria.MinSpeedBPS {
   114  		return nil, E.New("invalid down speed")
   115  	}
   116  	outboundDialer, err := dialer.New(router, options.DialerOptions)
   117  	if err != nil {
   118  		return nil, err
   119  	}
   120  	return &Hysteria{
   121  		myOutboundAdapter: myOutboundAdapter{
   122  			protocol:     C.TypeHysteria,
   123  			network:      options.Network.Build(),
   124  			router:       router,
   125  			logger:       logger,
   126  			tag:          tag,
   127  			dependencies: withDialerDependency(options.DialerOptions),
   128  		},
   129  		ctx:        ctx,
   130  		dialer:     outboundDialer,
   131  		serverAddr: options.ServerOptions.Build(),
   132  		tlsConfig:  tlsConfig,
   133  		quicConfig: quicConfig,
   134  		authKey:    auth,
   135  		xplusKey:   xplus,
   136  		sendBPS:    up,
   137  		recvBPS:    down,
   138  	}, nil
   139  }
   140  
   141  func (h *Hysteria) offer(ctx context.Context) (quic.Connection, error) {
   142  	conn := h.conn
   143  	if conn != nil && !common.Done(conn.Context()) {
   144  		return conn, nil
   145  	}
   146  	h.connAccess.Lock()
   147  	defer h.connAccess.Unlock()
   148  	h.udpAccess.Lock()
   149  	defer h.udpAccess.Unlock()
   150  	conn = h.conn
   151  	if conn != nil && !common.Done(conn.Context()) {
   152  		return conn, nil
   153  	}
   154  	common.Close(h.rawConn)
   155  	conn, err := h.offerNew(ctx)
   156  	if err != nil {
   157  		return nil, err
   158  	}
   159  	if common.Contains(h.network, N.NetworkUDP) {
   160  		for _, session := range h.udpSessions {
   161  			close(session)
   162  		}
   163  		h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
   164  		h.udpDefragger = hysteria.Defragger{}
   165  		go h.udpRecvLoop(conn)
   166  	}
   167  	return conn, nil
   168  }
   169  
   170  func (h *Hysteria) offerNew(ctx context.Context) (quic.Connection, error) {
   171  	udpConn, err := h.dialer.DialContext(h.ctx, "udp", h.serverAddr)
   172  	if err != nil {
   173  		return nil, err
   174  	}
   175  	var packetConn net.PacketConn
   176  	packetConn = bufio.NewUnbindPacketConn(udpConn)
   177  	if h.xplusKey != nil {
   178  		packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey)
   179  	}
   180  	packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn}
   181  	quicConn, err := qtls.Dial(h.ctx, packetConn, udpConn.RemoteAddr(), h.tlsConfig, h.quicConfig)
   182  	if err != nil {
   183  		packetConn.Close()
   184  		return nil, err
   185  	}
   186  	controlStream, err := quicConn.OpenStreamSync(ctx)
   187  	if err != nil {
   188  		packetConn.Close()
   189  		return nil, err
   190  	}
   191  	err = hysteria.WriteClientHello(controlStream, hysteria.ClientHello{
   192  		SendBPS: h.sendBPS,
   193  		RecvBPS: h.recvBPS,
   194  		Auth:    h.authKey,
   195  	})
   196  	if err != nil {
   197  		packetConn.Close()
   198  		return nil, err
   199  	}
   200  	serverHello, err := hysteria.ReadServerHello(controlStream)
   201  	if err != nil {
   202  		packetConn.Close()
   203  		return nil, err
   204  	}
   205  	if !serverHello.OK {
   206  		packetConn.Close()
   207  		return nil, E.New("remote error: ", serverHello.Message)
   208  	}
   209  	quicConn.SetCongestionControl(hysteria.NewBrutalSender(congestion.ByteCount(serverHello.RecvBPS)))
   210  	h.conn = quicConn
   211  	h.rawConn = udpConn
   212  	return quicConn, nil
   213  }
   214  
   215  func (h *Hysteria) udpRecvLoop(conn quic.Connection) {
   216  	for {
   217  		packet, err := conn.ReceiveMessage(h.ctx)
   218  		if err != nil {
   219  			return
   220  		}
   221  		message, err := hysteria.ParseUDPMessage(packet)
   222  		if err != nil {
   223  			h.logger.Error("parse udp message: ", err)
   224  			continue
   225  		}
   226  		dfMsg := h.udpDefragger.Feed(message)
   227  		if dfMsg == nil {
   228  			continue
   229  		}
   230  		h.udpAccess.RLock()
   231  		ch, ok := h.udpSessions[dfMsg.SessionID]
   232  		if ok {
   233  			select {
   234  			case ch <- dfMsg:
   235  				// OK
   236  			default:
   237  				// Silently drop the message when the channel is full
   238  			}
   239  		}
   240  		h.udpAccess.RUnlock()
   241  	}
   242  }
   243  
   244  func (h *Hysteria) InterfaceUpdated() {
   245  	h.Close()
   246  	return
   247  }
   248  
   249  func (h *Hysteria) Close() error {
   250  	h.connAccess.Lock()
   251  	defer h.connAccess.Unlock()
   252  	h.udpAccess.Lock()
   253  	defer h.udpAccess.Unlock()
   254  	if h.conn != nil {
   255  		h.conn.CloseWithError(0, "")
   256  		h.rawConn.Close()
   257  	}
   258  	for _, session := range h.udpSessions {
   259  		close(session)
   260  	}
   261  	h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage)
   262  	return nil
   263  }
   264  
   265  func (h *Hysteria) open(ctx context.Context, reconnect bool) (quic.Connection, quic.Stream, error) {
   266  	conn, err := h.offer(ctx)
   267  	if err != nil {
   268  		if nErr, ok := err.(net.Error); ok && !nErr.Temporary() && reconnect {
   269  			return h.open(ctx, false)
   270  		}
   271  		return nil, nil, err
   272  	}
   273  	stream, err := conn.OpenStream()
   274  	if err != nil {
   275  		if nErr, ok := err.(net.Error); ok && !nErr.Temporary() && reconnect {
   276  			return h.open(ctx, false)
   277  		}
   278  		return nil, nil, err
   279  	}
   280  	return conn, &hysteria.StreamWrapper{Stream: stream}, nil
   281  }
   282  
   283  func (h *Hysteria) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
   284  	switch N.NetworkName(network) {
   285  	case N.NetworkTCP:
   286  		h.logger.InfoContext(ctx, "outbound connection to ", destination)
   287  		_, stream, err := h.open(ctx, true)
   288  		if err != nil {
   289  			return nil, err
   290  		}
   291  		err = hysteria.WriteClientRequest(stream, hysteria.ClientRequest{
   292  			Host: destination.AddrString(),
   293  			Port: destination.Port,
   294  		})
   295  		if err != nil {
   296  			stream.Close()
   297  			return nil, err
   298  		}
   299  		return hysteria.NewConn(stream, destination, true), nil
   300  	case N.NetworkUDP:
   301  		conn, err := h.ListenPacket(ctx, destination)
   302  		if err != nil {
   303  			return nil, err
   304  		}
   305  		return conn.(*hysteria.PacketConn), nil
   306  	default:
   307  		return nil, E.New("unsupported network: ", network)
   308  	}
   309  }
   310  
   311  func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
   312  	h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
   313  	conn, stream, err := h.open(ctx, true)
   314  	if err != nil {
   315  		return nil, err
   316  	}
   317  	err = hysteria.WriteClientRequest(stream, hysteria.ClientRequest{
   318  		UDP:  true,
   319  		Host: destination.AddrString(),
   320  		Port: destination.Port,
   321  	})
   322  	if err != nil {
   323  		stream.Close()
   324  		return nil, err
   325  	}
   326  	var response *hysteria.ServerResponse
   327  	response, err = hysteria.ReadServerResponse(stream)
   328  	if err != nil {
   329  		stream.Close()
   330  		return nil, err
   331  	}
   332  	if !response.OK {
   333  		stream.Close()
   334  		return nil, E.New("remote error: ", response.Message)
   335  	}
   336  	h.udpAccess.Lock()
   337  	nCh := make(chan *hysteria.UDPMessage, 1024)
   338  	h.udpSessions[response.UDPSessionID] = nCh
   339  	h.udpAccess.Unlock()
   340  	packetConn := hysteria.NewPacketConn(conn, stream, response.UDPSessionID, destination, nCh, common.Closer(func() error {
   341  		h.udpAccess.Lock()
   342  		if ch, ok := h.udpSessions[response.UDPSessionID]; ok {
   343  			close(ch)
   344  			delete(h.udpSessions, response.UDPSessionID)
   345  		}
   346  		h.udpAccess.Unlock()
   347  		return nil
   348  	}))
   349  	go packetConn.Hold()
   350  	return packetConn, nil
   351  }
   352  
   353  func (h *Hysteria) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
   354  	return NewConnection(ctx, h, conn, metadata)
   355  }
   356  
   357  func (h *Hysteria) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
   358  	return NewPacketConnection(ctx, h, conn, metadata)
   359  }