github.com/sagernet/sing-mux@v0.2.1-0.20240124034317-9bfb33698bb6/server.go (about)

     1  package mux
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  
     7  	"github.com/sagernet/sing/common/bufio"
     8  	"github.com/sagernet/sing/common/debug"
     9  	E "github.com/sagernet/sing/common/exceptions"
    10  	"github.com/sagernet/sing/common/logger"
    11  	M "github.com/sagernet/sing/common/metadata"
    12  	N "github.com/sagernet/sing/common/network"
    13  	"github.com/sagernet/sing/common/task"
    14  )
    15  
    16  type ServiceHandler interface {
    17  	N.TCPConnectionHandler
    18  	N.UDPConnectionHandler
    19  }
    20  
    21  type Service struct {
    22  	newStreamContext func(context.Context, net.Conn) context.Context
    23  	logger           logger.ContextLogger
    24  	handler          ServiceHandler
    25  	padding          bool
    26  	brutal           BrutalOptions
    27  }
    28  
    29  type ServiceOptions struct {
    30  	NewStreamContext func(context.Context, net.Conn) context.Context
    31  	Logger           logger.ContextLogger
    32  	Handler          ServiceHandler
    33  	Padding          bool
    34  	Brutal           BrutalOptions
    35  }
    36  
    37  func NewService(options ServiceOptions) (*Service, error) {
    38  	if options.Brutal.Enabled && !BrutalAvailable && !debug.Enabled {
    39  		return nil, E.New("TCP Brutal is only supported on Linux")
    40  	}
    41  	return &Service{
    42  		newStreamContext: options.NewStreamContext,
    43  		logger:           options.Logger,
    44  		handler:          options.Handler,
    45  		padding:          options.Padding,
    46  		brutal:           options.Brutal,
    47  	}, nil
    48  }
    49  
    50  func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
    51  	request, err := ReadRequest(conn)
    52  	if err != nil {
    53  		return err
    54  	}
    55  	if request.Padding {
    56  		conn = newPaddingConn(conn)
    57  	} else if s.padding {
    58  		return E.New("non-padded connection rejected")
    59  	}
    60  	session, err := newServerSession(conn, request.Protocol)
    61  	if err != nil {
    62  		return err
    63  	}
    64  	var group task.Group
    65  	group.Append0(func(_ context.Context) error {
    66  		var stream net.Conn
    67  		for {
    68  			stream, err = session.Accept()
    69  			if err != nil {
    70  				return err
    71  			}
    72  			streamCtx := s.newStreamContext(ctx, stream)
    73  			go func() {
    74  				hErr := s.newConnection(streamCtx, conn, stream, metadata)
    75  				if hErr != nil {
    76  					s.logger.ErrorContext(streamCtx, E.Cause(hErr, "handle connection"))
    77  				}
    78  			}()
    79  		}
    80  	})
    81  	group.Cleanup(func() {
    82  		session.Close()
    83  	})
    84  	return group.Run(ctx)
    85  }
    86  
    87  func (s *Service) newConnection(ctx context.Context, sessionConn net.Conn, stream net.Conn, metadata M.Metadata) error {
    88  	stream = &wrapStream{stream}
    89  	request, err := ReadStreamRequest(stream)
    90  	if err != nil {
    91  		return E.Cause(err, "read multiplex stream request")
    92  	}
    93  	metadata.Destination = request.Destination
    94  	if request.Network == N.NetworkTCP {
    95  		conn := &serverConn{ExtendedConn: bufio.NewExtendedConn(stream)}
    96  		if request.Destination.Fqdn == BrutalExchangeDomain {
    97  			defer stream.Close()
    98  			var clientReceiveBPS uint64
    99  			clientReceiveBPS, err = ReadBrutalRequest(conn)
   100  			if err != nil {
   101  				return E.Cause(err, "read brutal request")
   102  			}
   103  			if !s.brutal.Enabled {
   104  				err = WriteBrutalResponse(conn, 0, false, "brutal is not enabled by the server")
   105  				if err != nil {
   106  					return E.Cause(err, "write brutal response")
   107  				}
   108  				return nil
   109  			}
   110  			sendBPS := s.brutal.SendBPS
   111  			if clientReceiveBPS < sendBPS {
   112  				sendBPS = clientReceiveBPS
   113  			}
   114  			err = SetBrutalOptions(sessionConn, sendBPS)
   115  			if err != nil {
   116  				// ignore error in test
   117  				if !debug.Enabled {
   118  					err = WriteBrutalResponse(conn, 0, false, E.Cause(err, "enable TCP Brutal").Error())
   119  					if err != nil {
   120  						return E.Cause(err, "write brutal response")
   121  					}
   122  					return nil
   123  				}
   124  			}
   125  			err = WriteBrutalResponse(conn, s.brutal.ReceiveBPS, true, "")
   126  			if err != nil {
   127  				return E.Cause(err, "write brutal response")
   128  			}
   129  			return nil
   130  		}
   131  		s.logger.InfoContext(ctx, "inbound multiplex connection to ", metadata.Destination)
   132  		s.handler.NewConnection(ctx, conn, metadata)
   133  		stream.Close()
   134  	} else {
   135  		var packetConn N.PacketConn
   136  		if !request.PacketAddr {
   137  			s.logger.InfoContext(ctx, "inbound multiplex packet connection to ", metadata.Destination)
   138  			packetConn = &serverPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: request.Destination}
   139  		} else {
   140  			s.logger.InfoContext(ctx, "inbound multiplex packet connection")
   141  			packetConn = &serverPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream)}
   142  		}
   143  		s.handler.NewPacketConnection(ctx, packetConn, metadata)
   144  		stream.Close()
   145  	}
   146  	return nil
   147  }