github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/hysteria2/server.go (about) 1 package hysteria2 2 3 import ( 4 "context" 5 "github.com/sagernet/quic-go" 6 "io" 7 "net" 8 "net/http" 9 "os" 10 "runtime" 11 "strings" 12 "sync" 13 14 "github.com/inazumav/sing-box/common/baderror" 15 "github.com/inazumav/sing-box/common/qtls" 16 "github.com/inazumav/sing-box/common/tls" 17 "github.com/inazumav/sing-box/transport/hysteria2/congestion" 18 "github.com/inazumav/sing-box/transport/hysteria2/internal/protocol" 19 tuicCongestion "github.com/inazumav/sing-box/transport/tuic/congestion" 20 "github.com/sagernet/quic-go/http3" 21 "github.com/sagernet/sing/common" 22 "github.com/sagernet/sing/common/auth" 23 E "github.com/sagernet/sing/common/exceptions" 24 "github.com/sagernet/sing/common/logger" 25 M "github.com/sagernet/sing/common/metadata" 26 N "github.com/sagernet/sing/common/network" 27 ) 28 29 type ServerOptions struct { 30 Context context.Context 31 Logger logger.Logger 32 SendBPS uint64 33 ReceiveBPS uint64 34 IgnoreClientBandwidth bool 35 SalamanderPassword string 36 TLSConfig tls.ServerConfig 37 Users []User 38 UDPDisabled bool 39 Handler ServerHandler 40 MasqueradeHandler http.Handler 41 } 42 43 type User struct { 44 Name string 45 Password string 46 } 47 48 type ServerHandler interface { 49 N.TCPConnectionHandler 50 N.UDPConnectionHandler 51 } 52 53 type Server struct { 54 ctx context.Context 55 logger logger.Logger 56 sendBPS uint64 57 receiveBPS uint64 58 ignoreClientBandwidth bool 59 salamanderPassword string 60 tlsConfig tls.ServerConfig 61 quicConfig *quic.Config 62 userMap map[string]User 63 udpDisabled bool 64 handler ServerHandler 65 masqueradeHandler http.Handler 66 quicListener io.Closer 67 } 68 69 func NewServer(options ServerOptions) (*Server, error) { 70 quicConfig := &quic.Config{ 71 DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), 72 MaxDatagramFrameSize: 1400, 73 EnableDatagrams: !options.UDPDisabled, 74 MaxIncomingStreams: 1 << 60, 75 InitialStreamReceiveWindow: defaultStreamReceiveWindow, 76 MaxStreamReceiveWindow: defaultStreamReceiveWindow, 77 InitialConnectionReceiveWindow: defaultConnReceiveWindow, 78 MaxConnectionReceiveWindow: defaultConnReceiveWindow, 79 MaxIdleTimeout: defaultMaxIdleTimeout, 80 KeepAlivePeriod: defaultKeepAlivePeriod, 81 } 82 if len(options.Users) == 0 { 83 return nil, E.New("missing users") 84 } 85 userMap := make(map[string]User) 86 for _, user := range options.Users { 87 userMap[user.Password] = user 88 } 89 if options.MasqueradeHandler == nil { 90 options.MasqueradeHandler = http.NotFoundHandler() 91 } 92 return &Server{ 93 ctx: options.Context, 94 logger: options.Logger, 95 sendBPS: options.SendBPS, 96 receiveBPS: options.ReceiveBPS, 97 ignoreClientBandwidth: options.IgnoreClientBandwidth, 98 salamanderPassword: options.SalamanderPassword, 99 tlsConfig: options.TLSConfig, 100 quicConfig: quicConfig, 101 userMap: userMap, 102 udpDisabled: options.UDPDisabled, 103 handler: options.Handler, 104 masqueradeHandler: options.MasqueradeHandler, 105 }, nil 106 } 107 108 func (s *Server) Start(conn net.PacketConn) error { 109 if s.salamanderPassword != "" { 110 conn = NewSalamanderConn(conn, []byte(s.salamanderPassword)) 111 } 112 err := qtls.ConfigureHTTP3(s.tlsConfig) 113 if err != nil { 114 return err 115 } 116 listener, err := qtls.Listen(conn, s.tlsConfig, s.quicConfig) 117 if err != nil { 118 return err 119 } 120 s.quicListener = listener 121 go s.loopConnections(listener) 122 return nil 123 } 124 125 func (s *Server) Close() error { 126 return common.Close( 127 s.quicListener, 128 ) 129 } 130 131 func (s *Server) loopConnections(listener qtls.QUICListener) { 132 for { 133 connection, err := listener.Accept(s.ctx) 134 if err != nil { 135 if strings.Contains(err.Error(), "server closed") { 136 s.logger.Debug(E.Cause(err, "listener closed")) 137 } else { 138 s.logger.Error(E.Cause(err, "listener closed")) 139 } 140 return 141 } 142 go s.handleConnection(connection) 143 } 144 } 145 146 func (s *Server) handleConnection(connection quic.Connection) { 147 session := &serverSession{ 148 Server: s, 149 ctx: s.ctx, 150 quicConn: connection, 151 source: M.SocksaddrFromNet(connection.RemoteAddr()), 152 connDone: make(chan struct{}), 153 udpConnMap: make(map[uint32]*udpPacketConn), 154 } 155 httpServer := http3.Server{ 156 Handler: session, 157 StreamHijacker: session.handleStream0, 158 } 159 _ = httpServer.ServeQUICConn(connection) 160 _ = connection.CloseWithError(0, "") 161 } 162 163 type serverSession struct { 164 *Server 165 ctx context.Context 166 quicConn quic.Connection 167 source M.Socksaddr 168 connAccess sync.Mutex 169 connDone chan struct{} 170 connErr error 171 authenticated bool 172 authUser *User 173 udpAccess sync.RWMutex 174 udpConnMap map[uint32]*udpPacketConn 175 } 176 177 func (s *serverSession) ServeHTTP(w http.ResponseWriter, r *http.Request) { 178 if r.Method == http.MethodPost && r.Host == protocol.URLHost && r.URL.Path == protocol.URLPath { 179 if s.authenticated { 180 protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{ 181 UDPEnabled: !s.udpDisabled, 182 Rx: s.receiveBPS, 183 RxAuto: s.ignoreClientBandwidth, 184 }) 185 w.WriteHeader(protocol.StatusAuthOK) 186 return 187 } 188 request := protocol.AuthRequestFromHeader(r.Header) 189 user, loaded := s.userMap[request.Auth] 190 if !loaded { 191 s.masqueradeHandler.ServeHTTP(w, r) 192 return 193 } 194 s.authUser = &user 195 s.authenticated = true 196 if !s.ignoreClientBandwidth && request.Rx > 0 { 197 var sendBps uint64 198 if s.sendBPS > 0 && s.sendBPS < request.Rx { 199 sendBps = s.sendBPS 200 } else { 201 sendBps = request.Rx 202 } 203 s.quicConn.SetCongestionControl(congestion.NewBrutalSender(sendBps)) 204 } else { 205 s.quicConn.SetCongestionControl(tuicCongestion.NewBBRSender( 206 tuicCongestion.DefaultClock{}, 207 tuicCongestion.GetInitialPacketSize(s.quicConn.RemoteAddr()), 208 tuicCongestion.InitialCongestionWindow*tuicCongestion.InitialMaxDatagramSize, 209 tuicCongestion.DefaultBBRMaxCongestionWindow*tuicCongestion.InitialMaxDatagramSize, 210 )) 211 } 212 protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{ 213 UDPEnabled: !s.udpDisabled, 214 Rx: s.receiveBPS, 215 RxAuto: s.ignoreClientBandwidth, 216 }) 217 w.WriteHeader(protocol.StatusAuthOK) 218 if s.ctx.Done() != nil { 219 go func() { 220 select { 221 case <-s.ctx.Done(): 222 s.closeWithError(s.ctx.Err()) 223 case <-s.connDone: 224 } 225 }() 226 } 227 if !s.udpDisabled { 228 go s.loopMessages() 229 } 230 } else { 231 s.masqueradeHandler.ServeHTTP(w, r) 232 } 233 } 234 235 func (s *serverSession) handleStream0(frameType http3.FrameType, connection quic.Connection, stream quic.Stream, err error) (bool, error) { 236 if !s.authenticated || err != nil { 237 return false, nil 238 } 239 if frameType != protocol.FrameTypeTCPRequest { 240 return false, nil 241 } 242 go func() { 243 hErr := s.handleStream(stream) 244 if hErr != nil { 245 stream.CancelRead(0) 246 stream.Close() 247 s.logger.Error(E.Cause(hErr, "handle stream request")) 248 } 249 }() 250 return true, nil 251 } 252 253 func (s *serverSession) handleStream(stream quic.Stream) error { 254 destinationString, err := protocol.ReadTCPRequest(stream) 255 if err != nil { 256 return E.New("read TCP request") 257 } 258 var conn net.Conn = &serverConn{ 259 Stream: stream, 260 } 261 ctx := s.ctx 262 if s.authUser.Name != "" { 263 ctx = auth.ContextWithUser(s.ctx, s.authUser.Name) 264 } 265 _ = s.handler.NewConnection(ctx, conn, M.Metadata{ 266 Source: s.source, 267 Destination: M.ParseSocksaddr(destinationString), 268 }) 269 return nil 270 } 271 272 func (s *serverSession) closeWithError(err error) { 273 s.connAccess.Lock() 274 defer s.connAccess.Unlock() 275 select { 276 case <-s.connDone: 277 return 278 default: 279 s.connErr = err 280 close(s.connDone) 281 } 282 if E.IsClosedOrCanceled(err) { 283 s.logger.Debug(E.Cause(err, "connection failed")) 284 } else { 285 s.logger.Error(E.Cause(err, "connection failed")) 286 } 287 _ = s.quicConn.CloseWithError(0, "") 288 } 289 290 type serverConn struct { 291 quic.Stream 292 responseWritten bool 293 } 294 295 func (c *serverConn) HandshakeFailure(err error) error { 296 if c.responseWritten { 297 return os.ErrClosed 298 } 299 c.responseWritten = true 300 buffer := protocol.WriteTCPResponse(false, err.Error(), nil) 301 defer buffer.Release() 302 return common.Error(c.Stream.Write(buffer.Bytes())) 303 } 304 305 func (c *serverConn) Read(p []byte) (n int, err error) { 306 n, err = c.Stream.Read(p) 307 return n, baderror.WrapQUIC(err) 308 } 309 310 func (c *serverConn) Write(p []byte) (n int, err error) { 311 if !c.responseWritten { 312 c.responseWritten = true 313 buffer := protocol.WriteTCPResponse(true, "", p) 314 defer buffer.Release() 315 _, err = c.Stream.Write(buffer.Bytes()) 316 if err != nil { 317 return 0, baderror.WrapQUIC(err) 318 } 319 return len(p), nil 320 } 321 n, err = c.Stream.Write(p) 322 return n, baderror.WrapQUIC(err) 323 } 324 325 func (c *serverConn) LocalAddr() net.Addr { 326 return M.Socksaddr{} 327 } 328 329 func (c *serverConn) RemoteAddr() net.Addr { 330 return M.Socksaddr{} 331 } 332 333 func (c *serverConn) Close() error { 334 c.Stream.CancelRead(0) 335 return c.Stream.Close() 336 }