github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/trojan/protocol.go (about) 1 package trojan 2 3 import ( 4 "crypto/sha256" 5 "encoding/binary" 6 "encoding/hex" 7 "net" 8 "os" 9 "sync" 10 11 "github.com/sagernet/sing/common" 12 "github.com/sagernet/sing/common/buf" 13 "github.com/sagernet/sing/common/bufio" 14 E "github.com/sagernet/sing/common/exceptions" 15 M "github.com/sagernet/sing/common/metadata" 16 N "github.com/sagernet/sing/common/network" 17 "github.com/sagernet/sing/common/rw" 18 ) 19 20 const ( 21 KeyLength = 56 22 CommandTCP = 1 23 CommandUDP = 3 24 CommandMux = 0x7f 25 ) 26 27 var CRLF = []byte{'\r', '\n'} 28 29 var _ N.EarlyConn = (*ClientConn)(nil) 30 31 type ClientConn struct { 32 N.ExtendedConn 33 key [KeyLength]byte 34 destination M.Socksaddr 35 headerWritten bool 36 } 37 38 func NewClientConn(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr) *ClientConn { 39 return &ClientConn{ 40 ExtendedConn: bufio.NewExtendedConn(conn), 41 key: key, 42 destination: destination, 43 } 44 } 45 46 func (c *ClientConn) NeedHandshake() bool { 47 return !c.headerWritten 48 } 49 50 func (c *ClientConn) Write(p []byte) (n int, err error) { 51 if c.headerWritten { 52 return c.ExtendedConn.Write(p) 53 } 54 err = ClientHandshake(c.ExtendedConn, c.key, c.destination, p) 55 if err != nil { 56 return 57 } 58 n = len(p) 59 c.headerWritten = true 60 return 61 } 62 63 func (c *ClientConn) WriteBuffer(buffer *buf.Buffer) error { 64 if c.headerWritten { 65 return c.ExtendedConn.WriteBuffer(buffer) 66 } 67 err := ClientHandshakeBuffer(c.ExtendedConn, c.key, c.destination, buffer) 68 if err != nil { 69 return err 70 } 71 c.headerWritten = true 72 return nil 73 } 74 75 func (c *ClientConn) FrontHeadroom() int { 76 if !c.headerWritten { 77 return KeyLength + 5 + M.MaxSocksaddrLength 78 } 79 return 0 80 } 81 82 func (c *ClientConn) Upstream() any { 83 return c.ExtendedConn 84 } 85 86 type ClientPacketConn struct { 87 net.Conn 88 access sync.Mutex 89 key [KeyLength]byte 90 headerWritten bool 91 } 92 93 func NewClientPacketConn(conn net.Conn, key [KeyLength]byte) *ClientPacketConn { 94 return &ClientPacketConn{ 95 Conn: conn, 96 key: key, 97 } 98 } 99 100 func (c *ClientPacketConn) NeedHandshake() bool { 101 return !c.headerWritten 102 } 103 104 func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { 105 return ReadPacket(c.Conn, buffer) 106 } 107 108 func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 109 if !c.headerWritten { 110 c.access.Lock() 111 if c.headerWritten { 112 c.access.Unlock() 113 } else { 114 err := ClientHandshakePacket(c.Conn, c.key, destination, buffer) 115 c.headerWritten = true 116 c.access.Unlock() 117 return err 118 } 119 } 120 return WritePacket(c.Conn, buffer, destination) 121 } 122 123 func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 124 buffer := buf.With(p) 125 destination, err := c.ReadPacket(buffer) 126 if err != nil { 127 return 128 } 129 n = buffer.Len() 130 if destination.IsFqdn() { 131 addr = destination 132 } else { 133 addr = destination.UDPAddr() 134 } 135 return 136 } 137 138 func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 139 return bufio.WritePacket(c, p, addr) 140 } 141 142 func (c *ClientPacketConn) Read(p []byte) (n int, err error) { 143 n, _, err = c.ReadFrom(p) 144 return 145 } 146 147 func (c *ClientPacketConn) Write(p []byte) (n int, err error) { 148 return 0, os.ErrInvalid 149 } 150 151 func (c *ClientPacketConn) FrontHeadroom() int { 152 if !c.headerWritten { 153 return KeyLength + 2*M.MaxSocksaddrLength + 9 154 } 155 return M.MaxSocksaddrLength + 4 156 } 157 158 func (c *ClientPacketConn) Upstream() any { 159 return c.Conn 160 } 161 162 func Key(password string) [KeyLength]byte { 163 var key [KeyLength]byte 164 hash := sha256.New224() 165 common.Must1(hash.Write([]byte(password))) 166 hex.Encode(key[:], hash.Sum(nil)) 167 return key 168 } 169 170 func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destination M.Socksaddr, payload []byte) error { 171 _, err := conn.Write(key[:]) 172 if err != nil { 173 return err 174 } 175 _, err = conn.Write(CRLF) 176 if err != nil { 177 return err 178 } 179 _, err = conn.Write([]byte{command}) 180 if err != nil { 181 return err 182 } 183 err = M.SocksaddrSerializer.WriteAddrPort(conn, destination) 184 if err != nil { 185 return err 186 } 187 _, err = conn.Write(CRLF) 188 if err != nil { 189 return err 190 } 191 if len(payload) > 0 { 192 _, err = conn.Write(payload) 193 if err != nil { 194 return err 195 } 196 } 197 return nil 198 } 199 200 func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload []byte) error { 201 headerLen := KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5 202 header := buf.NewSize(headerLen + len(payload)) 203 defer header.Release() 204 common.Must1(header.Write(key[:])) 205 common.Must1(header.Write(CRLF)) 206 common.Must(header.WriteByte(CommandTCP)) 207 err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 208 if err != nil { 209 return err 210 } 211 common.Must1(header.Write(CRLF)) 212 common.Must1(header.Write(payload)) 213 _, err = conn.Write(header.Bytes()) 214 if err != nil { 215 return E.Cause(err, "write request") 216 } 217 return nil 218 } 219 220 func ClientHandshakeBuffer(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error { 221 header := buf.With(payload.ExtendHeader(KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5)) 222 common.Must1(header.Write(key[:])) 223 common.Must1(header.Write(CRLF)) 224 common.Must(header.WriteByte(CommandTCP)) 225 err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 226 if err != nil { 227 return err 228 } 229 common.Must1(header.Write(CRLF)) 230 231 _, err = conn.Write(payload.Bytes()) 232 if err != nil { 233 return E.Cause(err, "write request") 234 } 235 return nil 236 } 237 238 func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error { 239 headerLen := KeyLength + 2*M.SocksaddrSerializer.AddrPortLen(destination) + 9 240 payloadLen := payload.Len() 241 var header *buf.Buffer 242 var writeHeader bool 243 if payload.Start() >= headerLen { 244 header = buf.With(payload.ExtendHeader(headerLen)) 245 } else { 246 header = buf.NewSize(headerLen) 247 defer header.Release() 248 writeHeader = true 249 } 250 common.Must1(header.Write(key[:])) 251 common.Must1(header.Write(CRLF)) 252 common.Must(header.WriteByte(CommandUDP)) 253 err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 254 if err != nil { 255 return err 256 } 257 common.Must1(header.Write(CRLF)) 258 common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination)) 259 common.Must(binary.Write(header, binary.BigEndian, uint16(payloadLen))) 260 common.Must1(header.Write(CRLF)) 261 262 if writeHeader { 263 _, err := conn.Write(header.Bytes()) 264 if err != nil { 265 return E.Cause(err, "write request") 266 } 267 } 268 269 _, err = conn.Write(payload.Bytes()) 270 if err != nil { 271 return E.Cause(err, "write payload") 272 } 273 return nil 274 } 275 276 func ReadPacket(conn net.Conn, buffer *buf.Buffer) (M.Socksaddr, error) { 277 destination, err := M.SocksaddrSerializer.ReadAddrPort(conn) 278 if err != nil { 279 return M.Socksaddr{}, E.Cause(err, "read destination") 280 } 281 282 var length uint16 283 err = binary.Read(conn, binary.BigEndian, &length) 284 if err != nil { 285 return M.Socksaddr{}, E.Cause(err, "read chunk length") 286 } 287 288 err = rw.SkipN(conn, 2) 289 if err != nil { 290 return M.Socksaddr{}, E.Cause(err, "skip crlf") 291 } 292 293 _, err = buffer.ReadFullFrom(conn, int(length)) 294 return destination.Unwrap(), err 295 } 296 297 func WritePacket(conn net.Conn, buffer *buf.Buffer, destination M.Socksaddr) error { 298 defer buffer.Release() 299 bufferLen := buffer.Len() 300 header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 4)) 301 err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 302 if err != nil { 303 return err 304 } 305 common.Must(binary.Write(header, binary.BigEndian, uint16(bufferLen))) 306 common.Must1(header.Write(CRLF)) 307 _, err = conn.Write(buffer.Bytes()) 308 if err != nil { 309 return E.Cause(err, "write packet") 310 } 311 return nil 312 }