github.com/v2fly/v2ray-core/v4@v4.45.2/proxy/trojan/protocol.go (about)

     1  package trojan
     2  
     3  import (
     4  	"encoding/binary"
     5  	"io"
     6  
     7  	"github.com/v2fly/v2ray-core/v4/common/buf"
     8  	"github.com/v2fly/v2ray-core/v4/common/net"
     9  	"github.com/v2fly/v2ray-core/v4/common/protocol"
    10  )
    11  
    12  var (
    13  	crlf = []byte{'\r', '\n'}
    14  
    15  	addrParser = protocol.NewAddressParser(
    16  		protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4),
    17  		protocol.AddressFamilyByte(0x04, net.AddressFamilyIPv6),
    18  		protocol.AddressFamilyByte(0x03, net.AddressFamilyDomain),
    19  	)
    20  )
    21  
    22  const (
    23  	maxLength       = 8192
    24  	commandTCP byte = 1
    25  	commandUDP byte = 3
    26  )
    27  
    28  // ConnWriter is TCP Connection Writer Wrapper for trojan protocol
    29  type ConnWriter struct {
    30  	io.Writer
    31  	Target     net.Destination
    32  	Account    *MemoryAccount
    33  	headerSent bool
    34  }
    35  
    36  // Write implements io.Writer
    37  func (c *ConnWriter) Write(p []byte) (n int, err error) {
    38  	if !c.headerSent {
    39  		if err := c.writeHeader(); err != nil {
    40  			return 0, newError("failed to write request header").Base(err)
    41  		}
    42  	}
    43  
    44  	return c.Writer.Write(p)
    45  }
    46  
    47  // WriteMultiBuffer implements buf.Writer
    48  func (c *ConnWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
    49  	defer buf.ReleaseMulti(mb)
    50  
    51  	for _, b := range mb {
    52  		if !b.IsEmpty() {
    53  			if _, err := c.Write(b.Bytes()); err != nil {
    54  				return err
    55  			}
    56  		}
    57  	}
    58  
    59  	return nil
    60  }
    61  
    62  func (c *ConnWriter) writeHeader() error {
    63  	buffer := buf.StackNew()
    64  	defer buffer.Release()
    65  
    66  	command := commandTCP
    67  	if c.Target.Network == net.Network_UDP {
    68  		command = commandUDP
    69  	}
    70  
    71  	if _, err := buffer.Write(c.Account.Key); err != nil {
    72  		return err
    73  	}
    74  	if _, err := buffer.Write(crlf); err != nil {
    75  		return err
    76  	}
    77  	if err := buffer.WriteByte(command); err != nil {
    78  		return err
    79  	}
    80  	if err := addrParser.WriteAddressPort(&buffer, c.Target.Address, c.Target.Port); err != nil {
    81  		return err
    82  	}
    83  	if _, err := buffer.Write(crlf); err != nil {
    84  		return err
    85  	}
    86  
    87  	_, err := c.Writer.Write(buffer.Bytes())
    88  	if err == nil {
    89  		c.headerSent = true
    90  	}
    91  
    92  	return err
    93  }
    94  
    95  // PacketWriter UDP Connection Writer Wrapper for trojan protocol
    96  type PacketWriter struct {
    97  	io.Writer
    98  	Target net.Destination
    99  }
   100  
   101  // WriteMultiBuffer implements buf.Writer
   102  func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
   103  	b := make([]byte, maxLength)
   104  	for !mb.IsEmpty() {
   105  		var length int
   106  		mb, length = buf.SplitBytes(mb, b)
   107  		if _, err := w.writePacket(b[:length], w.Target); err != nil {
   108  			buf.ReleaseMulti(mb)
   109  			return err
   110  		}
   111  	}
   112  
   113  	return nil
   114  }
   115  
   116  // WriteMultiBufferWithMetadata writes udp packet with destination specified
   117  func (w *PacketWriter) WriteMultiBufferWithMetadata(mb buf.MultiBuffer, dest net.Destination) error {
   118  	b := make([]byte, maxLength)
   119  	for !mb.IsEmpty() {
   120  		var length int
   121  		mb, length = buf.SplitBytes(mb, b)
   122  		if _, err := w.writePacket(b[:length], dest); err != nil {
   123  			buf.ReleaseMulti(mb)
   124  			return err
   125  		}
   126  	}
   127  
   128  	return nil
   129  }
   130  
   131  func (w *PacketWriter) writePacket(payload []byte, dest net.Destination) (int, error) { // nolint: unparam
   132  	buffer := buf.StackNew()
   133  	defer buffer.Release()
   134  
   135  	length := len(payload)
   136  	lengthBuf := [2]byte{}
   137  	binary.BigEndian.PutUint16(lengthBuf[:], uint16(length))
   138  	if err := addrParser.WriteAddressPort(&buffer, dest.Address, dest.Port); err != nil {
   139  		return 0, err
   140  	}
   141  	if _, err := buffer.Write(lengthBuf[:]); err != nil {
   142  		return 0, err
   143  	}
   144  	if _, err := buffer.Write(crlf); err != nil {
   145  		return 0, err
   146  	}
   147  	if _, err := buffer.Write(payload); err != nil {
   148  		return 0, err
   149  	}
   150  	_, err := w.Write(buffer.Bytes())
   151  	if err != nil {
   152  		return 0, err
   153  	}
   154  
   155  	return length, nil
   156  }
   157  
   158  // ConnReader is TCP Connection Reader Wrapper for trojan protocol
   159  type ConnReader struct {
   160  	io.Reader
   161  	Target       net.Destination
   162  	headerParsed bool
   163  }
   164  
   165  // ParseHeader parses the trojan protocol header
   166  func (c *ConnReader) ParseHeader() error {
   167  	var crlf [2]byte
   168  	var command [1]byte
   169  	var hash [56]byte
   170  	if _, err := io.ReadFull(c.Reader, hash[:]); err != nil {
   171  		return newError("failed to read user hash").Base(err)
   172  	}
   173  
   174  	if _, err := io.ReadFull(c.Reader, crlf[:]); err != nil {
   175  		return newError("failed to read crlf").Base(err)
   176  	}
   177  
   178  	if _, err := io.ReadFull(c.Reader, command[:]); err != nil {
   179  		return newError("failed to read command").Base(err)
   180  	}
   181  
   182  	network := net.Network_TCP
   183  	if command[0] == commandUDP {
   184  		network = net.Network_UDP
   185  	}
   186  
   187  	addr, port, err := addrParser.ReadAddressPort(nil, c.Reader)
   188  	if err != nil {
   189  		return newError("failed to read address and port").Base(err)
   190  	}
   191  	c.Target = net.Destination{Network: network, Address: addr, Port: port}
   192  
   193  	if _, err := io.ReadFull(c.Reader, crlf[:]); err != nil {
   194  		return newError("failed to read crlf").Base(err)
   195  	}
   196  
   197  	c.headerParsed = true
   198  	return nil
   199  }
   200  
   201  // Read implements io.Reader
   202  func (c *ConnReader) Read(p []byte) (int, error) {
   203  	if !c.headerParsed {
   204  		if err := c.ParseHeader(); err != nil {
   205  			return 0, err
   206  		}
   207  	}
   208  
   209  	return c.Reader.Read(p)
   210  }
   211  
   212  // ReadMultiBuffer implements buf.Reader
   213  func (c *ConnReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   214  	b := buf.New()
   215  	_, err := b.ReadFrom(c)
   216  	return buf.MultiBuffer{b}, err
   217  }
   218  
   219  // PacketPayload combines udp payload and destination
   220  type PacketPayload struct {
   221  	Target net.Destination
   222  	Buffer buf.MultiBuffer
   223  }
   224  
   225  // PacketReader is UDP Connection Reader Wrapper for trojan protocol
   226  type PacketReader struct {
   227  	io.Reader
   228  }
   229  
   230  // ReadMultiBuffer implements buf.Reader
   231  func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   232  	p, err := r.ReadMultiBufferWithMetadata()
   233  	if p != nil {
   234  		return p.Buffer, err
   235  	}
   236  	return nil, err
   237  }
   238  
   239  // ReadMultiBufferWithMetadata reads udp packet with destination
   240  func (r *PacketReader) ReadMultiBufferWithMetadata() (*PacketPayload, error) {
   241  	addr, port, err := addrParser.ReadAddressPort(nil, r)
   242  	if err != nil {
   243  		return nil, newError("failed to read address and port").Base(err)
   244  	}
   245  
   246  	var lengthBuf [2]byte
   247  	if _, err := io.ReadFull(r, lengthBuf[:]); err != nil {
   248  		return nil, newError("failed to read payload length").Base(err)
   249  	}
   250  
   251  	remain := int(binary.BigEndian.Uint16(lengthBuf[:]))
   252  	if remain > maxLength {
   253  		return nil, newError("oversize payload")
   254  	}
   255  
   256  	var crlf [2]byte
   257  	if _, err := io.ReadFull(r, crlf[:]); err != nil {
   258  		return nil, newError("failed to read crlf").Base(err)
   259  	}
   260  
   261  	dest := net.UDPDestination(addr, port)
   262  	var mb buf.MultiBuffer
   263  	for remain > 0 {
   264  		length := buf.Size
   265  		if remain < length {
   266  			length = remain
   267  		}
   268  
   269  		b := buf.New()
   270  		mb = append(mb, b)
   271  		n, err := b.ReadFullFrom(r, int32(length))
   272  		if err != nil {
   273  			buf.ReleaseMulti(mb)
   274  			return nil, newError("failed to read payload").Base(err)
   275  		}
   276  
   277  		remain -= int(n)
   278  	}
   279  
   280  	return &PacketPayload{Target: dest, Buffer: mb}, nil
   281  }