github.com/sagernet/sing-mux@v0.2.1-0.20240124034317-9bfb33698bb6/server.go (about) 1 package mux 2 3 import ( 4 "context" 5 "net" 6 7 "github.com/sagernet/sing/common/bufio" 8 "github.com/sagernet/sing/common/debug" 9 E "github.com/sagernet/sing/common/exceptions" 10 "github.com/sagernet/sing/common/logger" 11 M "github.com/sagernet/sing/common/metadata" 12 N "github.com/sagernet/sing/common/network" 13 "github.com/sagernet/sing/common/task" 14 ) 15 16 type ServiceHandler interface { 17 N.TCPConnectionHandler 18 N.UDPConnectionHandler 19 } 20 21 type Service struct { 22 newStreamContext func(context.Context, net.Conn) context.Context 23 logger logger.ContextLogger 24 handler ServiceHandler 25 padding bool 26 brutal BrutalOptions 27 } 28 29 type ServiceOptions struct { 30 NewStreamContext func(context.Context, net.Conn) context.Context 31 Logger logger.ContextLogger 32 Handler ServiceHandler 33 Padding bool 34 Brutal BrutalOptions 35 } 36 37 func NewService(options ServiceOptions) (*Service, error) { 38 if options.Brutal.Enabled && !BrutalAvailable && !debug.Enabled { 39 return nil, E.New("TCP Brutal is only supported on Linux") 40 } 41 return &Service{ 42 newStreamContext: options.NewStreamContext, 43 logger: options.Logger, 44 handler: options.Handler, 45 padding: options.Padding, 46 brutal: options.Brutal, 47 }, nil 48 } 49 50 func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { 51 request, err := ReadRequest(conn) 52 if err != nil { 53 return err 54 } 55 if request.Padding { 56 conn = newPaddingConn(conn) 57 } else if s.padding { 58 return E.New("non-padded connection rejected") 59 } 60 session, err := newServerSession(conn, request.Protocol) 61 if err != nil { 62 return err 63 } 64 var group task.Group 65 group.Append0(func(_ context.Context) error { 66 var stream net.Conn 67 for { 68 stream, err = session.Accept() 69 if err != nil { 70 return err 71 } 72 streamCtx := s.newStreamContext(ctx, stream) 73 go func() { 74 hErr := s.newConnection(streamCtx, conn, stream, metadata) 75 if hErr != nil { 76 s.logger.ErrorContext(streamCtx, E.Cause(hErr, "handle connection")) 77 } 78 }() 79 } 80 }) 81 group.Cleanup(func() { 82 session.Close() 83 }) 84 return group.Run(ctx) 85 } 86 87 func (s *Service) newConnection(ctx context.Context, sessionConn net.Conn, stream net.Conn, metadata M.Metadata) error { 88 stream = &wrapStream{stream} 89 request, err := ReadStreamRequest(stream) 90 if err != nil { 91 return E.Cause(err, "read multiplex stream request") 92 } 93 metadata.Destination = request.Destination 94 if request.Network == N.NetworkTCP { 95 conn := &serverConn{ExtendedConn: bufio.NewExtendedConn(stream)} 96 if request.Destination.Fqdn == BrutalExchangeDomain { 97 defer stream.Close() 98 var clientReceiveBPS uint64 99 clientReceiveBPS, err = ReadBrutalRequest(conn) 100 if err != nil { 101 return E.Cause(err, "read brutal request") 102 } 103 if !s.brutal.Enabled { 104 err = WriteBrutalResponse(conn, 0, false, "brutal is not enabled by the server") 105 if err != nil { 106 return E.Cause(err, "write brutal response") 107 } 108 return nil 109 } 110 sendBPS := s.brutal.SendBPS 111 if clientReceiveBPS < sendBPS { 112 sendBPS = clientReceiveBPS 113 } 114 err = SetBrutalOptions(sessionConn, sendBPS) 115 if err != nil { 116 // ignore error in test 117 if !debug.Enabled { 118 err = WriteBrutalResponse(conn, 0, false, E.Cause(err, "enable TCP Brutal").Error()) 119 if err != nil { 120 return E.Cause(err, "write brutal response") 121 } 122 return nil 123 } 124 } 125 err = WriteBrutalResponse(conn, s.brutal.ReceiveBPS, true, "") 126 if err != nil { 127 return E.Cause(err, "write brutal response") 128 } 129 return nil 130 } 131 s.logger.InfoContext(ctx, "inbound multiplex connection to ", metadata.Destination) 132 s.handler.NewConnection(ctx, conn, metadata) 133 stream.Close() 134 } else { 135 var packetConn N.PacketConn 136 if !request.PacketAddr { 137 s.logger.InfoContext(ctx, "inbound multiplex packet connection to ", metadata.Destination) 138 packetConn = &serverPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: request.Destination} 139 } else { 140 s.logger.InfoContext(ctx, "inbound multiplex packet connection") 141 packetConn = &serverPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream)} 142 } 143 s.handler.NewPacketConnection(ctx, packetConn, metadata) 144 stream.Close() 145 } 146 return nil 147 }