github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/proxy/trojan/protocol.go (about)

     1  package trojan
     2  
     3  import (
     4  	"encoding/binary"
     5  	"io"
     6  
     7  	"github.com/xtls/xray-core/common/buf"
     8  	"github.com/xtls/xray-core/common/net"
     9  	"github.com/xtls/xray-core/common/protocol"
    10  )
    11  
    12  var (
    13  	crlf = []byte{'\r', '\n'}
    14  
    15  	addrParser = protocol.NewAddressParser(
    16  		protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4),
    17  		protocol.AddressFamilyByte(0x04, net.AddressFamilyIPv6),
    18  		protocol.AddressFamilyByte(0x03, net.AddressFamilyDomain),
    19  	)
    20  )
    21  
    22  const (
    23  	maxLength = 8192
    24  
    25  	commandTCP byte = 1
    26  	commandUDP byte = 3
    27  )
    28  
    29  // ConnWriter is TCP Connection Writer Wrapper for trojan protocol
    30  type ConnWriter struct {
    31  	io.Writer
    32  	Target     net.Destination
    33  	Account    *MemoryAccount
    34  	headerSent bool
    35  }
    36  
    37  // Write implements io.Writer
    38  func (c *ConnWriter) Write(p []byte) (n int, err error) {
    39  	if !c.headerSent {
    40  		if err := c.writeHeader(); err != nil {
    41  			return 0, newError("failed to write request header").Base(err)
    42  		}
    43  	}
    44  
    45  	return c.Writer.Write(p)
    46  }
    47  
    48  // WriteMultiBuffer implements buf.Writer
    49  func (c *ConnWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
    50  	defer buf.ReleaseMulti(mb)
    51  
    52  	for _, b := range mb {
    53  		if !b.IsEmpty() {
    54  			if _, err := c.Write(b.Bytes()); err != nil {
    55  				return err
    56  			}
    57  		}
    58  	}
    59  
    60  	return nil
    61  }
    62  
    63  func (c *ConnWriter) writeHeader() error {
    64  	buffer := buf.StackNew()
    65  	defer buffer.Release()
    66  
    67  	command := commandTCP
    68  	if c.Target.Network == net.Network_UDP {
    69  		command = commandUDP
    70  	}
    71  
    72  	if _, err := buffer.Write(c.Account.Key); err != nil {
    73  		return err
    74  	}
    75  	if _, err := buffer.Write(crlf); err != nil {
    76  		return err
    77  	}
    78  	if err := buffer.WriteByte(command); err != nil {
    79  		return err
    80  	}
    81  	if err := addrParser.WriteAddressPort(&buffer, c.Target.Address, c.Target.Port); err != nil {
    82  		return err
    83  	}
    84  	if _, err := buffer.Write(crlf); err != nil {
    85  		return err
    86  	}
    87  
    88  	_, err := c.Writer.Write(buffer.Bytes())
    89  	if err == nil {
    90  		c.headerSent = true
    91  	}
    92  
    93  	return err
    94  }
    95  
    96  // PacketWriter UDP Connection Writer Wrapper for trojan protocol
    97  type PacketWriter struct {
    98  	io.Writer
    99  	Target net.Destination
   100  }
   101  
   102  // WriteMultiBuffer implements buf.Writer
   103  func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
   104  	for {
   105  		mb2, b := buf.SplitFirst(mb)
   106  		mb = mb2
   107  		if b == nil {
   108  			break
   109  		}
   110  		target := &w.Target
   111  		if b.UDP != nil {
   112  			target = b.UDP
   113  		}
   114  		if _, err := w.writePacket(b.Bytes(), *target); err != nil {
   115  			buf.ReleaseMulti(mb)
   116  			return err
   117  		}
   118  	}
   119  	return nil
   120  }
   121  
   122  func (w *PacketWriter) writePacket(payload []byte, dest net.Destination) (int, error) {
   123  	buffer := buf.StackNew()
   124  	defer buffer.Release()
   125  
   126  	length := len(payload)
   127  	lengthBuf := [2]byte{}
   128  	binary.BigEndian.PutUint16(lengthBuf[:], uint16(length))
   129  	if err := addrParser.WriteAddressPort(&buffer, dest.Address, dest.Port); err != nil {
   130  		return 0, err
   131  	}
   132  	if _, err := buffer.Write(lengthBuf[:]); err != nil {
   133  		return 0, err
   134  	}
   135  	if _, err := buffer.Write(crlf); err != nil {
   136  		return 0, err
   137  	}
   138  	if _, err := buffer.Write(payload); err != nil {
   139  		return 0, err
   140  	}
   141  	_, err := w.Write(buffer.Bytes())
   142  	if err != nil {
   143  		return 0, err
   144  	}
   145  
   146  	return length, nil
   147  }
   148  
   149  // ConnReader is TCP Connection Reader Wrapper for trojan protocol
   150  type ConnReader struct {
   151  	io.Reader
   152  	Target       net.Destination
   153  	Flow         string
   154  	headerParsed bool
   155  }
   156  
   157  // ParseHeader parses the trojan protocol header
   158  func (c *ConnReader) ParseHeader() error {
   159  	var crlf [2]byte
   160  	var command [1]byte
   161  	var hash [56]byte
   162  	if _, err := io.ReadFull(c.Reader, hash[:]); err != nil {
   163  		return newError("failed to read user hash").Base(err)
   164  	}
   165  
   166  	if _, err := io.ReadFull(c.Reader, crlf[:]); err != nil {
   167  		return newError("failed to read crlf").Base(err)
   168  	}
   169  
   170  	if _, err := io.ReadFull(c.Reader, command[:]); err != nil {
   171  		return newError("failed to read command").Base(err)
   172  	}
   173  
   174  	network := net.Network_TCP
   175  	if command[0] == commandUDP {
   176  		network = net.Network_UDP
   177  	}
   178  
   179  	addr, port, err := addrParser.ReadAddressPort(nil, c.Reader)
   180  	if err != nil {
   181  		return newError("failed to read address and port").Base(err)
   182  	}
   183  	c.Target = net.Destination{Network: network, Address: addr, Port: port}
   184  
   185  	if _, err := io.ReadFull(c.Reader, crlf[:]); err != nil {
   186  		return newError("failed to read crlf").Base(err)
   187  	}
   188  
   189  	c.headerParsed = true
   190  	return nil
   191  }
   192  
   193  // Read implements io.Reader
   194  func (c *ConnReader) Read(p []byte) (int, error) {
   195  	if !c.headerParsed {
   196  		if err := c.ParseHeader(); err != nil {
   197  			return 0, err
   198  		}
   199  	}
   200  
   201  	return c.Reader.Read(p)
   202  }
   203  
   204  // ReadMultiBuffer implements buf.Reader
   205  func (c *ConnReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   206  	b := buf.New()
   207  	_, err := b.ReadFrom(c)
   208  	return buf.MultiBuffer{b}, err
   209  }
   210  
   211  // PacketReader is UDP Connection Reader Wrapper for trojan protocol
   212  type PacketReader struct {
   213  	io.Reader
   214  }
   215  
   216  // ReadMultiBuffer implements buf.Reader
   217  func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   218  	addr, port, err := addrParser.ReadAddressPort(nil, r)
   219  	if err != nil {
   220  		return nil, newError("failed to read address and port").Base(err)
   221  	}
   222  
   223  	var lengthBuf [2]byte
   224  	if _, err := io.ReadFull(r, lengthBuf[:]); err != nil {
   225  		return nil, newError("failed to read payload length").Base(err)
   226  	}
   227  
   228  	remain := int(binary.BigEndian.Uint16(lengthBuf[:]))
   229  	if remain > maxLength {
   230  		return nil, newError("oversize payload")
   231  	}
   232  
   233  	var crlf [2]byte
   234  	if _, err := io.ReadFull(r, crlf[:]); err != nil {
   235  		return nil, newError("failed to read crlf").Base(err)
   236  	}
   237  
   238  	dest := net.UDPDestination(addr, port)
   239  	var mb buf.MultiBuffer
   240  	for remain > 0 {
   241  		length := buf.Size
   242  		if remain < length {
   243  			length = remain
   244  		}
   245  
   246  		b := buf.New()
   247  		b.UDP = &dest
   248  		mb = append(mb, b)
   249  		n, err := b.ReadFullFrom(r, int32(length))
   250  		if err != nil {
   251  			buf.ReleaseMulti(mb)
   252  			return nil, newError("failed to read payload").Base(err)
   253  		}
   254  
   255  		remain -= int(n)
   256  	}
   257  
   258  	return mb, nil
   259  }