github.com/xraypb/Xray-core@v1.8.1/proxy/trojan/protocol.go (about)

     1  package trojan
     2  
     3  import (
     4  	"encoding/binary"
     5  	"io"
     6  
     7  	"github.com/xraypb/Xray-core/common/buf"
     8  	"github.com/xraypb/Xray-core/common/net"
     9  	"github.com/xraypb/Xray-core/common/protocol"
    10  )
    11  
    12  var (
    13  	crlf = []byte{'\r', '\n'}
    14  
    15  	addrParser = protocol.NewAddressParser(
    16  		protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4),
    17  		protocol.AddressFamilyByte(0x04, net.AddressFamilyIPv6),
    18  		protocol.AddressFamilyByte(0x03, net.AddressFamilyDomain),
    19  	)
    20  )
    21  
    22  const (
    23  	maxLength = 8192
    24  
    25  	commandTCP byte = 1
    26  	commandUDP byte = 3
    27  )
    28  
    29  // ConnWriter is TCP Connection Writer Wrapper for trojan protocol
    30  type ConnWriter struct {
    31  	io.Writer
    32  	Target     net.Destination
    33  	Account    *MemoryAccount
    34  	Flow       string
    35  	headerSent bool
    36  }
    37  
    38  // Write implements io.Writer
    39  func (c *ConnWriter) Write(p []byte) (n int, err error) {
    40  	if !c.headerSent {
    41  		if err := c.writeHeader(); err != nil {
    42  			return 0, newError("failed to write request header").Base(err)
    43  		}
    44  	}
    45  
    46  	return c.Writer.Write(p)
    47  }
    48  
    49  // WriteMultiBuffer implements buf.Writer
    50  func (c *ConnWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
    51  	defer buf.ReleaseMulti(mb)
    52  
    53  	for _, b := range mb {
    54  		if !b.IsEmpty() {
    55  			if _, err := c.Write(b.Bytes()); err != nil {
    56  				return err
    57  			}
    58  		}
    59  	}
    60  
    61  	return nil
    62  }
    63  
    64  func (c *ConnWriter) writeHeader() error {
    65  	buffer := buf.StackNew()
    66  	defer buffer.Release()
    67  
    68  	command := commandTCP
    69  	if c.Target.Network == net.Network_UDP {
    70  		command = commandUDP
    71  	}
    72  
    73  	if _, err := buffer.Write(c.Account.Key); err != nil {
    74  		return err
    75  	}
    76  	if _, err := buffer.Write(crlf); err != nil {
    77  		return err
    78  	}
    79  	if err := buffer.WriteByte(command); err != nil {
    80  		return err
    81  	}
    82  	if err := addrParser.WriteAddressPort(&buffer, c.Target.Address, c.Target.Port); err != nil {
    83  		return err
    84  	}
    85  	if _, err := buffer.Write(crlf); err != nil {
    86  		return err
    87  	}
    88  
    89  	_, err := c.Writer.Write(buffer.Bytes())
    90  	if err == nil {
    91  		c.headerSent = true
    92  	}
    93  
    94  	return err
    95  }
    96  
    97  // PacketWriter UDP Connection Writer Wrapper for trojan protocol
    98  type PacketWriter struct {
    99  	io.Writer
   100  	Target net.Destination
   101  }
   102  
   103  // WriteMultiBuffer implements buf.Writer
   104  func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
   105  	for {
   106  		mb2, b := buf.SplitFirst(mb)
   107  		mb = mb2
   108  		if b == nil {
   109  			break
   110  		}
   111  		target := &w.Target
   112  		if b.UDP != nil {
   113  			target = b.UDP
   114  		}
   115  		if _, err := w.writePacket(b.Bytes(), *target); err != nil {
   116  			buf.ReleaseMulti(mb)
   117  			return err
   118  		}
   119  	}
   120  	return nil
   121  }
   122  
   123  func (w *PacketWriter) writePacket(payload []byte, dest net.Destination) (int, error) {
   124  	buffer := buf.StackNew()
   125  	defer buffer.Release()
   126  
   127  	length := len(payload)
   128  	lengthBuf := [2]byte{}
   129  	binary.BigEndian.PutUint16(lengthBuf[:], uint16(length))
   130  	if err := addrParser.WriteAddressPort(&buffer, dest.Address, dest.Port); err != nil {
   131  		return 0, err
   132  	}
   133  	if _, err := buffer.Write(lengthBuf[:]); err != nil {
   134  		return 0, err
   135  	}
   136  	if _, err := buffer.Write(crlf); err != nil {
   137  		return 0, err
   138  	}
   139  	if _, err := buffer.Write(payload); err != nil {
   140  		return 0, err
   141  	}
   142  	_, err := w.Write(buffer.Bytes())
   143  	if err != nil {
   144  		return 0, err
   145  	}
   146  
   147  	return length, nil
   148  }
   149  
   150  // ConnReader is TCP Connection Reader Wrapper for trojan protocol
   151  type ConnReader struct {
   152  	io.Reader
   153  	Target       net.Destination
   154  	Flow         string
   155  	headerParsed bool
   156  }
   157  
   158  // ParseHeader parses the trojan protocol header
   159  func (c *ConnReader) ParseHeader() error {
   160  	var crlf [2]byte
   161  	var command [1]byte
   162  	var hash [56]byte
   163  	if _, err := io.ReadFull(c.Reader, hash[:]); err != nil {
   164  		return newError("failed to read user hash").Base(err)
   165  	}
   166  
   167  	if _, err := io.ReadFull(c.Reader, crlf[:]); err != nil {
   168  		return newError("failed to read crlf").Base(err)
   169  	}
   170  
   171  	if _, err := io.ReadFull(c.Reader, command[:]); err != nil {
   172  		return newError("failed to read command").Base(err)
   173  	}
   174  
   175  	network := net.Network_TCP
   176  	if command[0] == commandUDP {
   177  		network = net.Network_UDP
   178  	}
   179  
   180  	addr, port, err := addrParser.ReadAddressPort(nil, c.Reader)
   181  	if err != nil {
   182  		return newError("failed to read address and port").Base(err)
   183  	}
   184  	c.Target = net.Destination{Network: network, Address: addr, Port: port}
   185  
   186  	if _, err := io.ReadFull(c.Reader, crlf[:]); err != nil {
   187  		return newError("failed to read crlf").Base(err)
   188  	}
   189  
   190  	c.headerParsed = true
   191  	return nil
   192  }
   193  
   194  // Read implements io.Reader
   195  func (c *ConnReader) Read(p []byte) (int, error) {
   196  	if !c.headerParsed {
   197  		if err := c.ParseHeader(); err != nil {
   198  			return 0, err
   199  		}
   200  	}
   201  
   202  	return c.Reader.Read(p)
   203  }
   204  
   205  // ReadMultiBuffer implements buf.Reader
   206  func (c *ConnReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   207  	b := buf.New()
   208  	_, err := b.ReadFrom(c)
   209  	return buf.MultiBuffer{b}, err
   210  }
   211  
   212  // PacketReader is UDP Connection Reader Wrapper for trojan protocol
   213  type PacketReader struct {
   214  	io.Reader
   215  }
   216  
   217  // ReadMultiBuffer implements buf.Reader
   218  func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   219  	addr, port, err := addrParser.ReadAddressPort(nil, r)
   220  	if err != nil {
   221  		return nil, newError("failed to read address and port").Base(err)
   222  	}
   223  
   224  	var lengthBuf [2]byte
   225  	if _, err := io.ReadFull(r, lengthBuf[:]); err != nil {
   226  		return nil, newError("failed to read payload length").Base(err)
   227  	}
   228  
   229  	remain := int(binary.BigEndian.Uint16(lengthBuf[:]))
   230  	if remain > maxLength {
   231  		return nil, newError("oversize payload")
   232  	}
   233  
   234  	var crlf [2]byte
   235  	if _, err := io.ReadFull(r, crlf[:]); err != nil {
   236  		return nil, newError("failed to read crlf").Base(err)
   237  	}
   238  
   239  	dest := net.UDPDestination(addr, port)
   240  	var mb buf.MultiBuffer
   241  	for remain > 0 {
   242  		length := buf.Size
   243  		if remain < length {
   244  			length = remain
   245  		}
   246  
   247  		b := buf.New()
   248  		b.UDP = &dest
   249  		mb = append(mb, b)
   250  		n, err := b.ReadFullFrom(r, int32(length))
   251  		if err != nil {
   252  			buf.ReleaseMulti(mb)
   253  			return nil, newError("failed to read payload").Base(err)
   254  		}
   255  
   256  		remain -= int(n)
   257  	}
   258  
   259  	return mb, nil
   260  }