github.com/sagernet/sing-box@v1.2.7/common/mux/service.go (about)

     1  package mux
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"net"
     7  
     8  	"github.com/sagernet/sing-box/adapter"
     9  	"github.com/sagernet/sing-box/log"
    10  	"github.com/sagernet/sing/common"
    11  	"github.com/sagernet/sing/common/buf"
    12  	"github.com/sagernet/sing/common/bufio"
    13  	E "github.com/sagernet/sing/common/exceptions"
    14  	M "github.com/sagernet/sing/common/metadata"
    15  	N "github.com/sagernet/sing/common/network"
    16  	"github.com/sagernet/sing/common/rw"
    17  	"github.com/sagernet/sing/common/task"
    18  )
    19  
    20  func NewConnection(ctx context.Context, router adapter.Router, errorHandler E.Handler, logger log.ContextLogger, conn net.Conn, metadata adapter.InboundContext) error {
    21  	request, err := ReadRequest(conn)
    22  	if err != nil {
    23  		return err
    24  	}
    25  	session, err := request.Protocol.newServer(conn)
    26  	if err != nil {
    27  		return err
    28  	}
    29  	var group task.Group
    30  	group.Append0(func(ctx context.Context) error {
    31  		var stream net.Conn
    32  		for {
    33  			stream, err = session.Accept()
    34  			if err != nil {
    35  				return err
    36  			}
    37  			go newConnection(ctx, router, errorHandler, logger, stream, metadata)
    38  		}
    39  	})
    40  	group.Cleanup(func() {
    41  		session.Close()
    42  	})
    43  	return group.Run(ctx)
    44  }
    45  
    46  func newConnection(ctx context.Context, router adapter.Router, errorHandler E.Handler, logger log.ContextLogger, stream net.Conn, metadata adapter.InboundContext) {
    47  	stream = &wrapStream{stream}
    48  	request, err := ReadStreamRequest(stream)
    49  	if err != nil {
    50  		logger.ErrorContext(ctx, err)
    51  		return
    52  	}
    53  	metadata.Destination = request.Destination
    54  	if request.Network == N.NetworkTCP {
    55  		logger.InfoContext(ctx, "inbound multiplex connection to ", metadata.Destination)
    56  		hErr := router.RouteConnection(ctx, &ServerConn{ExtendedConn: bufio.NewExtendedConn(stream)}, metadata)
    57  		stream.Close()
    58  		if hErr != nil {
    59  			errorHandler.NewError(ctx, hErr)
    60  		}
    61  	} else {
    62  		var packetConn N.PacketConn
    63  		if !request.PacketAddr {
    64  			logger.InfoContext(ctx, "inbound multiplex packet connection to ", metadata.Destination)
    65  			packetConn = &ServerPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: request.Destination}
    66  		} else {
    67  			logger.InfoContext(ctx, "inbound multiplex packet connection")
    68  			packetConn = &ServerPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream)}
    69  		}
    70  		hErr := router.RoutePacketConnection(ctx, packetConn, metadata)
    71  		stream.Close()
    72  		if hErr != nil {
    73  			errorHandler.NewError(ctx, hErr)
    74  		}
    75  	}
    76  }
    77  
    78  var _ N.HandshakeConn = (*ServerConn)(nil)
    79  
    80  type ServerConn struct {
    81  	N.ExtendedConn
    82  	responseWrite bool
    83  }
    84  
    85  func (c *ServerConn) HandshakeFailure(err error) error {
    86  	errMessage := err.Error()
    87  	_buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage))
    88  	defer common.KeepAlive(_buffer)
    89  	buffer := common.Dup(_buffer)
    90  	defer buffer.Release()
    91  	common.Must(
    92  		buffer.WriteByte(statusError),
    93  		rw.WriteVString(_buffer, errMessage),
    94  	)
    95  	return c.ExtendedConn.WriteBuffer(buffer)
    96  }
    97  
    98  func (c *ServerConn) Write(b []byte) (n int, err error) {
    99  	if c.responseWrite {
   100  		return c.ExtendedConn.Write(b)
   101  	}
   102  	_buffer := buf.StackNewSize(1 + len(b))
   103  	defer common.KeepAlive(_buffer)
   104  	buffer := common.Dup(_buffer)
   105  	defer buffer.Release()
   106  	common.Must(
   107  		buffer.WriteByte(statusSuccess),
   108  		common.Error(buffer.Write(b)),
   109  	)
   110  	_, err = c.ExtendedConn.Write(buffer.Bytes())
   111  	if err != nil {
   112  		return
   113  	}
   114  	c.responseWrite = true
   115  	return len(b), nil
   116  }
   117  
   118  func (c *ServerConn) WriteBuffer(buffer *buf.Buffer) error {
   119  	if c.responseWrite {
   120  		return c.ExtendedConn.WriteBuffer(buffer)
   121  	}
   122  	buffer.ExtendHeader(1)[0] = statusSuccess
   123  	c.responseWrite = true
   124  	return c.ExtendedConn.WriteBuffer(buffer)
   125  }
   126  
   127  func (c *ServerConn) FrontHeadroom() int {
   128  	if !c.responseWrite {
   129  		return 1
   130  	}
   131  	return 0
   132  }
   133  
   134  func (c *ServerConn) NeedAdditionalReadDeadline() bool {
   135  	return true
   136  }
   137  
   138  func (c *ServerConn) Upstream() any {
   139  	return c.ExtendedConn
   140  }
   141  
   142  var (
   143  	_ N.HandshakeConn = (*ServerPacketConn)(nil)
   144  	_ N.PacketConn    = (*ServerPacketConn)(nil)
   145  )
   146  
   147  type ServerPacketConn struct {
   148  	N.ExtendedConn
   149  	destination   M.Socksaddr
   150  	responseWrite bool
   151  }
   152  
   153  func (c *ServerPacketConn) HandshakeFailure(err error) error {
   154  	errMessage := err.Error()
   155  	_buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage))
   156  	defer common.KeepAlive(_buffer)
   157  	buffer := common.Dup(_buffer)
   158  	defer buffer.Release()
   159  	common.Must(
   160  		buffer.WriteByte(statusError),
   161  		rw.WriteVString(_buffer, errMessage),
   162  	)
   163  	return c.ExtendedConn.WriteBuffer(buffer)
   164  }
   165  
   166  func (c *ServerPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   167  	var length uint16
   168  	err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
   169  	if err != nil {
   170  		return
   171  	}
   172  	_, err = buffer.ReadFullFrom(c.ExtendedConn, int(length))
   173  	if err != nil {
   174  		return
   175  	}
   176  	destination = c.destination
   177  	return
   178  }
   179  
   180  func (c *ServerPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   181  	pLen := buffer.Len()
   182  	common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen)))
   183  	if !c.responseWrite {
   184  		buffer.ExtendHeader(1)[0] = statusSuccess
   185  		c.responseWrite = true
   186  	}
   187  	return c.ExtendedConn.WriteBuffer(buffer)
   188  }
   189  
   190  func (c *ServerPacketConn) NeedAdditionalReadDeadline() bool {
   191  	return true
   192  }
   193  
   194  func (c *ServerPacketConn) Upstream() any {
   195  	return c.ExtendedConn
   196  }
   197  
   198  func (c *ServerPacketConn) FrontHeadroom() int {
   199  	if !c.responseWrite {
   200  		return 3
   201  	}
   202  	return 2
   203  }
   204  
   205  var (
   206  	_ N.HandshakeConn = (*ServerPacketAddrConn)(nil)
   207  	_ N.PacketConn    = (*ServerPacketAddrConn)(nil)
   208  )
   209  
   210  type ServerPacketAddrConn struct {
   211  	N.ExtendedConn
   212  	responseWrite bool
   213  }
   214  
   215  func (c *ServerPacketAddrConn) HandshakeFailure(err error) error {
   216  	errMessage := err.Error()
   217  	_buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage))
   218  	defer common.KeepAlive(_buffer)
   219  	buffer := common.Dup(_buffer)
   220  	defer buffer.Release()
   221  	common.Must(
   222  		buffer.WriteByte(statusError),
   223  		rw.WriteVString(_buffer, errMessage),
   224  	)
   225  	return c.ExtendedConn.WriteBuffer(buffer)
   226  }
   227  
   228  func (c *ServerPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   229  	destination, err = M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn)
   230  	if err != nil {
   231  		return
   232  	}
   233  	var length uint16
   234  	err = binary.Read(c.ExtendedConn, binary.BigEndian, &length)
   235  	if err != nil {
   236  		return
   237  	}
   238  	_, err = buffer.ReadFullFrom(c.ExtendedConn, int(length))
   239  	if err != nil {
   240  		return
   241  	}
   242  	return
   243  }
   244  
   245  func (c *ServerPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   246  	pLen := buffer.Len()
   247  	common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen)))
   248  	common.Must(M.SocksaddrSerializer.WriteAddrPort(buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination))), destination))
   249  	if !c.responseWrite {
   250  		buffer.ExtendHeader(1)[0] = statusSuccess
   251  		c.responseWrite = true
   252  	}
   253  	return c.ExtendedConn.WriteBuffer(buffer)
   254  }
   255  
   256  func (c *ServerPacketAddrConn) NeedAdditionalReadDeadline() bool {
   257  	return true
   258  }
   259  
   260  func (c *ServerPacketAddrConn) Upstream() any {
   261  	return c.ExtendedConn
   262  }
   263  
   264  func (c *ServerPacketAddrConn) FrontHeadroom() int {
   265  	if !c.responseWrite {
   266  		return 3 + M.MaxSocksaddrLength
   267  	}
   268  	return 2 + M.MaxSocksaddrLength
   269  }