github.com/sagernet/sing-box@v1.2.7/outbound/hysteria.go (about) 1 //go:build with_quic 2 3 package outbound 4 5 import ( 6 "context" 7 "net" 8 "sync" 9 10 "github.com/sagernet/quic-go" 11 "github.com/sagernet/quic-go/congestion" 12 "github.com/sagernet/sing-box/adapter" 13 "github.com/sagernet/sing-box/common/dialer" 14 "github.com/sagernet/sing-box/common/tls" 15 C "github.com/sagernet/sing-box/constant" 16 "github.com/sagernet/sing-box/log" 17 "github.com/sagernet/sing-box/option" 18 "github.com/sagernet/sing-box/transport/hysteria" 19 "github.com/sagernet/sing/common" 20 "github.com/sagernet/sing/common/bufio" 21 E "github.com/sagernet/sing/common/exceptions" 22 M "github.com/sagernet/sing/common/metadata" 23 N "github.com/sagernet/sing/common/network" 24 ) 25 26 var ( 27 _ adapter.Outbound = (*Hysteria)(nil) 28 _ adapter.InterfaceUpdateListener = (*Hysteria)(nil) 29 ) 30 31 type Hysteria struct { 32 myOutboundAdapter 33 ctx context.Context 34 dialer N.Dialer 35 serverAddr M.Socksaddr 36 tlsConfig *tls.STDConfig 37 quicConfig *quic.Config 38 authKey []byte 39 xplusKey []byte 40 sendBPS uint64 41 recvBPS uint64 42 connAccess sync.Mutex 43 conn quic.Connection 44 rawConn net.Conn 45 udpAccess sync.RWMutex 46 udpSessions map[uint32]chan *hysteria.UDPMessage 47 udpDefragger hysteria.Defragger 48 } 49 50 func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaOutboundOptions) (*Hysteria, error) { 51 options.UDPFragmentDefault = true 52 if options.TLS == nil || !options.TLS.Enabled { 53 return nil, C.ErrTLSRequired 54 } 55 abstractTLSConfig, err := tls.NewClient(router, options.Server, common.PtrValueOrDefault(options.TLS)) 56 if err != nil { 57 return nil, err 58 } 59 tlsConfig, err := abstractTLSConfig.Config() 60 if err != nil { 61 return nil, err 62 } 63 tlsConfig.MinVersion = tls.VersionTLS13 64 if len(tlsConfig.NextProtos) == 0 { 65 tlsConfig.NextProtos = []string{hysteria.DefaultALPN} 66 } 67 quicConfig := &quic.Config{ 68 InitialStreamReceiveWindow: options.ReceiveWindowConn, 69 MaxStreamReceiveWindow: options.ReceiveWindowConn, 70 InitialConnectionReceiveWindow: options.ReceiveWindow, 71 MaxConnectionReceiveWindow: options.ReceiveWindow, 72 KeepAlivePeriod: hysteria.KeepAlivePeriod, 73 DisablePathMTUDiscovery: options.DisableMTUDiscovery, 74 EnableDatagrams: true, 75 } 76 if options.ReceiveWindowConn == 0 { 77 quicConfig.InitialStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow 78 quicConfig.MaxStreamReceiveWindow = hysteria.DefaultStreamReceiveWindow 79 } 80 if options.ReceiveWindow == 0 { 81 quicConfig.InitialConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow 82 quicConfig.MaxConnectionReceiveWindow = hysteria.DefaultConnectionReceiveWindow 83 } 84 if quicConfig.MaxIncomingStreams == 0 { 85 quicConfig.MaxIncomingStreams = hysteria.DefaultMaxIncomingStreams 86 } 87 var auth []byte 88 if len(options.Auth) > 0 { 89 auth = options.Auth 90 } else { 91 auth = []byte(options.AuthString) 92 } 93 var xplus []byte 94 if options.Obfs != "" { 95 xplus = []byte(options.Obfs) 96 } 97 var up, down uint64 98 if len(options.Up) > 0 { 99 up = hysteria.StringToBps(options.Up) 100 if up == 0 { 101 return nil, E.New("invalid up speed format: ", options.Up) 102 } 103 } else { 104 up = uint64(options.UpMbps) * hysteria.MbpsToBps 105 } 106 if len(options.Down) > 0 { 107 down = hysteria.StringToBps(options.Down) 108 if down == 0 { 109 return nil, E.New("invalid down speed format: ", options.Down) 110 } 111 } else { 112 down = uint64(options.DownMbps) * hysteria.MbpsToBps 113 } 114 if up < hysteria.MinSpeedBPS { 115 return nil, E.New("invalid up speed") 116 } 117 if down < hysteria.MinSpeedBPS { 118 return nil, E.New("invalid down speed") 119 } 120 return &Hysteria{ 121 myOutboundAdapter: myOutboundAdapter{ 122 protocol: C.TypeHysteria, 123 network: options.Network.Build(), 124 router: router, 125 logger: logger, 126 tag: tag, 127 }, 128 ctx: ctx, 129 dialer: dialer.New(router, options.DialerOptions), 130 serverAddr: options.ServerOptions.Build(), 131 tlsConfig: tlsConfig, 132 quicConfig: quicConfig, 133 authKey: auth, 134 xplusKey: xplus, 135 sendBPS: up, 136 recvBPS: down, 137 }, nil 138 } 139 140 func (h *Hysteria) offer(ctx context.Context) (quic.Connection, error) { 141 conn := h.conn 142 if conn != nil && !common.Done(conn.Context()) { 143 return conn, nil 144 } 145 h.connAccess.Lock() 146 defer h.connAccess.Unlock() 147 h.udpAccess.Lock() 148 defer h.udpAccess.Unlock() 149 conn = h.conn 150 if conn != nil && !common.Done(conn.Context()) { 151 return conn, nil 152 } 153 conn, err := h.offerNew(ctx) 154 if err != nil { 155 return nil, err 156 } 157 if common.Contains(h.network, N.NetworkUDP) { 158 for _, session := range h.udpSessions { 159 close(session) 160 } 161 h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage) 162 h.udpDefragger = hysteria.Defragger{} 163 go h.udpRecvLoop(conn) 164 } 165 return conn, nil 166 } 167 168 func (h *Hysteria) offerNew(ctx context.Context) (quic.Connection, error) { 169 udpConn, err := h.dialer.DialContext(h.ctx, "udp", h.serverAddr) 170 if err != nil { 171 return nil, err 172 } 173 var packetConn net.PacketConn 174 packetConn = bufio.NewUnbindPacketConn(udpConn) 175 if h.xplusKey != nil { 176 packetConn = hysteria.NewXPlusPacketConn(packetConn, h.xplusKey) 177 } 178 packetConn = &hysteria.PacketConnWrapper{PacketConn: packetConn} 179 quicConn, err := quic.Dial(packetConn, udpConn.RemoteAddr(), h.serverAddr.AddrString(), h.tlsConfig, h.quicConfig) 180 if err != nil { 181 packetConn.Close() 182 return nil, err 183 } 184 controlStream, err := quicConn.OpenStreamSync(ctx) 185 if err != nil { 186 packetConn.Close() 187 return nil, err 188 } 189 err = hysteria.WriteClientHello(controlStream, hysteria.ClientHello{ 190 SendBPS: h.sendBPS, 191 RecvBPS: h.recvBPS, 192 Auth: h.authKey, 193 }) 194 if err != nil { 195 packetConn.Close() 196 return nil, err 197 } 198 serverHello, err := hysteria.ReadServerHello(controlStream) 199 if err != nil { 200 packetConn.Close() 201 return nil, err 202 } 203 if !serverHello.OK { 204 packetConn.Close() 205 return nil, E.New("remote error: ", serverHello.Message) 206 } 207 quicConn.SetCongestionControl(hysteria.NewBrutalSender(congestion.ByteCount(serverHello.RecvBPS))) 208 h.conn = quicConn 209 h.rawConn = udpConn 210 return quicConn, nil 211 } 212 213 func (h *Hysteria) udpRecvLoop(conn quic.Connection) { 214 for { 215 packet, err := conn.ReceiveMessage() 216 if err != nil { 217 return 218 } 219 message, err := hysteria.ParseUDPMessage(packet) 220 if err != nil { 221 h.logger.Error("parse udp message: ", err) 222 continue 223 } 224 dfMsg := h.udpDefragger.Feed(message) 225 if dfMsg == nil { 226 continue 227 } 228 h.udpAccess.RLock() 229 ch, ok := h.udpSessions[dfMsg.SessionID] 230 if ok { 231 select { 232 case ch <- dfMsg: 233 // OK 234 default: 235 // Silently drop the message when the channel is full 236 } 237 } 238 h.udpAccess.RUnlock() 239 } 240 } 241 242 func (h *Hysteria) InterfaceUpdated() error { 243 h.Close() 244 return nil 245 } 246 247 func (h *Hysteria) Close() error { 248 h.connAccess.Lock() 249 defer h.connAccess.Unlock() 250 h.udpAccess.Lock() 251 defer h.udpAccess.Unlock() 252 if h.conn != nil { 253 h.conn.CloseWithError(0, "") 254 h.rawConn.Close() 255 } 256 for _, session := range h.udpSessions { 257 close(session) 258 } 259 h.udpSessions = make(map[uint32]chan *hysteria.UDPMessage) 260 return nil 261 } 262 263 func (h *Hysteria) open(ctx context.Context) (quic.Connection, quic.Stream, error) { 264 conn, err := h.offer(ctx) 265 if err != nil { 266 return nil, nil, err 267 } 268 stream, err := conn.OpenStream() 269 if err != nil { 270 return nil, nil, err 271 } 272 return conn, &hysteria.StreamWrapper{Stream: stream}, nil 273 } 274 275 func (h *Hysteria) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { 276 switch N.NetworkName(network) { 277 case N.NetworkTCP: 278 h.logger.InfoContext(ctx, "outbound connection to ", destination) 279 _, stream, err := h.open(ctx) 280 if err != nil { 281 return nil, err 282 } 283 err = hysteria.WriteClientRequest(stream, hysteria.ClientRequest{ 284 Host: destination.AddrString(), 285 Port: destination.Port, 286 }) 287 if err != nil { 288 stream.Close() 289 return nil, err 290 } 291 return hysteria.NewConn(stream, destination, true), nil 292 case N.NetworkUDP: 293 conn, err := h.ListenPacket(ctx, destination) 294 if err != nil { 295 return nil, err 296 } 297 return conn.(*hysteria.PacketConn), nil 298 default: 299 return nil, E.New("unsupported network: ", network) 300 } 301 } 302 303 func (h *Hysteria) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { 304 h.logger.InfoContext(ctx, "outbound packet connection to ", destination) 305 conn, stream, err := h.open(ctx) 306 if err != nil { 307 return nil, err 308 } 309 err = hysteria.WriteClientRequest(stream, hysteria.ClientRequest{ 310 UDP: true, 311 Host: destination.AddrString(), 312 Port: destination.Port, 313 }) 314 if err != nil { 315 stream.Close() 316 return nil, err 317 } 318 var response *hysteria.ServerResponse 319 response, err = hysteria.ReadServerResponse(stream) 320 if err != nil { 321 stream.Close() 322 return nil, err 323 } 324 if !response.OK { 325 stream.Close() 326 return nil, E.New("remote error: ", response.Message) 327 } 328 h.udpAccess.Lock() 329 nCh := make(chan *hysteria.UDPMessage, 1024) 330 h.udpSessions[response.UDPSessionID] = nCh 331 h.udpAccess.Unlock() 332 packetConn := hysteria.NewPacketConn(conn, stream, response.UDPSessionID, destination, nCh, common.Closer(func() error { 333 h.udpAccess.Lock() 334 if ch, ok := h.udpSessions[response.UDPSessionID]; ok { 335 close(ch) 336 delete(h.udpSessions, response.UDPSessionID) 337 } 338 h.udpAccess.Unlock() 339 return nil 340 })) 341 go packetConn.Hold() 342 return packetConn, nil 343 } 344 345 func (h *Hysteria) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { 346 return NewConnection(ctx, h, conn, metadata) 347 } 348 349 func (h *Hysteria) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { 350 return NewPacketConnection(ctx, h, conn, metadata) 351 }