github.com/sagernet/sing-box@v1.2.7/inbound/naive.go (about) 1 package inbound 2 3 import ( 4 "context" 5 "encoding/base64" 6 "encoding/binary" 7 "io" 8 "math/rand" 9 "net" 10 "net/http" 11 "os" 12 "strings" 13 "time" 14 15 "github.com/sagernet/sing-box/adapter" 16 "github.com/sagernet/sing-box/common/tls" 17 C "github.com/sagernet/sing-box/constant" 18 "github.com/sagernet/sing-box/include" 19 "github.com/sagernet/sing-box/log" 20 "github.com/sagernet/sing-box/option" 21 "github.com/sagernet/sing/common" 22 "github.com/sagernet/sing/common/auth" 23 "github.com/sagernet/sing/common/buf" 24 E "github.com/sagernet/sing/common/exceptions" 25 M "github.com/sagernet/sing/common/metadata" 26 N "github.com/sagernet/sing/common/network" 27 "github.com/sagernet/sing/common/rw" 28 sHttp "github.com/sagernet/sing/protocol/http" 29 ) 30 31 var _ adapter.Inbound = (*Naive)(nil) 32 33 type Naive struct { 34 myInboundAdapter 35 authenticator auth.Authenticator 36 tlsConfig tls.ServerConfig 37 httpServer *http.Server 38 h3Server any 39 } 40 41 func NewNaive(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NaiveInboundOptions) (*Naive, error) { 42 inbound := &Naive{ 43 myInboundAdapter: myInboundAdapter{ 44 protocol: C.TypeNaive, 45 network: options.Network.Build(), 46 ctx: ctx, 47 router: router, 48 logger: logger, 49 tag: tag, 50 listenOptions: options.ListenOptions, 51 }, 52 authenticator: auth.NewAuthenticator(options.Users), 53 } 54 if common.Contains(inbound.network, N.NetworkUDP) { 55 if options.TLS == nil || !options.TLS.Enabled { 56 return nil, E.New("TLS is required for QUIC server") 57 } 58 } 59 if len(options.Users) == 0 { 60 return nil, E.New("missing users") 61 } 62 if options.TLS != nil { 63 tlsConfig, err := tls.NewServer(ctx, router, logger, common.PtrValueOrDefault(options.TLS)) 64 if err != nil { 65 return nil, err 66 } 67 inbound.tlsConfig = tlsConfig 68 } 69 return inbound, nil 70 } 71 72 func (n *Naive) Start() error { 73 var tlsConfig *tls.STDConfig 74 if n.tlsConfig != nil { 75 err := n.tlsConfig.Start() 76 if err != nil { 77 return E.Cause(err, "create TLS config") 78 } 79 tlsConfig, err = n.tlsConfig.Config() 80 if err != nil { 81 return err 82 } 83 } 84 85 if common.Contains(n.network, N.NetworkTCP) { 86 tcpListener, err := n.ListenTCP() 87 if err != nil { 88 return err 89 } 90 n.httpServer = &http.Server{ 91 Handler: n, 92 TLSConfig: tlsConfig, 93 BaseContext: func(listener net.Listener) context.Context { 94 return n.ctx 95 }, 96 } 97 go func() { 98 var sErr error 99 if tlsConfig != nil { 100 sErr = n.httpServer.ServeTLS(tcpListener, "", "") 101 } else { 102 sErr = n.httpServer.Serve(tcpListener) 103 } 104 if sErr != nil && !E.IsClosedOrCanceled(sErr) { 105 n.logger.Error("http server serve error: ", sErr) 106 } 107 }() 108 } 109 110 if common.Contains(n.network, N.NetworkUDP) { 111 err := n.configureHTTP3Listener() 112 if !include.WithQUIC && len(n.network) > 1 { 113 log.Warn(E.Cause(err, "naive http3 disabled")) 114 } else if err != nil { 115 return err 116 } 117 } 118 119 return nil 120 } 121 122 func (n *Naive) Close() error { 123 return common.Close( 124 &n.myInboundAdapter, 125 common.PtrOrNil(n.httpServer), 126 n.h3Server, 127 n.tlsConfig, 128 ) 129 } 130 131 func (n *Naive) ServeHTTP(writer http.ResponseWriter, request *http.Request) { 132 ctx := log.ContextWithNewID(request.Context()) 133 if request.Method != "CONNECT" { 134 rejectHTTP(writer, http.StatusBadRequest) 135 n.badRequest(ctx, request, E.New("not CONNECT request")) 136 return 137 } else if request.Header.Get("Padding") == "" { 138 rejectHTTP(writer, http.StatusBadRequest) 139 n.badRequest(ctx, request, E.New("missing naive padding")) 140 return 141 } 142 var authOk bool 143 var userName string 144 authorization := request.Header.Get("Proxy-Authorization") 145 if strings.HasPrefix(authorization, "BASIC ") || strings.HasPrefix(authorization, "Basic ") { 146 userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:]) 147 userPswdArr := strings.SplitN(string(userPassword), ":", 2) 148 userName = userPswdArr[0] 149 authOk = n.authenticator.Verify(userPswdArr[0], userPswdArr[1]) 150 } 151 if !authOk { 152 rejectHTTP(writer, http.StatusProxyAuthRequired) 153 n.badRequest(ctx, request, E.New("authorization failed")) 154 return 155 } 156 writer.Header().Set("Padding", generateNaivePaddingHeader()) 157 writer.WriteHeader(http.StatusOK) 158 writer.(http.Flusher).Flush() 159 160 hostPort := request.URL.Host 161 if hostPort == "" { 162 hostPort = request.Host 163 } 164 source := sHttp.SourceAddress(request) 165 destination := M.ParseSocksaddr(hostPort) 166 167 if hijacker, isHijacker := writer.(http.Hijacker); isHijacker { 168 conn, _, err := hijacker.Hijack() 169 if err != nil { 170 n.badRequest(ctx, request, E.New("hijack failed")) 171 return 172 } 173 n.newConnection(ctx, &naiveH1Conn{Conn: conn}, userName, source, destination) 174 } else { 175 n.newConnection(ctx, &naiveH2Conn{reader: request.Body, writer: writer, flusher: writer.(http.Flusher)}, userName, source, destination) 176 } 177 } 178 179 func (n *Naive) newConnection(ctx context.Context, conn net.Conn, userName string, source, destination M.Socksaddr) { 180 if userName != "" { 181 n.logger.InfoContext(ctx, "[", userName, "] inbound connection from ", source) 182 n.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", destination) 183 } else { 184 n.logger.InfoContext(ctx, "inbound connection from ", source) 185 n.logger.InfoContext(ctx, "inbound connection to ", destination) 186 } 187 hErr := n.router.RouteConnection(ctx, conn, n.createMetadata(conn, adapter.InboundContext{ 188 Source: source, 189 Destination: destination, 190 User: userName, 191 })) 192 if hErr != nil { 193 conn.Close() 194 n.NewError(ctx, E.Cause(hErr, "process connection from ", source)) 195 } 196 } 197 198 func (n *Naive) badRequest(ctx context.Context, request *http.Request, err error) { 199 n.NewError(ctx, E.Cause(err, "process connection from ", request.RemoteAddr)) 200 } 201 202 func rejectHTTP(writer http.ResponseWriter, statusCode int) { 203 hijacker, ok := writer.(http.Hijacker) 204 if !ok { 205 writer.WriteHeader(statusCode) 206 return 207 } 208 conn, _, err := hijacker.Hijack() 209 if err != nil { 210 writer.WriteHeader(statusCode) 211 return 212 } 213 if tcpConn, isTCP := common.Cast[*net.TCPConn](conn); isTCP { 214 tcpConn.SetLinger(0) 215 } 216 conn.Close() 217 } 218 219 func generateNaivePaddingHeader() string { 220 paddingLen := rand.Intn(32) + 30 221 padding := make([]byte, paddingLen) 222 bits := rand.Uint64() 223 for i := 0; i < 16; i++ { 224 // Codes that won't be Huffman coded. 225 padding[i] = "!#$()+<>?@[]^`{}"[bits&15] 226 bits >>= 4 227 } 228 for i := 16; i < paddingLen; i++ { 229 padding[i] = '~' 230 } 231 return string(padding) 232 } 233 234 const kFirstPaddings = 8 235 236 type naiveH1Conn struct { 237 net.Conn 238 readPadding int 239 writePadding int 240 readRemaining int 241 paddingRemaining int 242 } 243 244 func (c *naiveH1Conn) Read(p []byte) (n int, err error) { 245 n, err = c.read(p) 246 return n, wrapHttpError(err) 247 } 248 249 func (c *naiveH1Conn) read(p []byte) (n int, err error) { 250 if c.readRemaining > 0 { 251 if len(p) > c.readRemaining { 252 p = p[:c.readRemaining] 253 } 254 n, err = c.Conn.Read(p) 255 if err != nil { 256 return 257 } 258 c.readRemaining -= n 259 return 260 } 261 if c.paddingRemaining > 0 { 262 err = rw.SkipN(c.Conn, c.paddingRemaining) 263 if err != nil { 264 return 265 } 266 c.paddingRemaining = 0 267 } 268 if c.readPadding < kFirstPaddings { 269 var paddingHdr []byte 270 if len(p) >= 3 { 271 paddingHdr = p[:3] 272 } else { 273 _paddingHdr := make([]byte, 3) 274 defer common.KeepAlive(_paddingHdr) 275 paddingHdr = common.Dup(_paddingHdr) 276 } 277 _, err = io.ReadFull(c.Conn, paddingHdr) 278 if err != nil { 279 return 280 } 281 originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2])) 282 paddingSize := int(paddingHdr[2]) 283 if len(p) > originalDataSize { 284 p = p[:originalDataSize] 285 } 286 n, err = c.Conn.Read(p) 287 if err != nil { 288 return 289 } 290 c.readPadding++ 291 c.readRemaining = originalDataSize - n 292 c.paddingRemaining = paddingSize 293 return 294 } 295 return c.Conn.Read(p) 296 } 297 298 func (c *naiveH1Conn) Write(p []byte) (n int, err error) { 299 for pLen := len(p); pLen > 0; { 300 var data []byte 301 if pLen > 65535 { 302 data = p[:65535] 303 p = p[65535:] 304 pLen -= 65535 305 } else { 306 data = p 307 pLen = 0 308 } 309 var writeN int 310 writeN, err = c.write(data) 311 n += writeN 312 if err != nil { 313 break 314 } 315 } 316 return n, wrapHttpError(err) 317 } 318 319 func (c *naiveH1Conn) write(p []byte) (n int, err error) { 320 if c.writePadding < kFirstPaddings { 321 paddingSize := rand.Intn(256) 322 323 _buffer := buf.StackNewSize(3 + len(p) + paddingSize) 324 defer common.KeepAlive(_buffer) 325 buffer := common.Dup(_buffer) 326 defer buffer.Release() 327 header := buffer.Extend(3) 328 binary.BigEndian.PutUint16(header, uint16(len(p))) 329 header[2] = byte(paddingSize) 330 331 common.Must1(buffer.Write(p)) 332 _, err = c.Conn.Write(buffer.Bytes()) 333 if err == nil { 334 n = len(p) 335 } 336 c.writePadding++ 337 return 338 } 339 return c.Conn.Write(p) 340 } 341 342 func (c *naiveH1Conn) FrontHeadroom() int { 343 if c.writePadding < kFirstPaddings { 344 return 3 345 } 346 return 0 347 } 348 349 func (c *naiveH1Conn) RearHeadroom() int { 350 if c.writePadding < kFirstPaddings { 351 return 255 352 } 353 return 0 354 } 355 356 func (c *naiveH1Conn) WriterMTU() int { 357 if c.writePadding < kFirstPaddings { 358 return 65535 359 } 360 return 0 361 } 362 363 func (c *naiveH1Conn) WriteBuffer(buffer *buf.Buffer) error { 364 defer buffer.Release() 365 if c.writePadding < kFirstPaddings { 366 bufferLen := buffer.Len() 367 if bufferLen > 65535 { 368 return common.Error(c.Write(buffer.Bytes())) 369 } 370 paddingSize := rand.Intn(256) 371 header := buffer.ExtendHeader(3) 372 binary.BigEndian.PutUint16(header, uint16(bufferLen)) 373 header[2] = byte(paddingSize) 374 buffer.Extend(paddingSize) 375 c.writePadding++ 376 } 377 return wrapHttpError(common.Error(c.Conn.Write(buffer.Bytes()))) 378 } 379 380 // FIXME 381 /*func (c *naiveH1Conn) WriteTo(w io.Writer) (n int64, err error) { 382 if c.readPadding < kFirstPaddings { 383 n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding) 384 } else { 385 n, err = bufio.Copy(w, c.Conn) 386 } 387 return n, wrapHttpError(err) 388 } 389 390 func (c *naiveH1Conn) ReadFrom(r io.Reader) (n int64, err error) { 391 if c.writePadding < kFirstPaddings { 392 n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding) 393 } else { 394 n, err = bufio.Copy(c.Conn, r) 395 } 396 return n, wrapHttpError(err) 397 } 398 */ 399 400 func (c *naiveH1Conn) Upstream() any { 401 return c.Conn 402 } 403 404 func (c *naiveH1Conn) ReaderReplaceable() bool { 405 return c.readPadding == kFirstPaddings 406 } 407 408 func (c *naiveH1Conn) WriterReplaceable() bool { 409 return c.writePadding == kFirstPaddings 410 } 411 412 type naiveH2Conn struct { 413 reader io.Reader 414 writer io.Writer 415 flusher http.Flusher 416 rAddr net.Addr 417 readPadding int 418 writePadding int 419 readRemaining int 420 paddingRemaining int 421 } 422 423 func (c *naiveH2Conn) Read(p []byte) (n int, err error) { 424 n, err = c.read(p) 425 return n, wrapHttpError(err) 426 } 427 428 func (c *naiveH2Conn) read(p []byte) (n int, err error) { 429 if c.readRemaining > 0 { 430 if len(p) > c.readRemaining { 431 p = p[:c.readRemaining] 432 } 433 n, err = c.reader.Read(p) 434 if err != nil { 435 return 436 } 437 c.readRemaining -= n 438 return 439 } 440 if c.paddingRemaining > 0 { 441 err = rw.SkipN(c.reader, c.paddingRemaining) 442 if err != nil { 443 return 444 } 445 c.paddingRemaining = 0 446 } 447 if c.readPadding < kFirstPaddings { 448 var paddingHdr []byte 449 if len(p) >= 3 { 450 paddingHdr = p[:3] 451 } else { 452 _paddingHdr := make([]byte, 3) 453 defer common.KeepAlive(_paddingHdr) 454 paddingHdr = common.Dup(_paddingHdr) 455 } 456 _, err = io.ReadFull(c.reader, paddingHdr) 457 if err != nil { 458 return 459 } 460 originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2])) 461 paddingSize := int(paddingHdr[2]) 462 if len(p) > originalDataSize { 463 p = p[:originalDataSize] 464 } 465 n, err = c.reader.Read(p) 466 if err != nil { 467 return 468 } 469 c.readPadding++ 470 c.readRemaining = originalDataSize - n 471 c.paddingRemaining = paddingSize 472 return 473 } 474 return c.reader.Read(p) 475 } 476 477 func (c *naiveH2Conn) Write(p []byte) (n int, err error) { 478 for pLen := len(p); pLen > 0; { 479 var data []byte 480 if pLen > 65535 { 481 data = p[:65535] 482 p = p[65535:] 483 pLen -= 65535 484 } else { 485 data = p 486 pLen = 0 487 } 488 var writeN int 489 writeN, err = c.write(data) 490 n += writeN 491 if err != nil { 492 break 493 } 494 } 495 if err == nil { 496 c.flusher.Flush() 497 } 498 return n, wrapHttpError(err) 499 } 500 501 func (c *naiveH2Conn) write(p []byte) (n int, err error) { 502 if c.writePadding < kFirstPaddings { 503 paddingSize := rand.Intn(256) 504 505 _buffer := buf.StackNewSize(3 + len(p) + paddingSize) 506 defer common.KeepAlive(_buffer) 507 buffer := common.Dup(_buffer) 508 defer buffer.Release() 509 header := buffer.Extend(3) 510 binary.BigEndian.PutUint16(header, uint16(len(p))) 511 header[2] = byte(paddingSize) 512 513 common.Must1(buffer.Write(p)) 514 _, err = c.writer.Write(buffer.Bytes()) 515 if err == nil { 516 n = len(p) 517 } 518 c.writePadding++ 519 return 520 } 521 return c.writer.Write(p) 522 } 523 524 func (c *naiveH2Conn) FrontHeadroom() int { 525 if c.writePadding < kFirstPaddings { 526 return 3 527 } 528 return 0 529 } 530 531 func (c *naiveH2Conn) RearHeadroom() int { 532 if c.writePadding < kFirstPaddings { 533 return 255 534 } 535 return 0 536 } 537 538 func (c *naiveH2Conn) WriterMTU() int { 539 if c.writePadding < kFirstPaddings { 540 return 65535 541 } 542 return 0 543 } 544 545 func (c *naiveH2Conn) WriteBuffer(buffer *buf.Buffer) error { 546 defer buffer.Release() 547 if c.writePadding < kFirstPaddings { 548 bufferLen := buffer.Len() 549 if bufferLen > 65535 { 550 return common.Error(c.Write(buffer.Bytes())) 551 } 552 paddingSize := rand.Intn(256) 553 header := buffer.ExtendHeader(3) 554 binary.BigEndian.PutUint16(header, uint16(bufferLen)) 555 header[2] = byte(paddingSize) 556 buffer.Extend(paddingSize) 557 c.writePadding++ 558 } 559 err := common.Error(c.writer.Write(buffer.Bytes())) 560 if err == nil { 561 c.flusher.Flush() 562 } 563 return wrapHttpError(err) 564 } 565 566 // FIXME 567 /*func (c *naiveH2Conn) WriteTo(w io.Writer) (n int64, err error) { 568 if c.readPadding < kFirstPaddings { 569 n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding) 570 } else { 571 n, err = bufio.Copy(w, c.reader) 572 } 573 return n, wrapHttpError(err) 574 } 575 576 func (c *naiveH2Conn) ReadFrom(r io.Reader) (n int64, err error) { 577 if c.writePadding < kFirstPaddings { 578 n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding) 579 } else { 580 n, err = bufio.Copy(c.writer, r) 581 } 582 return n, wrapHttpError(err) 583 }*/ 584 585 func (c *naiveH2Conn) Close() error { 586 return common.Close( 587 c.reader, 588 c.writer, 589 ) 590 } 591 592 func (c *naiveH2Conn) LocalAddr() net.Addr { 593 return nil 594 } 595 596 func (c *naiveH2Conn) RemoteAddr() net.Addr { 597 return c.rAddr 598 } 599 600 func (c *naiveH2Conn) SetDeadline(t time.Time) error { 601 return os.ErrInvalid 602 } 603 604 func (c *naiveH2Conn) SetReadDeadline(t time.Time) error { 605 return os.ErrInvalid 606 } 607 608 func (c *naiveH2Conn) SetWriteDeadline(t time.Time) error { 609 return os.ErrInvalid 610 } 611 612 func (c *naiveH2Conn) NeedAdditionalReadDeadline() bool { 613 return true 614 } 615 616 func (c *naiveH2Conn) UpstreamReader() any { 617 return c.reader 618 } 619 620 func (c *naiveH2Conn) UpstreamWriter() any { 621 return c.writer 622 } 623 624 func (c *naiveH2Conn) ReaderReplaceable() bool { 625 return c.readPadding == kFirstPaddings 626 } 627 628 func (c *naiveH2Conn) WriterReplaceable() bool { 629 return c.writePadding == kFirstPaddings 630 } 631 632 func wrapHttpError(err error) error { 633 if err == nil { 634 return err 635 } 636 if strings.Contains(err.Error(), "client disconnected") { 637 return net.ErrClosed 638 } 639 if strings.Contains(err.Error(), "body closed by handler") { 640 return net.ErrClosed 641 } 642 if strings.Contains(err.Error(), "canceled with error code 268") { 643 return io.EOF 644 } 645 return err 646 }