github.com/sagernet/sing-box@v1.9.0-rc.20/inbound/naive.go (about)

     1  package inbound
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"io"
     7  	"math/rand"
     8  	"net"
     9  	"net/http"
    10  	"os"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/sagernet/sing-box/adapter"
    15  	"github.com/sagernet/sing-box/common/tls"
    16  	"github.com/sagernet/sing-box/common/uot"
    17  	C "github.com/sagernet/sing-box/constant"
    18  	"github.com/sagernet/sing-box/log"
    19  	"github.com/sagernet/sing-box/option"
    20  	"github.com/sagernet/sing/common"
    21  	"github.com/sagernet/sing/common/auth"
    22  	"github.com/sagernet/sing/common/buf"
    23  	E "github.com/sagernet/sing/common/exceptions"
    24  	M "github.com/sagernet/sing/common/metadata"
    25  	N "github.com/sagernet/sing/common/network"
    26  	"github.com/sagernet/sing/common/rw"
    27  	sHttp "github.com/sagernet/sing/protocol/http"
    28  )
    29  
    30  var _ adapter.Inbound = (*Naive)(nil)
    31  
    32  type Naive struct {
    33  	myInboundAdapter
    34  	authenticator *auth.Authenticator
    35  	tlsConfig     tls.ServerConfig
    36  	httpServer    *http.Server
    37  	h3Server      any
    38  }
    39  
    40  func NewNaive(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NaiveInboundOptions) (*Naive, error) {
    41  	inbound := &Naive{
    42  		myInboundAdapter: myInboundAdapter{
    43  			protocol:      C.TypeNaive,
    44  			network:       options.Network.Build(),
    45  			ctx:           ctx,
    46  			router:        uot.NewRouter(router, logger),
    47  			logger:        logger,
    48  			tag:           tag,
    49  			listenOptions: options.ListenOptions,
    50  		},
    51  		authenticator: auth.NewAuthenticator(options.Users),
    52  	}
    53  	if common.Contains(inbound.network, N.NetworkUDP) {
    54  		if options.TLS == nil || !options.TLS.Enabled {
    55  			return nil, E.New("TLS is required for QUIC server")
    56  		}
    57  	}
    58  	if len(options.Users) == 0 {
    59  		return nil, E.New("missing users")
    60  	}
    61  	if options.TLS != nil {
    62  		tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
    63  		if err != nil {
    64  			return nil, err
    65  		}
    66  		inbound.tlsConfig = tlsConfig
    67  	}
    68  	return inbound, nil
    69  }
    70  
    71  func (n *Naive) Start() error {
    72  	var tlsConfig *tls.STDConfig
    73  	if n.tlsConfig != nil {
    74  		err := n.tlsConfig.Start()
    75  		if err != nil {
    76  			return E.Cause(err, "create TLS config")
    77  		}
    78  		tlsConfig, err = n.tlsConfig.Config()
    79  		if err != nil {
    80  			return err
    81  		}
    82  	}
    83  
    84  	if common.Contains(n.network, N.NetworkTCP) {
    85  		tcpListener, err := n.ListenTCP()
    86  		if err != nil {
    87  			return err
    88  		}
    89  		n.httpServer = &http.Server{
    90  			Handler:   n,
    91  			TLSConfig: tlsConfig,
    92  			BaseContext: func(listener net.Listener) context.Context {
    93  				return n.ctx
    94  			},
    95  		}
    96  		go func() {
    97  			var sErr error
    98  			if tlsConfig != nil {
    99  				sErr = n.httpServer.ServeTLS(tcpListener, "", "")
   100  			} else {
   101  				sErr = n.httpServer.Serve(tcpListener)
   102  			}
   103  			if sErr != nil && !E.IsClosedOrCanceled(sErr) {
   104  				n.logger.Error("http server serve error: ", sErr)
   105  			}
   106  		}()
   107  	}
   108  
   109  	if common.Contains(n.network, N.NetworkUDP) {
   110  		err := n.configureHTTP3Listener()
   111  		if !C.WithQUIC && len(n.network) > 1 {
   112  			n.logger.Warn(E.Cause(err, "naive http3 disabled"))
   113  		} else if err != nil {
   114  			return err
   115  		}
   116  	}
   117  
   118  	return nil
   119  }
   120  
   121  func (n *Naive) Close() error {
   122  	return common.Close(
   123  		&n.myInboundAdapter,
   124  		common.PtrOrNil(n.httpServer),
   125  		n.h3Server,
   126  		n.tlsConfig,
   127  	)
   128  }
   129  
   130  func (n *Naive) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
   131  	ctx := log.ContextWithNewID(request.Context())
   132  	if request.Method != "CONNECT" {
   133  		rejectHTTP(writer, http.StatusBadRequest)
   134  		n.badRequest(ctx, request, E.New("not CONNECT request"))
   135  		return
   136  	} else if request.Header.Get("Padding") == "" {
   137  		rejectHTTP(writer, http.StatusBadRequest)
   138  		n.badRequest(ctx, request, E.New("missing naive padding"))
   139  		return
   140  	}
   141  	userName, password, authOk := sHttp.ParseBasicAuth(request.Header.Get("Proxy-Authorization"))
   142  	if authOk {
   143  		authOk = n.authenticator.Verify(userName, password)
   144  	}
   145  	if !authOk {
   146  		rejectHTTP(writer, http.StatusProxyAuthRequired)
   147  		n.badRequest(ctx, request, E.New("authorization failed"))
   148  		return
   149  	}
   150  	writer.Header().Set("Padding", generateNaivePaddingHeader())
   151  	writer.WriteHeader(http.StatusOK)
   152  	writer.(http.Flusher).Flush()
   153  
   154  	hostPort := request.URL.Host
   155  	if hostPort == "" {
   156  		hostPort = request.Host
   157  	}
   158  	source := sHttp.SourceAddress(request)
   159  	destination := M.ParseSocksaddr(hostPort)
   160  
   161  	if hijacker, isHijacker := writer.(http.Hijacker); isHijacker {
   162  		conn, _, err := hijacker.Hijack()
   163  		if err != nil {
   164  			n.badRequest(ctx, request, E.New("hijack failed"))
   165  			return
   166  		}
   167  		n.newConnection(ctx, &naiveH1Conn{Conn: conn}, userName, source, destination)
   168  	} else {
   169  		n.newConnection(ctx, &naiveH2Conn{reader: request.Body, writer: writer, flusher: writer.(http.Flusher)}, userName, source, destination)
   170  	}
   171  }
   172  
   173  func (n *Naive) newConnection(ctx context.Context, conn net.Conn, userName string, source, destination M.Socksaddr) {
   174  	if userName != "" {
   175  		n.logger.InfoContext(ctx, "[", userName, "] inbound connection from ", source)
   176  		n.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", destination)
   177  	} else {
   178  		n.logger.InfoContext(ctx, "inbound connection from ", source)
   179  		n.logger.InfoContext(ctx, "inbound connection to ", destination)
   180  	}
   181  	hErr := n.router.RouteConnection(ctx, conn, n.createMetadata(conn, adapter.InboundContext{
   182  		Source:      source,
   183  		Destination: destination,
   184  		User:        userName,
   185  	}))
   186  	if hErr != nil {
   187  		conn.Close()
   188  		n.NewError(ctx, E.Cause(hErr, "process connection from ", source))
   189  	}
   190  }
   191  
   192  func (n *Naive) badRequest(ctx context.Context, request *http.Request, err error) {
   193  	n.NewError(ctx, E.Cause(err, "process connection from ", request.RemoteAddr))
   194  }
   195  
   196  func rejectHTTP(writer http.ResponseWriter, statusCode int) {
   197  	hijacker, ok := writer.(http.Hijacker)
   198  	if !ok {
   199  		writer.WriteHeader(statusCode)
   200  		return
   201  	}
   202  	conn, _, err := hijacker.Hijack()
   203  	if err != nil {
   204  		writer.WriteHeader(statusCode)
   205  		return
   206  	}
   207  	if tcpConn, isTCP := common.Cast[*net.TCPConn](conn); isTCP {
   208  		tcpConn.SetLinger(0)
   209  	}
   210  	conn.Close()
   211  }
   212  
   213  func generateNaivePaddingHeader() string {
   214  	paddingLen := rand.Intn(32) + 30
   215  	padding := make([]byte, paddingLen)
   216  	bits := rand.Uint64()
   217  	for i := 0; i < 16; i++ {
   218  		// Codes that won't be Huffman coded.
   219  		padding[i] = "!#$()+<>?@[]^`{}"[bits&15]
   220  		bits >>= 4
   221  	}
   222  	for i := 16; i < paddingLen; i++ {
   223  		padding[i] = '~'
   224  	}
   225  	return string(padding)
   226  }
   227  
   228  const kFirstPaddings = 8
   229  
   230  type naiveH1Conn struct {
   231  	net.Conn
   232  	readPadding      int
   233  	writePadding     int
   234  	readRemaining    int
   235  	paddingRemaining int
   236  }
   237  
   238  func (c *naiveH1Conn) Read(p []byte) (n int, err error) {
   239  	n, err = c.read(p)
   240  	return n, wrapHttpError(err)
   241  }
   242  
   243  func (c *naiveH1Conn) read(p []byte) (n int, err error) {
   244  	if c.readRemaining > 0 {
   245  		if len(p) > c.readRemaining {
   246  			p = p[:c.readRemaining]
   247  		}
   248  		n, err = c.Conn.Read(p)
   249  		if err != nil {
   250  			return
   251  		}
   252  		c.readRemaining -= n
   253  		return
   254  	}
   255  	if c.paddingRemaining > 0 {
   256  		err = rw.SkipN(c.Conn, c.paddingRemaining)
   257  		if err != nil {
   258  			return
   259  		}
   260  		c.paddingRemaining = 0
   261  	}
   262  	if c.readPadding < kFirstPaddings {
   263  		var paddingHdr []byte
   264  		if len(p) >= 3 {
   265  			paddingHdr = p[:3]
   266  		} else {
   267  			paddingHdr = make([]byte, 3)
   268  		}
   269  		_, err = io.ReadFull(c.Conn, paddingHdr)
   270  		if err != nil {
   271  			return
   272  		}
   273  		originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
   274  		paddingSize := int(paddingHdr[2])
   275  		if len(p) > originalDataSize {
   276  			p = p[:originalDataSize]
   277  		}
   278  		n, err = c.Conn.Read(p)
   279  		if err != nil {
   280  			return
   281  		}
   282  		c.readPadding++
   283  		c.readRemaining = originalDataSize - n
   284  		c.paddingRemaining = paddingSize
   285  		return
   286  	}
   287  	return c.Conn.Read(p)
   288  }
   289  
   290  func (c *naiveH1Conn) Write(p []byte) (n int, err error) {
   291  	for pLen := len(p); pLen > 0; {
   292  		var data []byte
   293  		if pLen > 65535 {
   294  			data = p[:65535]
   295  			p = p[65535:]
   296  			pLen -= 65535
   297  		} else {
   298  			data = p
   299  			pLen = 0
   300  		}
   301  		var writeN int
   302  		writeN, err = c.write(data)
   303  		n += writeN
   304  		if err != nil {
   305  			break
   306  		}
   307  	}
   308  	return n, wrapHttpError(err)
   309  }
   310  
   311  func (c *naiveH1Conn) write(p []byte) (n int, err error) {
   312  	if c.writePadding < kFirstPaddings {
   313  		paddingSize := rand.Intn(256)
   314  
   315  		buffer := buf.NewSize(3 + len(p) + paddingSize)
   316  		defer buffer.Release()
   317  		header := buffer.Extend(3)
   318  		binary.BigEndian.PutUint16(header, uint16(len(p)))
   319  		header[2] = byte(paddingSize)
   320  
   321  		common.Must1(buffer.Write(p))
   322  		_, err = c.Conn.Write(buffer.Bytes())
   323  		if err == nil {
   324  			n = len(p)
   325  		}
   326  		c.writePadding++
   327  		return
   328  	}
   329  	return c.Conn.Write(p)
   330  }
   331  
   332  func (c *naiveH1Conn) FrontHeadroom() int {
   333  	if c.writePadding < kFirstPaddings {
   334  		return 3
   335  	}
   336  	return 0
   337  }
   338  
   339  func (c *naiveH1Conn) RearHeadroom() int {
   340  	if c.writePadding < kFirstPaddings {
   341  		return 255
   342  	}
   343  	return 0
   344  }
   345  
   346  func (c *naiveH1Conn) WriterMTU() int {
   347  	if c.writePadding < kFirstPaddings {
   348  		return 65535
   349  	}
   350  	return 0
   351  }
   352  
   353  func (c *naiveH1Conn) WriteBuffer(buffer *buf.Buffer) error {
   354  	defer buffer.Release()
   355  	if c.writePadding < kFirstPaddings {
   356  		bufferLen := buffer.Len()
   357  		if bufferLen > 65535 {
   358  			return common.Error(c.Write(buffer.Bytes()))
   359  		}
   360  		paddingSize := rand.Intn(256)
   361  		header := buffer.ExtendHeader(3)
   362  		binary.BigEndian.PutUint16(header, uint16(bufferLen))
   363  		header[2] = byte(paddingSize)
   364  		buffer.Extend(paddingSize)
   365  		c.writePadding++
   366  	}
   367  	return wrapHttpError(common.Error(c.Conn.Write(buffer.Bytes())))
   368  }
   369  
   370  // FIXME
   371  /*func (c *naiveH1Conn) WriteTo(w io.Writer) (n int64, err error) {
   372  	if c.readPadding < kFirstPaddings {
   373  		n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
   374  	} else {
   375  		n, err = bufio.Copy(w, c.Conn)
   376  	}
   377  	return n, wrapHttpError(err)
   378  }
   379  
   380  func (c *naiveH1Conn) ReadFrom(r io.Reader) (n int64, err error) {
   381  	if c.writePadding < kFirstPaddings {
   382  		n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
   383  	} else {
   384  		n, err = bufio.Copy(c.Conn, r)
   385  	}
   386  	return n, wrapHttpError(err)
   387  }
   388  */
   389  
   390  func (c *naiveH1Conn) Upstream() any {
   391  	return c.Conn
   392  }
   393  
   394  func (c *naiveH1Conn) ReaderReplaceable() bool {
   395  	return c.readPadding == kFirstPaddings
   396  }
   397  
   398  func (c *naiveH1Conn) WriterReplaceable() bool {
   399  	return c.writePadding == kFirstPaddings
   400  }
   401  
   402  type naiveH2Conn struct {
   403  	reader           io.Reader
   404  	writer           io.Writer
   405  	flusher          http.Flusher
   406  	rAddr            net.Addr
   407  	readPadding      int
   408  	writePadding     int
   409  	readRemaining    int
   410  	paddingRemaining int
   411  }
   412  
   413  func (c *naiveH2Conn) Read(p []byte) (n int, err error) {
   414  	n, err = c.read(p)
   415  	return n, wrapHttpError(err)
   416  }
   417  
   418  func (c *naiveH2Conn) read(p []byte) (n int, err error) {
   419  	if c.readRemaining > 0 {
   420  		if len(p) > c.readRemaining {
   421  			p = p[:c.readRemaining]
   422  		}
   423  		n, err = c.reader.Read(p)
   424  		if err != nil {
   425  			return
   426  		}
   427  		c.readRemaining -= n
   428  		return
   429  	}
   430  	if c.paddingRemaining > 0 {
   431  		err = rw.SkipN(c.reader, c.paddingRemaining)
   432  		if err != nil {
   433  			return
   434  		}
   435  		c.paddingRemaining = 0
   436  	}
   437  	if c.readPadding < kFirstPaddings {
   438  		var paddingHdr []byte
   439  		if len(p) >= 3 {
   440  			paddingHdr = p[:3]
   441  		} else {
   442  			paddingHdr = make([]byte, 3)
   443  		}
   444  		_, err = io.ReadFull(c.reader, paddingHdr)
   445  		if err != nil {
   446  			return
   447  		}
   448  		originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
   449  		paddingSize := int(paddingHdr[2])
   450  		if len(p) > originalDataSize {
   451  			p = p[:originalDataSize]
   452  		}
   453  		n, err = c.reader.Read(p)
   454  		if err != nil {
   455  			return
   456  		}
   457  		c.readPadding++
   458  		c.readRemaining = originalDataSize - n
   459  		c.paddingRemaining = paddingSize
   460  		return
   461  	}
   462  	return c.reader.Read(p)
   463  }
   464  
   465  func (c *naiveH2Conn) Write(p []byte) (n int, err error) {
   466  	for pLen := len(p); pLen > 0; {
   467  		var data []byte
   468  		if pLen > 65535 {
   469  			data = p[:65535]
   470  			p = p[65535:]
   471  			pLen -= 65535
   472  		} else {
   473  			data = p
   474  			pLen = 0
   475  		}
   476  		var writeN int
   477  		writeN, err = c.write(data)
   478  		n += writeN
   479  		if err != nil {
   480  			break
   481  		}
   482  	}
   483  	if err == nil {
   484  		c.flusher.Flush()
   485  	}
   486  	return n, wrapHttpError(err)
   487  }
   488  
   489  func (c *naiveH2Conn) write(p []byte) (n int, err error) {
   490  	if c.writePadding < kFirstPaddings {
   491  		paddingSize := rand.Intn(256)
   492  
   493  		buffer := buf.NewSize(3 + len(p) + paddingSize)
   494  		defer buffer.Release()
   495  		header := buffer.Extend(3)
   496  		binary.BigEndian.PutUint16(header, uint16(len(p)))
   497  		header[2] = byte(paddingSize)
   498  
   499  		common.Must1(buffer.Write(p))
   500  		_, err = c.writer.Write(buffer.Bytes())
   501  		if err == nil {
   502  			n = len(p)
   503  		}
   504  		c.writePadding++
   505  		return
   506  	}
   507  	return c.writer.Write(p)
   508  }
   509  
   510  func (c *naiveH2Conn) FrontHeadroom() int {
   511  	if c.writePadding < kFirstPaddings {
   512  		return 3
   513  	}
   514  	return 0
   515  }
   516  
   517  func (c *naiveH2Conn) RearHeadroom() int {
   518  	if c.writePadding < kFirstPaddings {
   519  		return 255
   520  	}
   521  	return 0
   522  }
   523  
   524  func (c *naiveH2Conn) WriterMTU() int {
   525  	if c.writePadding < kFirstPaddings {
   526  		return 65535
   527  	}
   528  	return 0
   529  }
   530  
   531  func (c *naiveH2Conn) WriteBuffer(buffer *buf.Buffer) error {
   532  	defer buffer.Release()
   533  	if c.writePadding < kFirstPaddings {
   534  		bufferLen := buffer.Len()
   535  		if bufferLen > 65535 {
   536  			return common.Error(c.Write(buffer.Bytes()))
   537  		}
   538  		paddingSize := rand.Intn(256)
   539  		header := buffer.ExtendHeader(3)
   540  		binary.BigEndian.PutUint16(header, uint16(bufferLen))
   541  		header[2] = byte(paddingSize)
   542  		buffer.Extend(paddingSize)
   543  		c.writePadding++
   544  	}
   545  	err := common.Error(c.writer.Write(buffer.Bytes()))
   546  	if err == nil {
   547  		c.flusher.Flush()
   548  	}
   549  	return wrapHttpError(err)
   550  }
   551  
   552  // FIXME
   553  /*func (c *naiveH2Conn) WriteTo(w io.Writer) (n int64, err error) {
   554  	if c.readPadding < kFirstPaddings {
   555  		n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
   556  	} else {
   557  		n, err = bufio.Copy(w, c.reader)
   558  	}
   559  	return n, wrapHttpError(err)
   560  }
   561  
   562  func (c *naiveH2Conn) ReadFrom(r io.Reader) (n int64, err error) {
   563  	if c.writePadding < kFirstPaddings {
   564  		n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
   565  	} else {
   566  		n, err = bufio.Copy(c.writer, r)
   567  	}
   568  	return n, wrapHttpError(err)
   569  }*/
   570  
   571  func (c *naiveH2Conn) Close() error {
   572  	return common.Close(
   573  		c.reader,
   574  		c.writer,
   575  	)
   576  }
   577  
   578  func (c *naiveH2Conn) LocalAddr() net.Addr {
   579  	return M.Socksaddr{}
   580  }
   581  
   582  func (c *naiveH2Conn) RemoteAddr() net.Addr {
   583  	return c.rAddr
   584  }
   585  
   586  func (c *naiveH2Conn) SetDeadline(t time.Time) error {
   587  	return os.ErrInvalid
   588  }
   589  
   590  func (c *naiveH2Conn) SetReadDeadline(t time.Time) error {
   591  	return os.ErrInvalid
   592  }
   593  
   594  func (c *naiveH2Conn) SetWriteDeadline(t time.Time) error {
   595  	return os.ErrInvalid
   596  }
   597  
   598  func (c *naiveH2Conn) NeedAdditionalReadDeadline() bool {
   599  	return true
   600  }
   601  
   602  func (c *naiveH2Conn) UpstreamReader() any {
   603  	return c.reader
   604  }
   605  
   606  func (c *naiveH2Conn) UpstreamWriter() any {
   607  	return c.writer
   608  }
   609  
   610  func (c *naiveH2Conn) ReaderReplaceable() bool {
   611  	return c.readPadding == kFirstPaddings
   612  }
   613  
   614  func (c *naiveH2Conn) WriterReplaceable() bool {
   615  	return c.writePadding == kFirstPaddings
   616  }
   617  
   618  func wrapHttpError(err error) error {
   619  	if err == nil {
   620  		return err
   621  	}
   622  	if strings.Contains(err.Error(), "client disconnected") {
   623  		return net.ErrClosed
   624  	}
   625  	if strings.Contains(err.Error(), "body closed by handler") {
   626  		return net.ErrClosed
   627  	}
   628  	if strings.Contains(err.Error(), "canceled with error code 268") {
   629  		return io.EOF
   630  	}
   631  	return err
   632  }