github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/proxy/trojan/protocol.go (about) 1 package trojan 2 3 import ( 4 "encoding/binary" 5 "io" 6 7 "github.com/xtls/xray-core/common/buf" 8 "github.com/xtls/xray-core/common/net" 9 "github.com/xtls/xray-core/common/protocol" 10 ) 11 12 var ( 13 crlf = []byte{'\r', '\n'} 14 15 addrParser = protocol.NewAddressParser( 16 protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4), 17 protocol.AddressFamilyByte(0x04, net.AddressFamilyIPv6), 18 protocol.AddressFamilyByte(0x03, net.AddressFamilyDomain), 19 ) 20 ) 21 22 const ( 23 maxLength = 8192 24 25 commandTCP byte = 1 26 commandUDP byte = 3 27 ) 28 29 // ConnWriter is TCP Connection Writer Wrapper for trojan protocol 30 type ConnWriter struct { 31 io.Writer 32 Target net.Destination 33 Account *MemoryAccount 34 headerSent bool 35 } 36 37 // Write implements io.Writer 38 func (c *ConnWriter) Write(p []byte) (n int, err error) { 39 if !c.headerSent { 40 if err := c.writeHeader(); err != nil { 41 return 0, newError("failed to write request header").Base(err) 42 } 43 } 44 45 return c.Writer.Write(p) 46 } 47 48 // WriteMultiBuffer implements buf.Writer 49 func (c *ConnWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { 50 defer buf.ReleaseMulti(mb) 51 52 for _, b := range mb { 53 if !b.IsEmpty() { 54 if _, err := c.Write(b.Bytes()); err != nil { 55 return err 56 } 57 } 58 } 59 60 return nil 61 } 62 63 func (c *ConnWriter) writeHeader() error { 64 buffer := buf.StackNew() 65 defer buffer.Release() 66 67 command := commandTCP 68 if c.Target.Network == net.Network_UDP { 69 command = commandUDP 70 } 71 72 if _, err := buffer.Write(c.Account.Key); err != nil { 73 return err 74 } 75 if _, err := buffer.Write(crlf); err != nil { 76 return err 77 } 78 if err := buffer.WriteByte(command); err != nil { 79 return err 80 } 81 if err := addrParser.WriteAddressPort(&buffer, c.Target.Address, c.Target.Port); err != nil { 82 return err 83 } 84 if _, err := buffer.Write(crlf); err != nil { 85 return err 86 } 87 88 _, err := c.Writer.Write(buffer.Bytes()) 89 if err == nil { 90 c.headerSent = true 91 } 92 93 return err 94 } 95 96 // PacketWriter UDP Connection Writer Wrapper for trojan protocol 97 type PacketWriter struct { 98 io.Writer 99 Target net.Destination 100 } 101 102 // WriteMultiBuffer implements buf.Writer 103 func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error { 104 for { 105 mb2, b := buf.SplitFirst(mb) 106 mb = mb2 107 if b == nil { 108 break 109 } 110 target := &w.Target 111 if b.UDP != nil { 112 target = b.UDP 113 } 114 if _, err := w.writePacket(b.Bytes(), *target); err != nil { 115 buf.ReleaseMulti(mb) 116 return err 117 } 118 } 119 return nil 120 } 121 122 func (w *PacketWriter) writePacket(payload []byte, dest net.Destination) (int, error) { 123 buffer := buf.StackNew() 124 defer buffer.Release() 125 126 length := len(payload) 127 lengthBuf := [2]byte{} 128 binary.BigEndian.PutUint16(lengthBuf[:], uint16(length)) 129 if err := addrParser.WriteAddressPort(&buffer, dest.Address, dest.Port); err != nil { 130 return 0, err 131 } 132 if _, err := buffer.Write(lengthBuf[:]); err != nil { 133 return 0, err 134 } 135 if _, err := buffer.Write(crlf); err != nil { 136 return 0, err 137 } 138 if _, err := buffer.Write(payload); err != nil { 139 return 0, err 140 } 141 _, err := w.Write(buffer.Bytes()) 142 if err != nil { 143 return 0, err 144 } 145 146 return length, nil 147 } 148 149 // ConnReader is TCP Connection Reader Wrapper for trojan protocol 150 type ConnReader struct { 151 io.Reader 152 Target net.Destination 153 Flow string 154 headerParsed bool 155 } 156 157 // ParseHeader parses the trojan protocol header 158 func (c *ConnReader) ParseHeader() error { 159 var crlf [2]byte 160 var command [1]byte 161 var hash [56]byte 162 if _, err := io.ReadFull(c.Reader, hash[:]); err != nil { 163 return newError("failed to read user hash").Base(err) 164 } 165 166 if _, err := io.ReadFull(c.Reader, crlf[:]); err != nil { 167 return newError("failed to read crlf").Base(err) 168 } 169 170 if _, err := io.ReadFull(c.Reader, command[:]); err != nil { 171 return newError("failed to read command").Base(err) 172 } 173 174 network := net.Network_TCP 175 if command[0] == commandUDP { 176 network = net.Network_UDP 177 } 178 179 addr, port, err := addrParser.ReadAddressPort(nil, c.Reader) 180 if err != nil { 181 return newError("failed to read address and port").Base(err) 182 } 183 c.Target = net.Destination{Network: network, Address: addr, Port: port} 184 185 if _, err := io.ReadFull(c.Reader, crlf[:]); err != nil { 186 return newError("failed to read crlf").Base(err) 187 } 188 189 c.headerParsed = true 190 return nil 191 } 192 193 // Read implements io.Reader 194 func (c *ConnReader) Read(p []byte) (int, error) { 195 if !c.headerParsed { 196 if err := c.ParseHeader(); err != nil { 197 return 0, err 198 } 199 } 200 201 return c.Reader.Read(p) 202 } 203 204 // ReadMultiBuffer implements buf.Reader 205 func (c *ConnReader) ReadMultiBuffer() (buf.MultiBuffer, error) { 206 b := buf.New() 207 _, err := b.ReadFrom(c) 208 return buf.MultiBuffer{b}, err 209 } 210 211 // PacketReader is UDP Connection Reader Wrapper for trojan protocol 212 type PacketReader struct { 213 io.Reader 214 } 215 216 // ReadMultiBuffer implements buf.Reader 217 func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) { 218 addr, port, err := addrParser.ReadAddressPort(nil, r) 219 if err != nil { 220 return nil, newError("failed to read address and port").Base(err) 221 } 222 223 var lengthBuf [2]byte 224 if _, err := io.ReadFull(r, lengthBuf[:]); err != nil { 225 return nil, newError("failed to read payload length").Base(err) 226 } 227 228 remain := int(binary.BigEndian.Uint16(lengthBuf[:])) 229 if remain > maxLength { 230 return nil, newError("oversize payload") 231 } 232 233 var crlf [2]byte 234 if _, err := io.ReadFull(r, crlf[:]); err != nil { 235 return nil, newError("failed to read crlf").Base(err) 236 } 237 238 dest := net.UDPDestination(addr, port) 239 var mb buf.MultiBuffer 240 for remain > 0 { 241 length := buf.Size 242 if remain < length { 243 length = remain 244 } 245 246 b := buf.New() 247 b.UDP = &dest 248 mb = append(mb, b) 249 n, err := b.ReadFullFrom(r, int32(length)) 250 if err != nil { 251 buf.ReleaseMulti(mb) 252 return nil, newError("failed to read payload").Base(err) 253 } 254 255 remain -= int(n) 256 } 257 258 return mb, nil 259 }