github.com/sagernet/sing-box@v1.9.0-rc.20/transport/trojan/protocol.go (about) 1 package trojan 2 3 import ( 4 "crypto/sha256" 5 "encoding/binary" 6 "encoding/hex" 7 "net" 8 "os" 9 "sync" 10 11 "github.com/sagernet/sing/common" 12 "github.com/sagernet/sing/common/buf" 13 "github.com/sagernet/sing/common/bufio" 14 E "github.com/sagernet/sing/common/exceptions" 15 M "github.com/sagernet/sing/common/metadata" 16 N "github.com/sagernet/sing/common/network" 17 "github.com/sagernet/sing/common/rw" 18 ) 19 20 const ( 21 KeyLength = 56 22 CommandTCP = 1 23 CommandUDP = 3 24 CommandMux = 0x7f 25 ) 26 27 var CRLF = []byte{'\r', '\n'} 28 29 var _ N.EarlyConn = (*ClientConn)(nil) 30 31 type ClientConn struct { 32 N.ExtendedConn 33 key [KeyLength]byte 34 destination M.Socksaddr 35 headerWritten bool 36 } 37 38 func NewClientConn(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr) *ClientConn { 39 return &ClientConn{ 40 ExtendedConn: bufio.NewExtendedConn(conn), 41 key: key, 42 destination: destination, 43 } 44 } 45 46 func (c *ClientConn) NeedHandshake() bool { 47 return !c.headerWritten 48 } 49 50 func (c *ClientConn) Write(p []byte) (n int, err error) { 51 if c.headerWritten { 52 return c.ExtendedConn.Write(p) 53 } 54 err = ClientHandshake(c.ExtendedConn, c.key, c.destination, p) 55 if err != nil { 56 return 57 } 58 n = len(p) 59 c.headerWritten = true 60 return 61 } 62 63 func (c *ClientConn) WriteBuffer(buffer *buf.Buffer) error { 64 if c.headerWritten { 65 return c.ExtendedConn.WriteBuffer(buffer) 66 } 67 err := ClientHandshakeBuffer(c.ExtendedConn, c.key, c.destination, buffer) 68 if err != nil { 69 return err 70 } 71 c.headerWritten = true 72 return nil 73 } 74 75 func (c *ClientConn) FrontHeadroom() int { 76 if !c.headerWritten { 77 return KeyLength + 5 + M.MaxSocksaddrLength 78 } 79 return 0 80 } 81 82 func (c *ClientConn) Upstream() any { 83 return c.ExtendedConn 84 } 85 86 type ClientPacketConn struct { 87 net.Conn 88 access sync.Mutex 89 key [KeyLength]byte 90 headerWritten bool 91 readWaitOptions N.ReadWaitOptions 92 } 93 94 func NewClientPacketConn(conn net.Conn, key [KeyLength]byte) *ClientPacketConn { 95 return &ClientPacketConn{ 96 Conn: conn, 97 key: key, 98 } 99 } 100 101 func (c *ClientPacketConn) NeedHandshake() bool { 102 return !c.headerWritten 103 } 104 105 func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { 106 return ReadPacket(c.Conn, buffer) 107 } 108 109 func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 110 if !c.headerWritten { 111 c.access.Lock() 112 if c.headerWritten { 113 c.access.Unlock() 114 } else { 115 err := ClientHandshakePacket(c.Conn, c.key, destination, buffer) 116 c.headerWritten = true 117 c.access.Unlock() 118 return err 119 } 120 } 121 return WritePacket(c.Conn, buffer, destination) 122 } 123 124 func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 125 buffer := buf.With(p) 126 destination, err := c.ReadPacket(buffer) 127 if err != nil { 128 return 129 } 130 n = buffer.Len() 131 if destination.IsFqdn() { 132 addr = destination 133 } else { 134 addr = destination.UDPAddr() 135 } 136 return 137 } 138 139 func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 140 return bufio.WritePacket(c, p, addr) 141 } 142 143 func (c *ClientPacketConn) Read(p []byte) (n int, err error) { 144 n, _, err = c.ReadFrom(p) 145 return 146 } 147 148 func (c *ClientPacketConn) Write(p []byte) (n int, err error) { 149 return 0, os.ErrInvalid 150 } 151 152 func (c *ClientPacketConn) FrontHeadroom() int { 153 if !c.headerWritten { 154 return KeyLength + 2*M.MaxSocksaddrLength + 9 155 } 156 return M.MaxSocksaddrLength + 4 157 } 158 159 func (c *ClientPacketConn) Upstream() any { 160 return c.Conn 161 } 162 163 func Key(password string) [KeyLength]byte { 164 var key [KeyLength]byte 165 hash := sha256.New224() 166 common.Must1(hash.Write([]byte(password))) 167 hex.Encode(key[:], hash.Sum(nil)) 168 return key 169 } 170 171 func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destination M.Socksaddr, payload []byte) error { 172 _, err := conn.Write(key[:]) 173 if err != nil { 174 return err 175 } 176 _, err = conn.Write(CRLF) 177 if err != nil { 178 return err 179 } 180 _, err = conn.Write([]byte{command}) 181 if err != nil { 182 return err 183 } 184 err = M.SocksaddrSerializer.WriteAddrPort(conn, destination) 185 if err != nil { 186 return err 187 } 188 _, err = conn.Write(CRLF) 189 if err != nil { 190 return err 191 } 192 if len(payload) > 0 { 193 _, err = conn.Write(payload) 194 if err != nil { 195 return err 196 } 197 } 198 return nil 199 } 200 201 func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload []byte) error { 202 headerLen := KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5 203 header := buf.NewSize(headerLen + len(payload)) 204 defer header.Release() 205 common.Must1(header.Write(key[:])) 206 common.Must1(header.Write(CRLF)) 207 common.Must(header.WriteByte(CommandTCP)) 208 err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 209 if err != nil { 210 return err 211 } 212 common.Must1(header.Write(CRLF)) 213 common.Must1(header.Write(payload)) 214 _, err = conn.Write(header.Bytes()) 215 if err != nil { 216 return E.Cause(err, "write request") 217 } 218 return nil 219 } 220 221 func ClientHandshakeBuffer(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error { 222 header := buf.With(payload.ExtendHeader(KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5)) 223 common.Must1(header.Write(key[:])) 224 common.Must1(header.Write(CRLF)) 225 common.Must(header.WriteByte(CommandTCP)) 226 err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 227 if err != nil { 228 return err 229 } 230 common.Must1(header.Write(CRLF)) 231 232 _, err = conn.Write(payload.Bytes()) 233 if err != nil { 234 return E.Cause(err, "write request") 235 } 236 return nil 237 } 238 239 func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error { 240 headerLen := KeyLength + 2*M.SocksaddrSerializer.AddrPortLen(destination) + 9 241 payloadLen := payload.Len() 242 var header *buf.Buffer 243 var writeHeader bool 244 if payload.Start() >= headerLen { 245 header = buf.With(payload.ExtendHeader(headerLen)) 246 } else { 247 header = buf.NewSize(headerLen) 248 defer header.Release() 249 writeHeader = true 250 } 251 common.Must1(header.Write(key[:])) 252 common.Must1(header.Write(CRLF)) 253 common.Must(header.WriteByte(CommandUDP)) 254 err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 255 if err != nil { 256 return err 257 } 258 common.Must1(header.Write(CRLF)) 259 common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination)) 260 common.Must(binary.Write(header, binary.BigEndian, uint16(payloadLen))) 261 common.Must1(header.Write(CRLF)) 262 263 if writeHeader { 264 _, err := conn.Write(header.Bytes()) 265 if err != nil { 266 return E.Cause(err, "write request") 267 } 268 } 269 270 _, err = conn.Write(payload.Bytes()) 271 if err != nil { 272 return E.Cause(err, "write payload") 273 } 274 return nil 275 } 276 277 func ReadPacket(conn net.Conn, buffer *buf.Buffer) (M.Socksaddr, error) { 278 destination, err := M.SocksaddrSerializer.ReadAddrPort(conn) 279 if err != nil { 280 return M.Socksaddr{}, E.Cause(err, "read destination") 281 } 282 283 var length uint16 284 err = binary.Read(conn, binary.BigEndian, &length) 285 if err != nil { 286 return M.Socksaddr{}, E.Cause(err, "read chunk length") 287 } 288 289 err = rw.SkipN(conn, 2) 290 if err != nil { 291 return M.Socksaddr{}, E.Cause(err, "skip crlf") 292 } 293 294 _, err = buffer.ReadFullFrom(conn, int(length)) 295 return destination.Unwrap(), err 296 } 297 298 func WritePacket(conn net.Conn, buffer *buf.Buffer, destination M.Socksaddr) error { 299 defer buffer.Release() 300 bufferLen := buffer.Len() 301 header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 4)) 302 err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 303 if err != nil { 304 return err 305 } 306 common.Must(binary.Write(header, binary.BigEndian, uint16(bufferLen))) 307 common.Must1(header.Write(CRLF)) 308 _, err = conn.Write(buffer.Bytes()) 309 if err != nil { 310 return E.Cause(err, "write packet") 311 } 312 return nil 313 }