github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/proxy/trojan/protocol.go (about)

     1  package trojan
     2  
     3  import (
     4  	"encoding/binary"
     5  	"io"
     6  	gonet "net"
     7  
     8  	"github.com/v2fly/v2ray-core/v5/common/buf"
     9  	"github.com/v2fly/v2ray-core/v5/common/net"
    10  	"github.com/v2fly/v2ray-core/v5/common/protocol"
    11  )
    12  
    13  var (
    14  	crlf = []byte{'\r', '\n'}
    15  
    16  	addrParser = protocol.NewAddressParser(
    17  		protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4),
    18  		protocol.AddressFamilyByte(0x04, net.AddressFamilyIPv6),
    19  		protocol.AddressFamilyByte(0x03, net.AddressFamilyDomain),
    20  	)
    21  )
    22  
    23  const (
    24  	commandTCP byte = 1
    25  	commandUDP byte = 3
    26  )
    27  
    28  // ConnWriter is TCP Connection Writer Wrapper for trojan protocol
    29  type ConnWriter struct {
    30  	io.Writer
    31  	Target     net.Destination
    32  	Account    *MemoryAccount
    33  	headerSent bool
    34  }
    35  
    36  // Write implements io.Writer
    37  func (c *ConnWriter) Write(p []byte) (n int, err error) {
    38  	if !c.headerSent {
    39  		if err := c.writeHeader(); err != nil {
    40  			return 0, newError("failed to write request header").Base(err)
    41  		}
    42  	}
    43  
    44  	return c.Writer.Write(p)
    45  }
    46  
    47  // WriteMultiBuffer implements buf.Writer
    48  func (c *ConnWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
    49  	defer buf.ReleaseMulti(mb)
    50  
    51  	for _, b := range mb {
    52  		if !b.IsEmpty() {
    53  			if _, err := c.Write(b.Bytes()); err != nil {
    54  				return err
    55  			}
    56  		}
    57  	}
    58  
    59  	return nil
    60  }
    61  
    62  func (c *ConnWriter) WriteHeader() error {
    63  	if !c.headerSent {
    64  		if err := c.writeHeader(); err != nil {
    65  			return err
    66  		}
    67  	}
    68  	return nil
    69  }
    70  
    71  func (c *ConnWriter) writeHeader() error {
    72  	buffer := buf.StackNew()
    73  	defer buffer.Release()
    74  
    75  	command := commandTCP
    76  	if c.Target.Network == net.Network_UDP {
    77  		command = commandUDP
    78  	}
    79  
    80  	if _, err := buffer.Write(c.Account.Key); err != nil {
    81  		return err
    82  	}
    83  	if _, err := buffer.Write(crlf); err != nil {
    84  		return err
    85  	}
    86  	if err := buffer.WriteByte(command); err != nil {
    87  		return err
    88  	}
    89  	if err := addrParser.WriteAddressPort(&buffer, c.Target.Address, c.Target.Port); err != nil {
    90  		return err
    91  	}
    92  	if _, err := buffer.Write(crlf); err != nil {
    93  		return err
    94  	}
    95  
    96  	_, err := c.Writer.Write(buffer.Bytes())
    97  	if err == nil {
    98  		c.headerSent = true
    99  	}
   100  
   101  	return err
   102  }
   103  
   104  // PacketWriter UDP Connection Writer Wrapper for trojan protocol
   105  type PacketWriter struct {
   106  	io.Writer
   107  	Target net.Destination
   108  }
   109  
   110  // WriteMultiBuffer implements buf.Writer
   111  func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
   112  	for _, b := range mb {
   113  		if b.IsEmpty() {
   114  			continue
   115  		}
   116  		if _, err := w.writePacket(b.Bytes(), w.Target); err != nil {
   117  			buf.ReleaseMulti(mb)
   118  			return err
   119  		}
   120  	}
   121  
   122  	return nil
   123  }
   124  
   125  // WriteMultiBufferWithMetadata writes udp packet with destination specified
   126  func (w *PacketWriter) WriteMultiBufferWithMetadata(mb buf.MultiBuffer, dest net.Destination) error {
   127  	for _, b := range mb {
   128  		if b.IsEmpty() {
   129  			continue
   130  		}
   131  		if _, err := w.writePacket(b.Bytes(), dest); err != nil {
   132  			buf.ReleaseMulti(mb)
   133  			return err
   134  		}
   135  	}
   136  
   137  	return nil
   138  }
   139  
   140  func (w *PacketWriter) WriteTo(payload []byte, addr gonet.Addr) (int, error) {
   141  	dest := net.DestinationFromAddr(addr)
   142  
   143  	return w.writePacket(payload, dest)
   144  }
   145  
   146  func (w *PacketWriter) writePacket(payload []byte, dest net.Destination) (int, error) { // nolint: unparam
   147  	var addrPortLen int32
   148  	switch dest.Address.Family() {
   149  	case net.AddressFamilyDomain:
   150  		if protocol.IsDomainTooLong(dest.Address.Domain()) {
   151  			return 0, newError("Super long domain is not supported: ", dest.Address.Domain())
   152  		}
   153  		addrPortLen = 1 + 1 + int32(len(dest.Address.Domain())) + 2
   154  	case net.AddressFamilyIPv4:
   155  		addrPortLen = 1 + 4 + 2
   156  	case net.AddressFamilyIPv6:
   157  		addrPortLen = 1 + 16 + 2
   158  	default:
   159  		panic("Unknown address type.")
   160  	}
   161  
   162  	length := len(payload)
   163  	lengthBuf := [2]byte{}
   164  	binary.BigEndian.PutUint16(lengthBuf[:], uint16(length))
   165  
   166  	buffer := buf.NewWithSize(addrPortLen + 2 + 2 + int32(length))
   167  	defer buffer.Release()
   168  
   169  	if err := addrParser.WriteAddressPort(buffer, dest.Address, dest.Port); err != nil {
   170  		return 0, err
   171  	}
   172  	if _, err := buffer.Write(lengthBuf[:]); err != nil {
   173  		return 0, err
   174  	}
   175  	if _, err := buffer.Write(crlf); err != nil {
   176  		return 0, err
   177  	}
   178  	if _, err := buffer.Write(payload); err != nil {
   179  		return 0, err
   180  	}
   181  	_, err := w.Write(buffer.Bytes())
   182  	if err != nil {
   183  		return 0, err
   184  	}
   185  
   186  	return length, nil
   187  }
   188  
   189  // ConnReader is TCP Connection Reader Wrapper for trojan protocol
   190  type ConnReader struct {
   191  	io.Reader
   192  	Target       net.Destination
   193  	headerParsed bool
   194  }
   195  
   196  // ParseHeader parses the trojan protocol header
   197  func (c *ConnReader) ParseHeader() error {
   198  	var crlf [2]byte
   199  	var command [1]byte
   200  	var hash [56]byte
   201  	if _, err := io.ReadFull(c.Reader, hash[:]); err != nil {
   202  		return newError("failed to read user hash").Base(err)
   203  	}
   204  
   205  	if _, err := io.ReadFull(c.Reader, crlf[:]); err != nil {
   206  		return newError("failed to read crlf").Base(err)
   207  	}
   208  
   209  	if _, err := io.ReadFull(c.Reader, command[:]); err != nil {
   210  		return newError("failed to read command").Base(err)
   211  	}
   212  
   213  	network := net.Network_TCP
   214  	if command[0] == commandUDP {
   215  		network = net.Network_UDP
   216  	}
   217  
   218  	addr, port, err := addrParser.ReadAddressPort(nil, c.Reader)
   219  	if err != nil {
   220  		return newError("failed to read address and port").Base(err)
   221  	}
   222  	c.Target = net.Destination{Network: network, Address: addr, Port: port}
   223  
   224  	if _, err := io.ReadFull(c.Reader, crlf[:]); err != nil {
   225  		return newError("failed to read crlf").Base(err)
   226  	}
   227  
   228  	c.headerParsed = true
   229  	return nil
   230  }
   231  
   232  // Read implements io.Reader
   233  func (c *ConnReader) Read(p []byte) (int, error) {
   234  	if !c.headerParsed {
   235  		if err := c.ParseHeader(); err != nil {
   236  			return 0, err
   237  		}
   238  	}
   239  
   240  	return c.Reader.Read(p)
   241  }
   242  
   243  // ReadMultiBuffer implements buf.Reader
   244  func (c *ConnReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   245  	b := buf.New()
   246  	_, err := b.ReadFrom(c)
   247  	return buf.MultiBuffer{b}, err
   248  }
   249  
   250  // PacketPayload combines udp payload and destination
   251  type PacketPayload struct {
   252  	Target net.Destination
   253  	Buffer buf.MultiBuffer
   254  }
   255  
   256  // PacketReader is UDP Connection Reader Wrapper for trojan protocol
   257  type PacketReader struct {
   258  	io.Reader
   259  }
   260  
   261  // ReadMultiBuffer implements buf.Reader
   262  func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   263  	p, err := r.ReadMultiBufferWithMetadata()
   264  	if p != nil {
   265  		return p.Buffer, err
   266  	}
   267  	return nil, err
   268  }
   269  
   270  // ReadMultiBufferWithMetadata reads udp packet with destination
   271  func (r *PacketReader) ReadMultiBufferWithMetadata() (*PacketPayload, error) {
   272  	addr, port, err := addrParser.ReadAddressPort(nil, r)
   273  	if err != nil {
   274  		return nil, newError("failed to read address and port").Base(err)
   275  	}
   276  
   277  	var lengthBuf [2]byte
   278  	if _, err := io.ReadFull(r, lengthBuf[:]); err != nil {
   279  		return nil, newError("failed to read payload length").Base(err)
   280  	}
   281  
   282  	length := binary.BigEndian.Uint16(lengthBuf[:])
   283  
   284  	var crlf [2]byte
   285  	if _, err := io.ReadFull(r, crlf[:]); err != nil {
   286  		return nil, newError("failed to read crlf").Base(err)
   287  	}
   288  
   289  	dest := net.UDPDestination(addr, port)
   290  
   291  	b := buf.NewWithSize(int32(length))
   292  	_, err = b.ReadFullFrom(r, int32(length))
   293  	if err != nil {
   294  		return nil, newError("failed to read payload").Base(err)
   295  	}
   296  
   297  	return &PacketPayload{Target: dest, Buffer: buf.MultiBuffer{b}}, nil
   298  }
   299  
   300  type PacketConnectionReader struct {
   301  	reader  *PacketReader
   302  	payload *PacketPayload
   303  }
   304  
   305  func (r *PacketConnectionReader) ReadFrom(p []byte) (n int, addr gonet.Addr, err error) {
   306  	if r.payload == nil || r.payload.Buffer.IsEmpty() {
   307  		r.payload, err = r.reader.ReadMultiBufferWithMetadata()
   308  		if err != nil {
   309  			return
   310  		}
   311  	}
   312  
   313  	addr = &gonet.UDPAddr{
   314  		IP:   r.payload.Target.Address.IP(),
   315  		Port: int(r.payload.Target.Port),
   316  	}
   317  
   318  	r.payload.Buffer, n = buf.SplitFirstBytes(r.payload.Buffer, p)
   319  
   320  	return
   321  }