github.com/sagernet/sing-box@v1.2.7/transport/trojan/protocol.go (about) 1 package trojan 2 3 import ( 4 "crypto/sha256" 5 "encoding/binary" 6 "encoding/hex" 7 "io" 8 "net" 9 "os" 10 11 "github.com/sagernet/sing/common" 12 "github.com/sagernet/sing/common/buf" 13 "github.com/sagernet/sing/common/bufio" 14 E "github.com/sagernet/sing/common/exceptions" 15 M "github.com/sagernet/sing/common/metadata" 16 N "github.com/sagernet/sing/common/network" 17 "github.com/sagernet/sing/common/rw" 18 ) 19 20 const ( 21 KeyLength = 56 22 CommandTCP = 1 23 CommandUDP = 3 24 CommandMux = 0x7f 25 ) 26 27 var CRLF = []byte{'\r', '\n'} 28 29 var _ N.EarlyConn = (*ClientConn)(nil) 30 31 type ClientConn struct { 32 N.ExtendedConn 33 key [KeyLength]byte 34 destination M.Socksaddr 35 headerWritten bool 36 } 37 38 func NewClientConn(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr) *ClientConn { 39 return &ClientConn{ 40 ExtendedConn: bufio.NewExtendedConn(conn), 41 key: key, 42 destination: destination, 43 } 44 } 45 46 func (c *ClientConn) NeedHandshake() bool { 47 return !c.headerWritten 48 } 49 50 func (c *ClientConn) Write(p []byte) (n int, err error) { 51 if c.headerWritten { 52 return c.ExtendedConn.Write(p) 53 } 54 err = ClientHandshake(c.ExtendedConn, c.key, c.destination, p) 55 if err != nil { 56 return 57 } 58 n = len(p) 59 c.headerWritten = true 60 return 61 } 62 63 func (c *ClientConn) WriteBuffer(buffer *buf.Buffer) error { 64 if c.headerWritten { 65 return c.ExtendedConn.WriteBuffer(buffer) 66 } 67 err := ClientHandshakeBuffer(c.ExtendedConn, c.key, c.destination, buffer) 68 if err != nil { 69 return err 70 } 71 c.headerWritten = true 72 return nil 73 } 74 75 func (c *ClientConn) ReadFrom(r io.Reader) (n int64, err error) { 76 if !c.headerWritten { 77 return bufio.ReadFrom0(c, r) 78 } 79 return bufio.Copy(c.ExtendedConn, r) 80 } 81 82 func (c *ClientConn) WriteTo(w io.Writer) (n int64, err error) { 83 return bufio.Copy(w, c.ExtendedConn) 84 } 85 86 func (c *ClientConn) FrontHeadroom() int { 87 if !c.headerWritten { 88 return KeyLength + 5 + M.MaxSocksaddrLength 89 } 90 return 0 91 } 92 93 func (c *ClientConn) Upstream() any { 94 return c.ExtendedConn 95 } 96 97 type ClientPacketConn struct { 98 net.Conn 99 key [KeyLength]byte 100 headerWritten bool 101 } 102 103 func NewClientPacketConn(conn net.Conn, key [KeyLength]byte) *ClientPacketConn { 104 return &ClientPacketConn{ 105 Conn: conn, 106 key: key, 107 } 108 } 109 110 func (c *ClientPacketConn) NeedHandshake() bool { 111 return !c.headerWritten 112 } 113 114 func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { 115 return ReadPacket(c.Conn, buffer) 116 } 117 118 func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 119 if !c.headerWritten { 120 err := ClientHandshakePacket(c.Conn, c.key, destination, buffer) 121 c.headerWritten = true 122 return err 123 } 124 return WritePacket(c.Conn, buffer, destination) 125 } 126 127 func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 128 buffer := buf.With(p) 129 destination, err := c.ReadPacket(buffer) 130 if err != nil { 131 return 132 } 133 n = buffer.Len() 134 if destination.IsFqdn() { 135 addr = destination 136 } else { 137 addr = destination.UDPAddr() 138 } 139 return 140 } 141 142 func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 143 return bufio.WritePacket(c, p, addr) 144 } 145 146 func (c *ClientPacketConn) Read(p []byte) (n int, err error) { 147 n, _, err = c.ReadFrom(p) 148 return 149 } 150 151 func (c *ClientPacketConn) Write(p []byte) (n int, err error) { 152 return 0, os.ErrInvalid 153 } 154 155 func (c *ClientPacketConn) FrontHeadroom() int { 156 if !c.headerWritten { 157 return KeyLength + 2*M.MaxSocksaddrLength + 9 158 } 159 return M.MaxSocksaddrLength + 4 160 } 161 162 func (c *ClientPacketConn) Upstream() any { 163 return c.Conn 164 } 165 166 func Key(password string) [KeyLength]byte { 167 var key [KeyLength]byte 168 hash := sha256.New224() 169 common.Must1(hash.Write([]byte(password))) 170 hex.Encode(key[:], hash.Sum(nil)) 171 return key 172 } 173 174 func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destination M.Socksaddr, payload []byte) error { 175 _, err := conn.Write(key[:]) 176 if err != nil { 177 return err 178 } 179 _, err = conn.Write(CRLF) 180 if err != nil { 181 return err 182 } 183 _, err = conn.Write([]byte{command}) 184 if err != nil { 185 return err 186 } 187 err = M.SocksaddrSerializer.WriteAddrPort(conn, destination) 188 if err != nil { 189 return err 190 } 191 _, err = conn.Write(CRLF) 192 if err != nil { 193 return err 194 } 195 if len(payload) > 0 { 196 _, err = conn.Write(payload) 197 if err != nil { 198 return err 199 } 200 } 201 return nil 202 } 203 204 func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload []byte) error { 205 headerLen := KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5 206 var header *buf.Buffer 207 defer header.Release() 208 var writeHeader bool 209 if len(payload) > 0 && headerLen+len(payload) < 65535 { 210 buffer := buf.StackNewSize(headerLen + len(payload)) 211 defer common.KeepAlive(buffer) 212 header = common.Dup(buffer) 213 } else { 214 buffer := buf.StackNewSize(headerLen) 215 defer common.KeepAlive(buffer) 216 header = common.Dup(buffer) 217 writeHeader = true 218 } 219 common.Must1(header.Write(key[:])) 220 common.Must1(header.Write(CRLF)) 221 common.Must(header.WriteByte(CommandTCP)) 222 common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination)) 223 common.Must1(header.Write(CRLF)) 224 if !writeHeader { 225 common.Must1(header.Write(payload)) 226 } 227 228 _, err := conn.Write(header.Bytes()) 229 if err != nil { 230 return E.Cause(err, "write request") 231 } 232 233 if writeHeader { 234 _, err = conn.Write(payload) 235 if err != nil { 236 return E.Cause(err, "write payload") 237 } 238 } 239 return nil 240 } 241 242 func ClientHandshakeBuffer(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error { 243 header := buf.With(payload.ExtendHeader(KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5)) 244 common.Must1(header.Write(key[:])) 245 common.Must1(header.Write(CRLF)) 246 common.Must(header.WriteByte(CommandTCP)) 247 common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination)) 248 common.Must1(header.Write(CRLF)) 249 250 _, err := conn.Write(payload.Bytes()) 251 if err != nil { 252 return E.Cause(err, "write request") 253 } 254 return nil 255 } 256 257 func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error { 258 headerLen := KeyLength + 2*M.SocksaddrSerializer.AddrPortLen(destination) + 9 259 payloadLen := payload.Len() 260 var header *buf.Buffer 261 defer header.Release() 262 var writeHeader bool 263 if payload.Start() >= headerLen { 264 header = buf.With(payload.ExtendHeader(headerLen)) 265 } else { 266 buffer := buf.StackNewSize(headerLen) 267 defer common.KeepAlive(buffer) 268 header = common.Dup(buffer) 269 writeHeader = true 270 } 271 common.Must1(header.Write(key[:])) 272 common.Must1(header.Write(CRLF)) 273 common.Must(header.WriteByte(CommandUDP)) 274 common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination)) 275 common.Must1(header.Write(CRLF)) 276 common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination)) 277 common.Must(binary.Write(header, binary.BigEndian, uint16(payloadLen))) 278 common.Must1(header.Write(CRLF)) 279 280 if writeHeader { 281 _, err := conn.Write(header.Bytes()) 282 if err != nil { 283 return E.Cause(err, "write request") 284 } 285 } 286 287 _, err := conn.Write(payload.Bytes()) 288 if err != nil { 289 return E.Cause(err, "write payload") 290 } 291 return nil 292 } 293 294 func ReadPacket(conn net.Conn, buffer *buf.Buffer) (M.Socksaddr, error) { 295 destination, err := M.SocksaddrSerializer.ReadAddrPort(conn) 296 if err != nil { 297 return M.Socksaddr{}, E.Cause(err, "read destination") 298 } 299 300 var length uint16 301 err = binary.Read(conn, binary.BigEndian, &length) 302 if err != nil { 303 return M.Socksaddr{}, E.Cause(err, "read chunk length") 304 } 305 306 err = rw.SkipN(conn, 2) 307 if err != nil { 308 return M.Socksaddr{}, E.Cause(err, "skip crlf") 309 } 310 311 _, err = buffer.ReadFullFrom(conn, int(length)) 312 return destination, err 313 } 314 315 func WritePacket(conn net.Conn, buffer *buf.Buffer, destination M.Socksaddr) error { 316 defer buffer.Release() 317 bufferLen := buffer.Len() 318 header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 4)) 319 common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination)) 320 common.Must(binary.Write(header, binary.BigEndian, uint16(bufferLen))) 321 common.Must1(header.Write(CRLF)) 322 _, err := conn.Write(buffer.Bytes()) 323 if err != nil { 324 return E.Cause(err, "write packet") 325 } 326 return nil 327 }