github.com/sagernet/sing-mux@v0.2.1-0.20240124034317-9bfb33698bb6/client.go (about) 1 package mux 2 3 import ( 4 "context" 5 "net" 6 "sync" 7 8 "github.com/sagernet/sing/common" 9 "github.com/sagernet/sing/common/bufio" 10 E "github.com/sagernet/sing/common/exceptions" 11 "github.com/sagernet/sing/common/logger" 12 M "github.com/sagernet/sing/common/metadata" 13 N "github.com/sagernet/sing/common/network" 14 "github.com/sagernet/sing/common/x/list" 15 ) 16 17 type Client struct { 18 dialer N.Dialer 19 logger logger.Logger 20 protocol byte 21 maxConnections int 22 minStreams int 23 maxStreams int 24 padding bool 25 access sync.Mutex 26 connections list.List[abstractSession] 27 brutal BrutalOptions 28 } 29 30 type Options struct { 31 Dialer N.Dialer 32 Logger logger.Logger 33 Protocol string 34 MaxConnections int 35 MinStreams int 36 MaxStreams int 37 Padding bool 38 Brutal BrutalOptions 39 } 40 41 type BrutalOptions struct { 42 Enabled bool 43 SendBPS uint64 44 ReceiveBPS uint64 45 } 46 47 func NewClient(options Options) (*Client, error) { 48 client := &Client{ 49 dialer: options.Dialer, 50 logger: options.Logger, 51 maxConnections: options.MaxConnections, 52 minStreams: options.MinStreams, 53 maxStreams: options.MaxStreams, 54 padding: options.Padding, 55 brutal: options.Brutal, 56 } 57 if client.dialer == nil { 58 client.dialer = N.SystemDialer 59 } 60 if client.maxStreams == 0 && client.maxConnections == 0 { 61 client.minStreams = 8 62 } 63 switch options.Protocol { 64 case "", "h2mux": 65 client.protocol = ProtocolH2Mux 66 case "smux": 67 client.protocol = ProtocolSmux 68 case "yamux": 69 client.protocol = ProtocolYAMux 70 default: 71 return nil, E.New("unknown protocol: " + options.Protocol) 72 } 73 return client, nil 74 } 75 76 func (c *Client) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { 77 switch N.NetworkName(network) { 78 case N.NetworkTCP: 79 stream, err := c.openStream(ctx) 80 if err != nil { 81 return nil, err 82 } 83 return &clientConn{Conn: stream, destination: destination}, nil 84 case N.NetworkUDP: 85 stream, err := c.openStream(ctx) 86 if err != nil { 87 return nil, err 88 } 89 extendedConn := bufio.NewExtendedConn(stream) 90 return &clientPacketConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil 91 default: 92 return nil, E.Extend(N.ErrUnknownNetwork, network) 93 } 94 } 95 96 func (c *Client) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { 97 stream, err := c.openStream(ctx) 98 if err != nil { 99 return nil, err 100 } 101 extendedConn := bufio.NewExtendedConn(stream) 102 return &clientPacketAddrConn{AbstractConn: extendedConn, conn: extendedConn, destination: destination}, nil 103 } 104 105 func (c *Client) openStream(ctx context.Context) (net.Conn, error) { 106 var ( 107 session abstractSession 108 stream net.Conn 109 err error 110 ) 111 for attempts := 0; attempts < 2; attempts++ { 112 session, err = c.offer(ctx) 113 if err != nil { 114 continue 115 } 116 stream, err = session.Open() 117 if err != nil { 118 continue 119 } 120 break 121 } 122 if err != nil { 123 return nil, err 124 } 125 return &wrapStream{stream}, nil 126 } 127 128 func (c *Client) offer(ctx context.Context) (abstractSession, error) { 129 c.access.Lock() 130 defer c.access.Unlock() 131 132 var sessions []abstractSession 133 for element := c.connections.Front(); element != nil; { 134 if element.Value.IsClosed() { 135 element.Value.Close() 136 nextElement := element.Next() 137 c.connections.Remove(element) 138 element = nextElement 139 continue 140 } 141 sessions = append(sessions, element.Value) 142 element = element.Next() 143 } 144 if c.brutal.Enabled { 145 if len(sessions) > 0 { 146 return sessions[0], nil 147 } 148 return c.offerNew(ctx) 149 } 150 session := common.MinBy(common.Filter(sessions, abstractSession.CanTakeNewRequest), abstractSession.NumStreams) 151 if session == nil { 152 return c.offerNew(ctx) 153 } 154 numStreams := session.NumStreams() 155 if numStreams == 0 { 156 return session, nil 157 } 158 if c.maxConnections > 0 { 159 if len(sessions) >= c.maxConnections || numStreams < c.minStreams { 160 return session, nil 161 } 162 } else { 163 if c.maxStreams > 0 && numStreams < c.maxStreams { 164 return session, nil 165 } 166 } 167 return c.offerNew(ctx) 168 } 169 170 func (c *Client) offerNew(ctx context.Context) (abstractSession, error) { 171 ctx, cancel := context.WithTimeout(ctx, TCPTimeout) 172 defer cancel() 173 conn, err := c.dialer.DialContext(ctx, N.NetworkTCP, Destination) 174 if err != nil { 175 return nil, err 176 } 177 var version byte 178 if c.padding { 179 version = Version1 180 } else { 181 version = Version0 182 } 183 conn = newProtocolConn(conn, Request{ 184 Version: version, 185 Protocol: c.protocol, 186 Padding: c.padding, 187 }) 188 if c.padding { 189 conn = newPaddingConn(conn) 190 } 191 session, err := newClientSession(conn, c.protocol) 192 if err != nil { 193 conn.Close() 194 return nil, err 195 } 196 if c.brutal.Enabled { 197 err = c.brutalExchange(ctx, conn, session) 198 if err != nil { 199 conn.Close() 200 session.Close() 201 return nil, E.Cause(err, "brutal exchange") 202 } 203 } 204 c.connections.PushBack(session) 205 return session, nil 206 } 207 208 func (c *Client) brutalExchange(ctx context.Context, sessionConn net.Conn, session abstractSession) error { 209 stream, err := session.Open() 210 if err != nil { 211 return err 212 } 213 conn := &clientConn{Conn: &wrapStream{stream}, destination: M.Socksaddr{Fqdn: BrutalExchangeDomain}} 214 err = WriteBrutalRequest(conn, c.brutal.ReceiveBPS) 215 if err != nil { 216 return err 217 } 218 serverReceiveBPS, err := ReadBrutalResponse(conn) 219 if err != nil { 220 return err 221 } 222 conn.Close() 223 sendBPS := c.brutal.SendBPS 224 if serverReceiveBPS < sendBPS { 225 sendBPS = serverReceiveBPS 226 } 227 clientBrutalErr := SetBrutalOptions(sessionConn, sendBPS) 228 if clientBrutalErr != nil { 229 c.logger.Debug(E.Cause(clientBrutalErr, "failed to enable TCP Brutal at client")) 230 } 231 return nil 232 } 233 234 func (c *Client) Reset() { 235 c.access.Lock() 236 defer c.access.Unlock() 237 for _, session := range c.connections.Array() { 238 session.Close() 239 } 240 c.connections.Init() 241 } 242 243 func (c *Client) Close() error { 244 c.Reset() 245 return nil 246 }