github.com/sagernet/sing-box@v1.2.7/transport/vless/client.go (about) 1 package vless 2 3 import ( 4 "encoding/binary" 5 "io" 6 "net" 7 8 "github.com/sagernet/sing-vmess" 9 "github.com/sagernet/sing/common" 10 "github.com/sagernet/sing/common/buf" 11 E "github.com/sagernet/sing/common/exceptions" 12 "github.com/sagernet/sing/common/logger" 13 M "github.com/sagernet/sing/common/metadata" 14 N "github.com/sagernet/sing/common/network" 15 16 "github.com/gofrs/uuid/v5" 17 ) 18 19 type Client struct { 20 key [16]byte 21 flow string 22 logger logger.Logger 23 } 24 25 func NewClient(userId string, flow string, logger logger.Logger) (*Client, error) { 26 user := uuid.FromStringOrNil(userId) 27 if user == uuid.Nil { 28 user = uuid.NewV5(user, userId) 29 } 30 switch flow { 31 case "", "xtls-rprx-vision": 32 default: 33 return nil, E.New("unsupported flow: " + flow) 34 } 35 return &Client{user, flow, logger}, nil 36 } 37 38 func (c *Client) prepareConn(conn net.Conn) (net.Conn, error) { 39 if c.flow == FlowVision { 40 vConn, err := NewVisionConn(conn, c.key, c.logger) 41 if err != nil { 42 return nil, E.Cause(err, "initialize vision") 43 } 44 conn = vConn 45 } 46 return conn, nil 47 } 48 49 func (c *Client) DialConn(conn net.Conn, destination M.Socksaddr) (*Conn, error) { 50 vConn, err := c.prepareConn(conn) 51 if err != nil { 52 return nil, err 53 } 54 serverConn := &Conn{Conn: conn, protocolConn: vConn, key: c.key, command: vmess.CommandTCP, destination: destination, flow: c.flow} 55 return serverConn, common.Error(serverConn.Write(nil)) 56 } 57 58 func (c *Client) DialEarlyConn(conn net.Conn, destination M.Socksaddr) (*Conn, error) { 59 vConn, err := c.prepareConn(conn) 60 if err != nil { 61 return nil, err 62 } 63 return &Conn{Conn: conn, protocolConn: vConn, key: c.key, command: vmess.CommandTCP, destination: destination, flow: c.flow}, nil 64 } 65 66 func (c *Client) DialPacketConn(conn net.Conn, destination M.Socksaddr) (*PacketConn, error) { 67 serverConn := &PacketConn{Conn: conn, key: c.key, destination: destination, flow: c.flow} 68 return serverConn, common.Error(serverConn.Write(nil)) 69 } 70 71 func (c *Client) DialEarlyPacketConn(conn net.Conn, destination M.Socksaddr) (*PacketConn, error) { 72 return &PacketConn{Conn: conn, key: c.key, destination: destination, flow: c.flow}, nil 73 } 74 75 func (c *Client) DialXUDPPacketConn(conn net.Conn, destination M.Socksaddr) (vmess.PacketConn, error) { 76 serverConn := &Conn{Conn: conn, protocolConn: conn, key: c.key, command: vmess.CommandMux, destination: destination, flow: c.flow} 77 err := common.Error(serverConn.Write(nil)) 78 if err != nil { 79 return nil, err 80 } 81 return vmess.NewXUDPConn(serverConn, destination), nil 82 } 83 84 func (c *Client) DialEarlyXUDPPacketConn(conn net.Conn, destination M.Socksaddr) (vmess.PacketConn, error) { 85 return vmess.NewXUDPConn(&Conn{Conn: conn, protocolConn: conn, key: c.key, command: vmess.CommandMux, destination: destination, flow: c.flow}, destination), nil 86 } 87 88 var _ N.EarlyConn = (*Conn)(nil) 89 90 type Conn struct { 91 net.Conn 92 protocolConn net.Conn 93 key [16]byte 94 command byte 95 destination M.Socksaddr 96 flow string 97 requestWritten bool 98 responseRead bool 99 } 100 101 func (c *Conn) NeedHandshake() bool { 102 return !c.requestWritten 103 } 104 105 func (c *Conn) Read(b []byte) (n int, err error) { 106 if !c.responseRead { 107 err = ReadResponse(c.Conn) 108 if err != nil { 109 return 110 } 111 c.responseRead = true 112 } 113 return c.protocolConn.Read(b) 114 } 115 116 func (c *Conn) Write(b []byte) (n int, err error) { 117 if !c.requestWritten { 118 request := Request{c.key, c.command, c.destination, c.flow} 119 if c.protocolConn != nil { 120 err = WriteRequest(c.Conn, request, nil) 121 } else { 122 err = WriteRequest(c.Conn, request, b) 123 } 124 if err == nil { 125 n = len(b) 126 } 127 c.requestWritten = true 128 if c.protocolConn == nil { 129 return 130 } 131 } 132 return c.protocolConn.Write(b) 133 } 134 135 func (c *Conn) NeedAdditionalReadDeadline() bool { 136 return true 137 } 138 139 func (c *Conn) Upstream() any { 140 return c.Conn 141 } 142 143 type PacketConn struct { 144 net.Conn 145 key [16]byte 146 destination M.Socksaddr 147 flow string 148 requestWritten bool 149 responseRead bool 150 } 151 152 func (c *PacketConn) Read(b []byte) (n int, err error) { 153 if !c.responseRead { 154 err = ReadResponse(c.Conn) 155 if err != nil { 156 return 157 } 158 c.responseRead = true 159 } 160 var length uint16 161 err = binary.Read(c.Conn, binary.BigEndian, &length) 162 if err != nil { 163 return 164 } 165 if cap(b) < int(length) { 166 return 0, io.ErrShortBuffer 167 } 168 return io.ReadFull(c.Conn, b[:length]) 169 } 170 171 func (c *PacketConn) Write(b []byte) (n int, err error) { 172 if !c.requestWritten { 173 err = WritePacketRequest(c.Conn, Request{c.key, vmess.CommandUDP, c.destination, c.flow}, nil) 174 if err == nil { 175 n = len(b) 176 } 177 c.requestWritten = true 178 } 179 err = binary.Write(c.Conn, binary.BigEndian, uint16(len(b))) 180 if err != nil { 181 return 182 } 183 return c.Conn.Write(b) 184 } 185 186 func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 187 defer buffer.Release() 188 dataLen := buffer.Len() 189 binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(dataLen)) 190 if !c.requestWritten { 191 err := WritePacketRequest(c.Conn, Request{c.key, vmess.CommandUDP, c.destination, c.flow}, buffer.Bytes()) 192 c.requestWritten = true 193 return err 194 } 195 return common.Error(c.Conn.Write(buffer.Bytes())) 196 } 197 198 func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 199 n, err = c.Read(p) 200 if err != nil { 201 return 202 } 203 if c.destination.IsFqdn() { 204 addr = c.destination 205 } else { 206 addr = c.destination.UDPAddr() 207 } 208 return 209 } 210 211 func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 212 return c.Write(p) 213 } 214 215 func (c *PacketConn) FrontHeadroom() int { 216 return 2 217 } 218 219 func (c *PacketConn) NeedAdditionalReadDeadline() bool { 220 return true 221 } 222 223 func (c *PacketConn) Upstream() any { 224 return c.Conn 225 }