github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/tuic/client.go (about) 1 //go:build with_quic 2 3 package tuic 4 5 import ( 6 "context" 7 "github.com/sagernet/quic-go" 8 "io" 9 "net" 10 "os" 11 "runtime" 12 "sync" 13 "time" 14 15 "github.com/inazumav/sing-box/common/baderror" 16 "github.com/inazumav/sing-box/common/qtls" 17 "github.com/inazumav/sing-box/common/tls" 18 "github.com/sagernet/sing/common" 19 "github.com/sagernet/sing/common/buf" 20 "github.com/sagernet/sing/common/bufio" 21 E "github.com/sagernet/sing/common/exceptions" 22 M "github.com/sagernet/sing/common/metadata" 23 N "github.com/sagernet/sing/common/network" 24 25 "github.com/gofrs/uuid/v5" 26 ) 27 28 type ClientOptions struct { 29 Context context.Context 30 Dialer N.Dialer 31 ServerAddress M.Socksaddr 32 TLSConfig tls.Config 33 UUID uuid.UUID 34 Password string 35 CongestionControl string 36 UDPStream bool 37 ZeroRTTHandshake bool 38 Heartbeat time.Duration 39 } 40 41 type Client struct { 42 ctx context.Context 43 dialer N.Dialer 44 serverAddr M.Socksaddr 45 tlsConfig tls.Config 46 quicConfig *quic.Config 47 uuid uuid.UUID 48 password string 49 congestionControl string 50 udpStream bool 51 zeroRTTHandshake bool 52 heartbeat time.Duration 53 54 connAccess sync.RWMutex 55 conn *clientQUICConnection 56 } 57 58 func NewClient(options ClientOptions) (*Client, error) { 59 if options.Heartbeat == 0 { 60 options.Heartbeat = 10 * time.Second 61 } 62 quicConfig := &quic.Config{ 63 DisablePathMTUDiscovery: !(runtime.GOOS == "windows" || runtime.GOOS == "linux" || runtime.GOOS == "android" || runtime.GOOS == "darwin"), 64 MaxDatagramFrameSize: 1400, 65 EnableDatagrams: true, 66 MaxIncomingUniStreams: 1 << 60, 67 } 68 switch options.CongestionControl { 69 case "": 70 options.CongestionControl = "cubic" 71 case "cubic", "new_reno", "bbr": 72 default: 73 return nil, E.New("unknown congestion control algorithm: ", options.CongestionControl) 74 } 75 return &Client{ 76 ctx: options.Context, 77 dialer: options.Dialer, 78 serverAddr: options.ServerAddress, 79 tlsConfig: options.TLSConfig, 80 quicConfig: quicConfig, 81 uuid: options.UUID, 82 password: options.Password, 83 congestionControl: options.CongestionControl, 84 udpStream: options.UDPStream, 85 zeroRTTHandshake: options.ZeroRTTHandshake, 86 heartbeat: options.Heartbeat, 87 }, nil 88 } 89 90 func (c *Client) offer(ctx context.Context) (*clientQUICConnection, error) { 91 conn := c.conn 92 if conn != nil && conn.active() { 93 return conn, nil 94 } 95 c.connAccess.Lock() 96 defer c.connAccess.Unlock() 97 conn = c.conn 98 if conn != nil && conn.active() { 99 return conn, nil 100 } 101 conn, err := c.offerNew(ctx) 102 if err != nil { 103 return nil, err 104 } 105 return conn, nil 106 } 107 108 func (c *Client) offerNew(ctx context.Context) (*clientQUICConnection, error) { 109 udpConn, err := c.dialer.DialContext(ctx, "udp", c.serverAddr) 110 if err != nil { 111 return nil, err 112 } 113 var quicConn quic.Connection 114 if c.zeroRTTHandshake { 115 quicConn, err = qtls.DialEarly(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig) 116 } else { 117 quicConn, err = qtls.Dial(ctx, bufio.NewUnbindPacketConn(udpConn), udpConn.RemoteAddr(), c.tlsConfig, c.quicConfig) 118 } 119 if err != nil { 120 udpConn.Close() 121 return nil, E.Cause(err, "open connection") 122 } 123 setCongestion(c.ctx, quicConn, c.congestionControl) 124 conn := &clientQUICConnection{ 125 quicConn: quicConn, 126 rawConn: udpConn, 127 connDone: make(chan struct{}), 128 udpConnMap: make(map[uint16]*udpPacketConn), 129 } 130 go func() { 131 hErr := c.clientHandshake(quicConn) 132 if hErr != nil { 133 conn.closeWithError(hErr) 134 } 135 }() 136 if c.udpStream { 137 go c.loopUniStreams(conn) 138 } 139 go c.loopMessages(conn) 140 go c.loopHeartbeats(conn) 141 c.conn = conn 142 return conn, nil 143 } 144 145 func (c *Client) clientHandshake(conn quic.Connection) error { 146 authStream, err := conn.OpenUniStream() 147 if err != nil { 148 return E.Cause(err, "open handshake stream") 149 } 150 defer authStream.Close() 151 handshakeState := conn.ConnectionState() 152 tuicAuthToken, err := handshakeState.ExportKeyingMaterial(string(c.uuid[:]), []byte(c.password), 32) 153 if err != nil { 154 return E.Cause(err, "export keying material") 155 } 156 authRequest := buf.NewSize(AuthenticateLen) 157 authRequest.WriteByte(Version) 158 authRequest.WriteByte(CommandAuthenticate) 159 authRequest.Write(c.uuid[:]) 160 authRequest.Write(tuicAuthToken) 161 return common.Error(authStream.Write(authRequest.Bytes())) 162 } 163 164 func (c *Client) loopHeartbeats(conn *clientQUICConnection) { 165 ticker := time.NewTicker(c.heartbeat) 166 defer ticker.Stop() 167 for { 168 select { 169 case <-conn.connDone: 170 return 171 case <-ticker.C: 172 err := conn.quicConn.SendMessage([]byte{Version, CommandHeartbeat}) 173 if err != nil { 174 conn.closeWithError(E.Cause(err, "send heartbeat")) 175 } 176 } 177 } 178 } 179 180 func (c *Client) DialConn(ctx context.Context, destination M.Socksaddr) (net.Conn, error) { 181 conn, err := c.offer(ctx) 182 if err != nil { 183 return nil, err 184 } 185 stream, err := conn.quicConn.OpenStream() 186 if err != nil { 187 return nil, err 188 } 189 return &clientConn{ 190 parent: conn, 191 stream: stream, 192 destination: destination, 193 }, nil 194 } 195 196 func (c *Client) ListenPacket(ctx context.Context) (net.PacketConn, error) { 197 conn, err := c.offer(ctx) 198 if err != nil { 199 return nil, err 200 } 201 var sessionID uint16 202 clientPacketConn := newUDPPacketConn(ctx, conn.quicConn, c.udpStream, false, func() { 203 conn.udpAccess.Lock() 204 delete(conn.udpConnMap, sessionID) 205 conn.udpAccess.Unlock() 206 }) 207 conn.udpAccess.Lock() 208 sessionID = conn.udpSessionID 209 conn.udpSessionID++ 210 conn.udpConnMap[sessionID] = clientPacketConn 211 conn.udpAccess.Unlock() 212 clientPacketConn.sessionID = sessionID 213 return clientPacketConn, nil 214 } 215 216 func (c *Client) CloseWithError(err error) error { 217 conn := c.conn 218 if conn != nil { 219 conn.closeWithError(err) 220 } 221 return nil 222 } 223 224 type clientQUICConnection struct { 225 quicConn quic.Connection 226 rawConn io.Closer 227 closeOnce sync.Once 228 connDone chan struct{} 229 connErr error 230 udpAccess sync.RWMutex 231 udpConnMap map[uint16]*udpPacketConn 232 udpSessionID uint16 233 } 234 235 func (c *clientQUICConnection) active() bool { 236 select { 237 case <-c.quicConn.Context().Done(): 238 return false 239 default: 240 } 241 select { 242 case <-c.connDone: 243 return false 244 default: 245 } 246 return true 247 } 248 249 func (c *clientQUICConnection) closeWithError(err error) { 250 c.closeOnce.Do(func() { 251 c.connErr = err 252 close(c.connDone) 253 _ = c.quicConn.CloseWithError(0, "") 254 _ = c.rawConn.Close() 255 }) 256 } 257 258 type clientConn struct { 259 parent *clientQUICConnection 260 stream quic.Stream 261 destination M.Socksaddr 262 requestWritten bool 263 } 264 265 func (c *clientConn) NeedHandshake() bool { 266 return !c.requestWritten 267 } 268 269 func (c *clientConn) Read(b []byte) (n int, err error) { 270 n, err = c.stream.Read(b) 271 return n, baderror.WrapQUIC(err) 272 } 273 274 func (c *clientConn) Write(b []byte) (n int, err error) { 275 if !c.requestWritten { 276 request := buf.NewSize(2 + addressSerializer.AddrPortLen(c.destination) + len(b)) 277 defer request.Release() 278 request.WriteByte(Version) 279 request.WriteByte(CommandConnect) 280 err = addressSerializer.WriteAddrPort(request, c.destination) 281 if err != nil { 282 return 283 } 284 request.Write(b) 285 _, err = c.stream.Write(request.Bytes()) 286 if err != nil { 287 c.parent.closeWithError(E.Cause(err, "create new connection")) 288 return 0, baderror.WrapQUIC(err) 289 } 290 c.requestWritten = true 291 return len(b), nil 292 } 293 n, err = c.stream.Write(b) 294 return n, baderror.WrapQUIC(err) 295 } 296 297 func (c *clientConn) Close() error { 298 stream := c.stream 299 if stream == nil { 300 return nil 301 } 302 stream.CancelRead(0) 303 return stream.Close() 304 } 305 306 func (c *clientConn) LocalAddr() net.Addr { 307 return M.Socksaddr{} 308 } 309 310 func (c *clientConn) RemoteAddr() net.Addr { 311 return c.destination 312 } 313 314 func (c *clientConn) SetDeadline(t time.Time) error { 315 if c.stream == nil { 316 return os.ErrInvalid 317 } 318 return c.stream.SetDeadline(t) 319 } 320 321 func (c *clientConn) SetReadDeadline(t time.Time) error { 322 if c.stream == nil { 323 return os.ErrInvalid 324 } 325 return c.stream.SetReadDeadline(t) 326 } 327 328 func (c *clientConn) SetWriteDeadline(t time.Time) error { 329 if c.stream == nil { 330 return os.ErrInvalid 331 } 332 return c.stream.SetWriteDeadline(t) 333 }