github.com/imannamdari/v2ray-core/v5@v5.0.5/proxy/trojan/protocol.go (about)

     1  package trojan
     2  
     3  import (
     4  	"encoding/binary"
     5  	"io"
     6  	gonet "net"
     7  
     8  	"github.com/imannamdari/v2ray-core/v5/common/buf"
     9  	"github.com/imannamdari/v2ray-core/v5/common/net"
    10  	"github.com/imannamdari/v2ray-core/v5/common/protocol"
    11  )
    12  
    13  var (
    14  	crlf = []byte{'\r', '\n'}
    15  
    16  	addrParser = protocol.NewAddressParser(
    17  		protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4),
    18  		protocol.AddressFamilyByte(0x04, net.AddressFamilyIPv6),
    19  		protocol.AddressFamilyByte(0x03, net.AddressFamilyDomain),
    20  	)
    21  )
    22  
    23  const (
    24  	maxLength       = 8192
    25  	commandTCP byte = 1
    26  	commandUDP byte = 3
    27  )
    28  
    29  // ConnWriter is TCP Connection Writer Wrapper for trojan protocol
    30  type ConnWriter struct {
    31  	io.Writer
    32  	Target     net.Destination
    33  	Account    *MemoryAccount
    34  	headerSent bool
    35  }
    36  
    37  // Write implements io.Writer
    38  func (c *ConnWriter) Write(p []byte) (n int, err error) {
    39  	if !c.headerSent {
    40  		if err := c.writeHeader(); err != nil {
    41  			return 0, newError("failed to write request header").Base(err)
    42  		}
    43  	}
    44  
    45  	return c.Writer.Write(p)
    46  }
    47  
    48  // WriteMultiBuffer implements buf.Writer
    49  func (c *ConnWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
    50  	defer buf.ReleaseMulti(mb)
    51  
    52  	for _, b := range mb {
    53  		if !b.IsEmpty() {
    54  			if _, err := c.Write(b.Bytes()); err != nil {
    55  				return err
    56  			}
    57  		}
    58  	}
    59  
    60  	return nil
    61  }
    62  
    63  func (c *ConnWriter) writeHeader() error {
    64  	buffer := buf.StackNew()
    65  	defer buffer.Release()
    66  
    67  	command := commandTCP
    68  	if c.Target.Network == net.Network_UDP {
    69  		command = commandUDP
    70  	}
    71  
    72  	if _, err := buffer.Write(c.Account.Key); err != nil {
    73  		return err
    74  	}
    75  	if _, err := buffer.Write(crlf); err != nil {
    76  		return err
    77  	}
    78  	if err := buffer.WriteByte(command); err != nil {
    79  		return err
    80  	}
    81  	if err := addrParser.WriteAddressPort(&buffer, c.Target.Address, c.Target.Port); err != nil {
    82  		return err
    83  	}
    84  	if _, err := buffer.Write(crlf); err != nil {
    85  		return err
    86  	}
    87  
    88  	_, err := c.Writer.Write(buffer.Bytes())
    89  	if err == nil {
    90  		c.headerSent = true
    91  	}
    92  
    93  	return err
    94  }
    95  
    96  // PacketWriter UDP Connection Writer Wrapper for trojan protocol
    97  type PacketWriter struct {
    98  	io.Writer
    99  	Target net.Destination
   100  }
   101  
   102  // WriteMultiBuffer implements buf.Writer
   103  func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
   104  	b := make([]byte, maxLength)
   105  	for !mb.IsEmpty() {
   106  		var length int
   107  		mb, length = buf.SplitBytes(mb, b)
   108  		if _, err := w.writePacket(b[:length], w.Target); err != nil {
   109  			buf.ReleaseMulti(mb)
   110  			return err
   111  		}
   112  	}
   113  
   114  	return nil
   115  }
   116  
   117  // WriteMultiBufferWithMetadata writes udp packet with destination specified
   118  func (w *PacketWriter) WriteMultiBufferWithMetadata(mb buf.MultiBuffer, dest net.Destination) error {
   119  	b := make([]byte, maxLength)
   120  	for !mb.IsEmpty() {
   121  		var length int
   122  		mb, length = buf.SplitBytes(mb, b)
   123  		if _, err := w.writePacket(b[:length], dest); err != nil {
   124  			buf.ReleaseMulti(mb)
   125  			return err
   126  		}
   127  	}
   128  
   129  	return nil
   130  }
   131  
   132  func (w *PacketWriter) WriteTo(payload []byte, addr gonet.Addr) (int, error) {
   133  	dest := net.DestinationFromAddr(addr)
   134  
   135  	return w.writePacket(payload, dest)
   136  }
   137  
   138  func (w *PacketWriter) writePacket(payload []byte, dest net.Destination) (int, error) { // nolint: unparam
   139  	buffer := buf.StackNew()
   140  	defer buffer.Release()
   141  
   142  	length := len(payload)
   143  	lengthBuf := [2]byte{}
   144  	binary.BigEndian.PutUint16(lengthBuf[:], uint16(length))
   145  	if err := addrParser.WriteAddressPort(&buffer, dest.Address, dest.Port); err != nil {
   146  		return 0, err
   147  	}
   148  	if _, err := buffer.Write(lengthBuf[:]); err != nil {
   149  		return 0, err
   150  	}
   151  	if _, err := buffer.Write(crlf); err != nil {
   152  		return 0, err
   153  	}
   154  	if _, err := buffer.Write(payload); err != nil {
   155  		return 0, err
   156  	}
   157  	_, err := w.Write(buffer.Bytes())
   158  	if err != nil {
   159  		return 0, err
   160  	}
   161  
   162  	return length, nil
   163  }
   164  
   165  // ConnReader is TCP Connection Reader Wrapper for trojan protocol
   166  type ConnReader struct {
   167  	io.Reader
   168  	Target       net.Destination
   169  	headerParsed bool
   170  }
   171  
   172  // ParseHeader parses the trojan protocol header
   173  func (c *ConnReader) ParseHeader() error {
   174  	var crlf [2]byte
   175  	var command [1]byte
   176  	var hash [56]byte
   177  	if _, err := io.ReadFull(c.Reader, hash[:]); err != nil {
   178  		return newError("failed to read user hash").Base(err)
   179  	}
   180  
   181  	if _, err := io.ReadFull(c.Reader, crlf[:]); err != nil {
   182  		return newError("failed to read crlf").Base(err)
   183  	}
   184  
   185  	if _, err := io.ReadFull(c.Reader, command[:]); err != nil {
   186  		return newError("failed to read command").Base(err)
   187  	}
   188  
   189  	network := net.Network_TCP
   190  	if command[0] == commandUDP {
   191  		network = net.Network_UDP
   192  	}
   193  
   194  	addr, port, err := addrParser.ReadAddressPort(nil, c.Reader)
   195  	if err != nil {
   196  		return newError("failed to read address and port").Base(err)
   197  	}
   198  	c.Target = net.Destination{Network: network, Address: addr, Port: port}
   199  
   200  	if _, err := io.ReadFull(c.Reader, crlf[:]); err != nil {
   201  		return newError("failed to read crlf").Base(err)
   202  	}
   203  
   204  	c.headerParsed = true
   205  	return nil
   206  }
   207  
   208  // Read implements io.Reader
   209  func (c *ConnReader) Read(p []byte) (int, error) {
   210  	if !c.headerParsed {
   211  		if err := c.ParseHeader(); err != nil {
   212  			return 0, err
   213  		}
   214  	}
   215  
   216  	return c.Reader.Read(p)
   217  }
   218  
   219  // ReadMultiBuffer implements buf.Reader
   220  func (c *ConnReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   221  	b := buf.New()
   222  	_, err := b.ReadFrom(c)
   223  	return buf.MultiBuffer{b}, err
   224  }
   225  
   226  // PacketPayload combines udp payload and destination
   227  type PacketPayload struct {
   228  	Target net.Destination
   229  	Buffer buf.MultiBuffer
   230  }
   231  
   232  // PacketReader is UDP Connection Reader Wrapper for trojan protocol
   233  type PacketReader struct {
   234  	io.Reader
   235  }
   236  
   237  // ReadMultiBuffer implements buf.Reader
   238  func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   239  	p, err := r.ReadMultiBufferWithMetadata()
   240  	if p != nil {
   241  		return p.Buffer, err
   242  	}
   243  	return nil, err
   244  }
   245  
   246  // ReadMultiBufferWithMetadata reads udp packet with destination
   247  func (r *PacketReader) ReadMultiBufferWithMetadata() (*PacketPayload, error) {
   248  	addr, port, err := addrParser.ReadAddressPort(nil, r)
   249  	if err != nil {
   250  		return nil, newError("failed to read address and port").Base(err)
   251  	}
   252  
   253  	var lengthBuf [2]byte
   254  	if _, err := io.ReadFull(r, lengthBuf[:]); err != nil {
   255  		return nil, newError("failed to read payload length").Base(err)
   256  	}
   257  
   258  	remain := int(binary.BigEndian.Uint16(lengthBuf[:]))
   259  	if remain > maxLength {
   260  		return nil, newError("oversize payload")
   261  	}
   262  
   263  	var crlf [2]byte
   264  	if _, err := io.ReadFull(r, crlf[:]); err != nil {
   265  		return nil, newError("failed to read crlf").Base(err)
   266  	}
   267  
   268  	dest := net.UDPDestination(addr, port)
   269  	var mb buf.MultiBuffer
   270  	for remain > 0 {
   271  		length := buf.Size
   272  		if remain < length {
   273  			length = remain
   274  		}
   275  
   276  		b := buf.New()
   277  		mb = append(mb, b)
   278  		n, err := b.ReadFullFrom(r, int32(length))
   279  		if err != nil {
   280  			buf.ReleaseMulti(mb)
   281  			return nil, newError("failed to read payload").Base(err)
   282  		}
   283  
   284  		remain -= int(n)
   285  	}
   286  
   287  	return &PacketPayload{Target: dest, Buffer: mb}, nil
   288  }
   289  
   290  type PacketConnectionReader struct {
   291  	reader  *PacketReader
   292  	payload *PacketPayload
   293  }
   294  
   295  func (r *PacketConnectionReader) ReadFrom(p []byte) (n int, addr gonet.Addr, err error) {
   296  	if r.payload == nil || r.payload.Buffer.IsEmpty() {
   297  		r.payload, err = r.reader.ReadMultiBufferWithMetadata()
   298  		if err != nil {
   299  			return
   300  		}
   301  	}
   302  
   303  	addr = &gonet.UDPAddr{
   304  		IP:   r.payload.Target.Address.IP(),
   305  		Port: int(r.payload.Target.Port),
   306  	}
   307  
   308  	r.payload.Buffer, n = buf.SplitFirstBytes(r.payload.Buffer, p)
   309  
   310  	return
   311  }