github.com/sagernet/sing-box@v1.2.7/inbound/naive.go (about)

     1  package inbound
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"encoding/binary"
     7  	"io"
     8  	"math/rand"
     9  	"net"
    10  	"net/http"
    11  	"os"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/sagernet/sing-box/adapter"
    16  	"github.com/sagernet/sing-box/common/tls"
    17  	C "github.com/sagernet/sing-box/constant"
    18  	"github.com/sagernet/sing-box/include"
    19  	"github.com/sagernet/sing-box/log"
    20  	"github.com/sagernet/sing-box/option"
    21  	"github.com/sagernet/sing/common"
    22  	"github.com/sagernet/sing/common/auth"
    23  	"github.com/sagernet/sing/common/buf"
    24  	E "github.com/sagernet/sing/common/exceptions"
    25  	M "github.com/sagernet/sing/common/metadata"
    26  	N "github.com/sagernet/sing/common/network"
    27  	"github.com/sagernet/sing/common/rw"
    28  	sHttp "github.com/sagernet/sing/protocol/http"
    29  )
    30  
    31  var _ adapter.Inbound = (*Naive)(nil)
    32  
    33  type Naive struct {
    34  	myInboundAdapter
    35  	authenticator auth.Authenticator
    36  	tlsConfig     tls.ServerConfig
    37  	httpServer    *http.Server
    38  	h3Server      any
    39  }
    40  
    41  func NewNaive(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.NaiveInboundOptions) (*Naive, error) {
    42  	inbound := &Naive{
    43  		myInboundAdapter: myInboundAdapter{
    44  			protocol:      C.TypeNaive,
    45  			network:       options.Network.Build(),
    46  			ctx:           ctx,
    47  			router:        router,
    48  			logger:        logger,
    49  			tag:           tag,
    50  			listenOptions: options.ListenOptions,
    51  		},
    52  		authenticator: auth.NewAuthenticator(options.Users),
    53  	}
    54  	if common.Contains(inbound.network, N.NetworkUDP) {
    55  		if options.TLS == nil || !options.TLS.Enabled {
    56  			return nil, E.New("TLS is required for QUIC server")
    57  		}
    58  	}
    59  	if len(options.Users) == 0 {
    60  		return nil, E.New("missing users")
    61  	}
    62  	if options.TLS != nil {
    63  		tlsConfig, err := tls.NewServer(ctx, router, logger, common.PtrValueOrDefault(options.TLS))
    64  		if err != nil {
    65  			return nil, err
    66  		}
    67  		inbound.tlsConfig = tlsConfig
    68  	}
    69  	return inbound, nil
    70  }
    71  
    72  func (n *Naive) Start() error {
    73  	var tlsConfig *tls.STDConfig
    74  	if n.tlsConfig != nil {
    75  		err := n.tlsConfig.Start()
    76  		if err != nil {
    77  			return E.Cause(err, "create TLS config")
    78  		}
    79  		tlsConfig, err = n.tlsConfig.Config()
    80  		if err != nil {
    81  			return err
    82  		}
    83  	}
    84  
    85  	if common.Contains(n.network, N.NetworkTCP) {
    86  		tcpListener, err := n.ListenTCP()
    87  		if err != nil {
    88  			return err
    89  		}
    90  		n.httpServer = &http.Server{
    91  			Handler:   n,
    92  			TLSConfig: tlsConfig,
    93  			BaseContext: func(listener net.Listener) context.Context {
    94  				return n.ctx
    95  			},
    96  		}
    97  		go func() {
    98  			var sErr error
    99  			if tlsConfig != nil {
   100  				sErr = n.httpServer.ServeTLS(tcpListener, "", "")
   101  			} else {
   102  				sErr = n.httpServer.Serve(tcpListener)
   103  			}
   104  			if sErr != nil && !E.IsClosedOrCanceled(sErr) {
   105  				n.logger.Error("http server serve error: ", sErr)
   106  			}
   107  		}()
   108  	}
   109  
   110  	if common.Contains(n.network, N.NetworkUDP) {
   111  		err := n.configureHTTP3Listener()
   112  		if !include.WithQUIC && len(n.network) > 1 {
   113  			log.Warn(E.Cause(err, "naive http3 disabled"))
   114  		} else if err != nil {
   115  			return err
   116  		}
   117  	}
   118  
   119  	return nil
   120  }
   121  
   122  func (n *Naive) Close() error {
   123  	return common.Close(
   124  		&n.myInboundAdapter,
   125  		common.PtrOrNil(n.httpServer),
   126  		n.h3Server,
   127  		n.tlsConfig,
   128  	)
   129  }
   130  
   131  func (n *Naive) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
   132  	ctx := log.ContextWithNewID(request.Context())
   133  	if request.Method != "CONNECT" {
   134  		rejectHTTP(writer, http.StatusBadRequest)
   135  		n.badRequest(ctx, request, E.New("not CONNECT request"))
   136  		return
   137  	} else if request.Header.Get("Padding") == "" {
   138  		rejectHTTP(writer, http.StatusBadRequest)
   139  		n.badRequest(ctx, request, E.New("missing naive padding"))
   140  		return
   141  	}
   142  	var authOk bool
   143  	var userName string
   144  	authorization := request.Header.Get("Proxy-Authorization")
   145  	if strings.HasPrefix(authorization, "BASIC ") || strings.HasPrefix(authorization, "Basic ") {
   146  		userPassword, _ := base64.URLEncoding.DecodeString(authorization[6:])
   147  		userPswdArr := strings.SplitN(string(userPassword), ":", 2)
   148  		userName = userPswdArr[0]
   149  		authOk = n.authenticator.Verify(userPswdArr[0], userPswdArr[1])
   150  	}
   151  	if !authOk {
   152  		rejectHTTP(writer, http.StatusProxyAuthRequired)
   153  		n.badRequest(ctx, request, E.New("authorization failed"))
   154  		return
   155  	}
   156  	writer.Header().Set("Padding", generateNaivePaddingHeader())
   157  	writer.WriteHeader(http.StatusOK)
   158  	writer.(http.Flusher).Flush()
   159  
   160  	hostPort := request.URL.Host
   161  	if hostPort == "" {
   162  		hostPort = request.Host
   163  	}
   164  	source := sHttp.SourceAddress(request)
   165  	destination := M.ParseSocksaddr(hostPort)
   166  
   167  	if hijacker, isHijacker := writer.(http.Hijacker); isHijacker {
   168  		conn, _, err := hijacker.Hijack()
   169  		if err != nil {
   170  			n.badRequest(ctx, request, E.New("hijack failed"))
   171  			return
   172  		}
   173  		n.newConnection(ctx, &naiveH1Conn{Conn: conn}, userName, source, destination)
   174  	} else {
   175  		n.newConnection(ctx, &naiveH2Conn{reader: request.Body, writer: writer, flusher: writer.(http.Flusher)}, userName, source, destination)
   176  	}
   177  }
   178  
   179  func (n *Naive) newConnection(ctx context.Context, conn net.Conn, userName string, source, destination M.Socksaddr) {
   180  	if userName != "" {
   181  		n.logger.InfoContext(ctx, "[", userName, "] inbound connection from ", source)
   182  		n.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", destination)
   183  	} else {
   184  		n.logger.InfoContext(ctx, "inbound connection from ", source)
   185  		n.logger.InfoContext(ctx, "inbound connection to ", destination)
   186  	}
   187  	hErr := n.router.RouteConnection(ctx, conn, n.createMetadata(conn, adapter.InboundContext{
   188  		Source:      source,
   189  		Destination: destination,
   190  		User:        userName,
   191  	}))
   192  	if hErr != nil {
   193  		conn.Close()
   194  		n.NewError(ctx, E.Cause(hErr, "process connection from ", source))
   195  	}
   196  }
   197  
   198  func (n *Naive) badRequest(ctx context.Context, request *http.Request, err error) {
   199  	n.NewError(ctx, E.Cause(err, "process connection from ", request.RemoteAddr))
   200  }
   201  
   202  func rejectHTTP(writer http.ResponseWriter, statusCode int) {
   203  	hijacker, ok := writer.(http.Hijacker)
   204  	if !ok {
   205  		writer.WriteHeader(statusCode)
   206  		return
   207  	}
   208  	conn, _, err := hijacker.Hijack()
   209  	if err != nil {
   210  		writer.WriteHeader(statusCode)
   211  		return
   212  	}
   213  	if tcpConn, isTCP := common.Cast[*net.TCPConn](conn); isTCP {
   214  		tcpConn.SetLinger(0)
   215  	}
   216  	conn.Close()
   217  }
   218  
   219  func generateNaivePaddingHeader() string {
   220  	paddingLen := rand.Intn(32) + 30
   221  	padding := make([]byte, paddingLen)
   222  	bits := rand.Uint64()
   223  	for i := 0; i < 16; i++ {
   224  		// Codes that won't be Huffman coded.
   225  		padding[i] = "!#$()+<>?@[]^`{}"[bits&15]
   226  		bits >>= 4
   227  	}
   228  	for i := 16; i < paddingLen; i++ {
   229  		padding[i] = '~'
   230  	}
   231  	return string(padding)
   232  }
   233  
   234  const kFirstPaddings = 8
   235  
   236  type naiveH1Conn struct {
   237  	net.Conn
   238  	readPadding      int
   239  	writePadding     int
   240  	readRemaining    int
   241  	paddingRemaining int
   242  }
   243  
   244  func (c *naiveH1Conn) Read(p []byte) (n int, err error) {
   245  	n, err = c.read(p)
   246  	return n, wrapHttpError(err)
   247  }
   248  
   249  func (c *naiveH1Conn) read(p []byte) (n int, err error) {
   250  	if c.readRemaining > 0 {
   251  		if len(p) > c.readRemaining {
   252  			p = p[:c.readRemaining]
   253  		}
   254  		n, err = c.Conn.Read(p)
   255  		if err != nil {
   256  			return
   257  		}
   258  		c.readRemaining -= n
   259  		return
   260  	}
   261  	if c.paddingRemaining > 0 {
   262  		err = rw.SkipN(c.Conn, c.paddingRemaining)
   263  		if err != nil {
   264  			return
   265  		}
   266  		c.paddingRemaining = 0
   267  	}
   268  	if c.readPadding < kFirstPaddings {
   269  		var paddingHdr []byte
   270  		if len(p) >= 3 {
   271  			paddingHdr = p[:3]
   272  		} else {
   273  			_paddingHdr := make([]byte, 3)
   274  			defer common.KeepAlive(_paddingHdr)
   275  			paddingHdr = common.Dup(_paddingHdr)
   276  		}
   277  		_, err = io.ReadFull(c.Conn, paddingHdr)
   278  		if err != nil {
   279  			return
   280  		}
   281  		originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
   282  		paddingSize := int(paddingHdr[2])
   283  		if len(p) > originalDataSize {
   284  			p = p[:originalDataSize]
   285  		}
   286  		n, err = c.Conn.Read(p)
   287  		if err != nil {
   288  			return
   289  		}
   290  		c.readPadding++
   291  		c.readRemaining = originalDataSize - n
   292  		c.paddingRemaining = paddingSize
   293  		return
   294  	}
   295  	return c.Conn.Read(p)
   296  }
   297  
   298  func (c *naiveH1Conn) Write(p []byte) (n int, err error) {
   299  	for pLen := len(p); pLen > 0; {
   300  		var data []byte
   301  		if pLen > 65535 {
   302  			data = p[:65535]
   303  			p = p[65535:]
   304  			pLen -= 65535
   305  		} else {
   306  			data = p
   307  			pLen = 0
   308  		}
   309  		var writeN int
   310  		writeN, err = c.write(data)
   311  		n += writeN
   312  		if err != nil {
   313  			break
   314  		}
   315  	}
   316  	return n, wrapHttpError(err)
   317  }
   318  
   319  func (c *naiveH1Conn) write(p []byte) (n int, err error) {
   320  	if c.writePadding < kFirstPaddings {
   321  		paddingSize := rand.Intn(256)
   322  
   323  		_buffer := buf.StackNewSize(3 + len(p) + paddingSize)
   324  		defer common.KeepAlive(_buffer)
   325  		buffer := common.Dup(_buffer)
   326  		defer buffer.Release()
   327  		header := buffer.Extend(3)
   328  		binary.BigEndian.PutUint16(header, uint16(len(p)))
   329  		header[2] = byte(paddingSize)
   330  
   331  		common.Must1(buffer.Write(p))
   332  		_, err = c.Conn.Write(buffer.Bytes())
   333  		if err == nil {
   334  			n = len(p)
   335  		}
   336  		c.writePadding++
   337  		return
   338  	}
   339  	return c.Conn.Write(p)
   340  }
   341  
   342  func (c *naiveH1Conn) FrontHeadroom() int {
   343  	if c.writePadding < kFirstPaddings {
   344  		return 3
   345  	}
   346  	return 0
   347  }
   348  
   349  func (c *naiveH1Conn) RearHeadroom() int {
   350  	if c.writePadding < kFirstPaddings {
   351  		return 255
   352  	}
   353  	return 0
   354  }
   355  
   356  func (c *naiveH1Conn) WriterMTU() int {
   357  	if c.writePadding < kFirstPaddings {
   358  		return 65535
   359  	}
   360  	return 0
   361  }
   362  
   363  func (c *naiveH1Conn) WriteBuffer(buffer *buf.Buffer) error {
   364  	defer buffer.Release()
   365  	if c.writePadding < kFirstPaddings {
   366  		bufferLen := buffer.Len()
   367  		if bufferLen > 65535 {
   368  			return common.Error(c.Write(buffer.Bytes()))
   369  		}
   370  		paddingSize := rand.Intn(256)
   371  		header := buffer.ExtendHeader(3)
   372  		binary.BigEndian.PutUint16(header, uint16(bufferLen))
   373  		header[2] = byte(paddingSize)
   374  		buffer.Extend(paddingSize)
   375  		c.writePadding++
   376  	}
   377  	return wrapHttpError(common.Error(c.Conn.Write(buffer.Bytes())))
   378  }
   379  
   380  // FIXME
   381  /*func (c *naiveH1Conn) WriteTo(w io.Writer) (n int64, err error) {
   382  	if c.readPadding < kFirstPaddings {
   383  		n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
   384  	} else {
   385  		n, err = bufio.Copy(w, c.Conn)
   386  	}
   387  	return n, wrapHttpError(err)
   388  }
   389  
   390  func (c *naiveH1Conn) ReadFrom(r io.Reader) (n int64, err error) {
   391  	if c.writePadding < kFirstPaddings {
   392  		n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
   393  	} else {
   394  		n, err = bufio.Copy(c.Conn, r)
   395  	}
   396  	return n, wrapHttpError(err)
   397  }
   398  */
   399  
   400  func (c *naiveH1Conn) Upstream() any {
   401  	return c.Conn
   402  }
   403  
   404  func (c *naiveH1Conn) ReaderReplaceable() bool {
   405  	return c.readPadding == kFirstPaddings
   406  }
   407  
   408  func (c *naiveH1Conn) WriterReplaceable() bool {
   409  	return c.writePadding == kFirstPaddings
   410  }
   411  
   412  type naiveH2Conn struct {
   413  	reader           io.Reader
   414  	writer           io.Writer
   415  	flusher          http.Flusher
   416  	rAddr            net.Addr
   417  	readPadding      int
   418  	writePadding     int
   419  	readRemaining    int
   420  	paddingRemaining int
   421  }
   422  
   423  func (c *naiveH2Conn) Read(p []byte) (n int, err error) {
   424  	n, err = c.read(p)
   425  	return n, wrapHttpError(err)
   426  }
   427  
   428  func (c *naiveH2Conn) read(p []byte) (n int, err error) {
   429  	if c.readRemaining > 0 {
   430  		if len(p) > c.readRemaining {
   431  			p = p[:c.readRemaining]
   432  		}
   433  		n, err = c.reader.Read(p)
   434  		if err != nil {
   435  			return
   436  		}
   437  		c.readRemaining -= n
   438  		return
   439  	}
   440  	if c.paddingRemaining > 0 {
   441  		err = rw.SkipN(c.reader, c.paddingRemaining)
   442  		if err != nil {
   443  			return
   444  		}
   445  		c.paddingRemaining = 0
   446  	}
   447  	if c.readPadding < kFirstPaddings {
   448  		var paddingHdr []byte
   449  		if len(p) >= 3 {
   450  			paddingHdr = p[:3]
   451  		} else {
   452  			_paddingHdr := make([]byte, 3)
   453  			defer common.KeepAlive(_paddingHdr)
   454  			paddingHdr = common.Dup(_paddingHdr)
   455  		}
   456  		_, err = io.ReadFull(c.reader, paddingHdr)
   457  		if err != nil {
   458  			return
   459  		}
   460  		originalDataSize := int(binary.BigEndian.Uint16(paddingHdr[:2]))
   461  		paddingSize := int(paddingHdr[2])
   462  		if len(p) > originalDataSize {
   463  			p = p[:originalDataSize]
   464  		}
   465  		n, err = c.reader.Read(p)
   466  		if err != nil {
   467  			return
   468  		}
   469  		c.readPadding++
   470  		c.readRemaining = originalDataSize - n
   471  		c.paddingRemaining = paddingSize
   472  		return
   473  	}
   474  	return c.reader.Read(p)
   475  }
   476  
   477  func (c *naiveH2Conn) Write(p []byte) (n int, err error) {
   478  	for pLen := len(p); pLen > 0; {
   479  		var data []byte
   480  		if pLen > 65535 {
   481  			data = p[:65535]
   482  			p = p[65535:]
   483  			pLen -= 65535
   484  		} else {
   485  			data = p
   486  			pLen = 0
   487  		}
   488  		var writeN int
   489  		writeN, err = c.write(data)
   490  		n += writeN
   491  		if err != nil {
   492  			break
   493  		}
   494  	}
   495  	if err == nil {
   496  		c.flusher.Flush()
   497  	}
   498  	return n, wrapHttpError(err)
   499  }
   500  
   501  func (c *naiveH2Conn) write(p []byte) (n int, err error) {
   502  	if c.writePadding < kFirstPaddings {
   503  		paddingSize := rand.Intn(256)
   504  
   505  		_buffer := buf.StackNewSize(3 + len(p) + paddingSize)
   506  		defer common.KeepAlive(_buffer)
   507  		buffer := common.Dup(_buffer)
   508  		defer buffer.Release()
   509  		header := buffer.Extend(3)
   510  		binary.BigEndian.PutUint16(header, uint16(len(p)))
   511  		header[2] = byte(paddingSize)
   512  
   513  		common.Must1(buffer.Write(p))
   514  		_, err = c.writer.Write(buffer.Bytes())
   515  		if err == nil {
   516  			n = len(p)
   517  		}
   518  		c.writePadding++
   519  		return
   520  	}
   521  	return c.writer.Write(p)
   522  }
   523  
   524  func (c *naiveH2Conn) FrontHeadroom() int {
   525  	if c.writePadding < kFirstPaddings {
   526  		return 3
   527  	}
   528  	return 0
   529  }
   530  
   531  func (c *naiveH2Conn) RearHeadroom() int {
   532  	if c.writePadding < kFirstPaddings {
   533  		return 255
   534  	}
   535  	return 0
   536  }
   537  
   538  func (c *naiveH2Conn) WriterMTU() int {
   539  	if c.writePadding < kFirstPaddings {
   540  		return 65535
   541  	}
   542  	return 0
   543  }
   544  
   545  func (c *naiveH2Conn) WriteBuffer(buffer *buf.Buffer) error {
   546  	defer buffer.Release()
   547  	if c.writePadding < kFirstPaddings {
   548  		bufferLen := buffer.Len()
   549  		if bufferLen > 65535 {
   550  			return common.Error(c.Write(buffer.Bytes()))
   551  		}
   552  		paddingSize := rand.Intn(256)
   553  		header := buffer.ExtendHeader(3)
   554  		binary.BigEndian.PutUint16(header, uint16(bufferLen))
   555  		header[2] = byte(paddingSize)
   556  		buffer.Extend(paddingSize)
   557  		c.writePadding++
   558  	}
   559  	err := common.Error(c.writer.Write(buffer.Bytes()))
   560  	if err == nil {
   561  		c.flusher.Flush()
   562  	}
   563  	return wrapHttpError(err)
   564  }
   565  
   566  // FIXME
   567  /*func (c *naiveH2Conn) WriteTo(w io.Writer) (n int64, err error) {
   568  	if c.readPadding < kFirstPaddings {
   569  		n, err = bufio.WriteToN(c, w, kFirstPaddings-c.readPadding)
   570  	} else {
   571  		n, err = bufio.Copy(w, c.reader)
   572  	}
   573  	return n, wrapHttpError(err)
   574  }
   575  
   576  func (c *naiveH2Conn) ReadFrom(r io.Reader) (n int64, err error) {
   577  	if c.writePadding < kFirstPaddings {
   578  		n, err = bufio.ReadFromN(c, r, kFirstPaddings-c.writePadding)
   579  	} else {
   580  		n, err = bufio.Copy(c.writer, r)
   581  	}
   582  	return n, wrapHttpError(err)
   583  }*/
   584  
   585  func (c *naiveH2Conn) Close() error {
   586  	return common.Close(
   587  		c.reader,
   588  		c.writer,
   589  	)
   590  }
   591  
   592  func (c *naiveH2Conn) LocalAddr() net.Addr {
   593  	return nil
   594  }
   595  
   596  func (c *naiveH2Conn) RemoteAddr() net.Addr {
   597  	return c.rAddr
   598  }
   599  
   600  func (c *naiveH2Conn) SetDeadline(t time.Time) error {
   601  	return os.ErrInvalid
   602  }
   603  
   604  func (c *naiveH2Conn) SetReadDeadline(t time.Time) error {
   605  	return os.ErrInvalid
   606  }
   607  
   608  func (c *naiveH2Conn) SetWriteDeadline(t time.Time) error {
   609  	return os.ErrInvalid
   610  }
   611  
   612  func (c *naiveH2Conn) NeedAdditionalReadDeadline() bool {
   613  	return true
   614  }
   615  
   616  func (c *naiveH2Conn) UpstreamReader() any {
   617  	return c.reader
   618  }
   619  
   620  func (c *naiveH2Conn) UpstreamWriter() any {
   621  	return c.writer
   622  }
   623  
   624  func (c *naiveH2Conn) ReaderReplaceable() bool {
   625  	return c.readPadding == kFirstPaddings
   626  }
   627  
   628  func (c *naiveH2Conn) WriterReplaceable() bool {
   629  	return c.writePadding == kFirstPaddings
   630  }
   631  
   632  func wrapHttpError(err error) error {
   633  	if err == nil {
   634  		return err
   635  	}
   636  	if strings.Contains(err.Error(), "client disconnected") {
   637  		return net.ErrClosed
   638  	}
   639  	if strings.Contains(err.Error(), "body closed by handler") {
   640  		return net.ErrClosed
   641  	}
   642  	if strings.Contains(err.Error(), "canceled with error code 268") {
   643  		return io.EOF
   644  	}
   645  	return err
   646  }