github.com/xraypb/xray-core@v1.6.6/proxy/trojan/protocol.go (about)

     1  package trojan
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	fmt "fmt"
     7  	"io"
     8  	"runtime"
     9  	"syscall"
    10  
    11  	"github.com/xraypb/xray-core/common/buf"
    12  	"github.com/xraypb/xray-core/common/errors"
    13  	"github.com/xraypb/xray-core/common/net"
    14  	"github.com/xraypb/xray-core/common/protocol"
    15  	"github.com/xraypb/xray-core/common/session"
    16  	"github.com/xraypb/xray-core/common/signal"
    17  	"github.com/xraypb/xray-core/features/stats"
    18  	"github.com/xraypb/xray-core/transport/internet/stat"
    19  	"github.com/xraypb/xray-core/transport/internet/xtls"
    20  )
    21  
    22  var (
    23  	crlf = []byte{'\r', '\n'}
    24  
    25  	addrParser = protocol.NewAddressParser(
    26  		protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4),
    27  		protocol.AddressFamilyByte(0x04, net.AddressFamilyIPv6),
    28  		protocol.AddressFamilyByte(0x03, net.AddressFamilyDomain),
    29  	)
    30  
    31  	xtls_show = false
    32  )
    33  
    34  const (
    35  	maxLength = 8192
    36  	// XRS is constant for XTLS splice mode
    37  	XRS = "xtls-rprx-splice"
    38  	// XRD is constant for XTLS direct mode
    39  	XRD = "xtls-rprx-direct"
    40  	// XRO is constant for XTLS origin mode
    41  	XRO = "xtls-rprx-origin"
    42  
    43  	commandTCP byte = 1
    44  	commandUDP byte = 3
    45  
    46  	// for XTLS
    47  	commandXRD byte = 0xf0 // XTLS direct mode
    48  	commandXRO byte = 0xf1 // XTLS origin mode
    49  )
    50  
    51  // ConnWriter is TCP Connection Writer Wrapper for trojan protocol
    52  type ConnWriter struct {
    53  	io.Writer
    54  	Target     net.Destination
    55  	Account    *MemoryAccount
    56  	Flow       string
    57  	headerSent bool
    58  }
    59  
    60  // Write implements io.Writer
    61  func (c *ConnWriter) Write(p []byte) (n int, err error) {
    62  	if !c.headerSent {
    63  		if err := c.writeHeader(); err != nil {
    64  			return 0, newError("failed to write request header").Base(err)
    65  		}
    66  	}
    67  
    68  	return c.Writer.Write(p)
    69  }
    70  
    71  // WriteMultiBuffer implements buf.Writer
    72  func (c *ConnWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
    73  	defer buf.ReleaseMulti(mb)
    74  
    75  	for _, b := range mb {
    76  		if !b.IsEmpty() {
    77  			if _, err := c.Write(b.Bytes()); err != nil {
    78  				return err
    79  			}
    80  		}
    81  	}
    82  
    83  	return nil
    84  }
    85  
    86  func (c *ConnWriter) writeHeader() error {
    87  	buffer := buf.StackNew()
    88  	defer buffer.Release()
    89  
    90  	command := commandTCP
    91  	if c.Target.Network == net.Network_UDP {
    92  		command = commandUDP
    93  	} else if c.Flow == XRD {
    94  		command = commandXRD
    95  	} else if c.Flow == XRO {
    96  		command = commandXRO
    97  	}
    98  
    99  	if _, err := buffer.Write(c.Account.Key); err != nil {
   100  		return err
   101  	}
   102  	if _, err := buffer.Write(crlf); err != nil {
   103  		return err
   104  	}
   105  	if err := buffer.WriteByte(command); err != nil {
   106  		return err
   107  	}
   108  	if err := addrParser.WriteAddressPort(&buffer, c.Target.Address, c.Target.Port); err != nil {
   109  		return err
   110  	}
   111  	if _, err := buffer.Write(crlf); err != nil {
   112  		return err
   113  	}
   114  
   115  	_, err := c.Writer.Write(buffer.Bytes())
   116  	if err == nil {
   117  		c.headerSent = true
   118  	}
   119  
   120  	return err
   121  }
   122  
   123  // PacketWriter UDP Connection Writer Wrapper for trojan protocol
   124  type PacketWriter struct {
   125  	io.Writer
   126  	Target net.Destination
   127  }
   128  
   129  // WriteMultiBuffer implements buf.Writer
   130  func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
   131  	for {
   132  		mb2, b := buf.SplitFirst(mb)
   133  		mb = mb2
   134  		if b == nil {
   135  			break
   136  		}
   137  		target := &w.Target
   138  		if b.UDP != nil {
   139  			target = b.UDP
   140  		}
   141  		if _, err := w.writePacket(b.Bytes(), *target); err != nil {
   142  			buf.ReleaseMulti(mb)
   143  			return err
   144  		}
   145  	}
   146  	return nil
   147  }
   148  
   149  func (w *PacketWriter) writePacket(payload []byte, dest net.Destination) (int, error) {
   150  	buffer := buf.StackNew()
   151  	defer buffer.Release()
   152  
   153  	length := len(payload)
   154  	lengthBuf := [2]byte{}
   155  	binary.BigEndian.PutUint16(lengthBuf[:], uint16(length))
   156  	if err := addrParser.WriteAddressPort(&buffer, dest.Address, dest.Port); err != nil {
   157  		return 0, err
   158  	}
   159  	if _, err := buffer.Write(lengthBuf[:]); err != nil {
   160  		return 0, err
   161  	}
   162  	if _, err := buffer.Write(crlf); err != nil {
   163  		return 0, err
   164  	}
   165  	if _, err := buffer.Write(payload); err != nil {
   166  		return 0, err
   167  	}
   168  	_, err := w.Write(buffer.Bytes())
   169  	if err != nil {
   170  		return 0, err
   171  	}
   172  
   173  	return length, nil
   174  }
   175  
   176  // ConnReader is TCP Connection Reader Wrapper for trojan protocol
   177  type ConnReader struct {
   178  	io.Reader
   179  	Target       net.Destination
   180  	Flow         string
   181  	headerParsed bool
   182  }
   183  
   184  // ParseHeader parses the trojan protocol header
   185  func (c *ConnReader) ParseHeader() error {
   186  	var crlf [2]byte
   187  	var command [1]byte
   188  	var hash [56]byte
   189  	if _, err := io.ReadFull(c.Reader, hash[:]); err != nil {
   190  		return newError("failed to read user hash").Base(err)
   191  	}
   192  
   193  	if _, err := io.ReadFull(c.Reader, crlf[:]); err != nil {
   194  		return newError("failed to read crlf").Base(err)
   195  	}
   196  
   197  	if _, err := io.ReadFull(c.Reader, command[:]); err != nil {
   198  		return newError("failed to read command").Base(err)
   199  	}
   200  
   201  	network := net.Network_TCP
   202  	if command[0] == commandUDP {
   203  		network = net.Network_UDP
   204  	} else if command[0] == commandXRD {
   205  		c.Flow = XRD
   206  	} else if command[0] == commandXRO {
   207  		c.Flow = XRO
   208  	}
   209  
   210  	addr, port, err := addrParser.ReadAddressPort(nil, c.Reader)
   211  	if err != nil {
   212  		return newError("failed to read address and port").Base(err)
   213  	}
   214  	c.Target = net.Destination{Network: network, Address: addr, Port: port}
   215  
   216  	if _, err := io.ReadFull(c.Reader, crlf[:]); err != nil {
   217  		return newError("failed to read crlf").Base(err)
   218  	}
   219  
   220  	c.headerParsed = true
   221  	return nil
   222  }
   223  
   224  // Read implements io.Reader
   225  func (c *ConnReader) Read(p []byte) (int, error) {
   226  	if !c.headerParsed {
   227  		if err := c.ParseHeader(); err != nil {
   228  			return 0, err
   229  		}
   230  	}
   231  
   232  	return c.Reader.Read(p)
   233  }
   234  
   235  // ReadMultiBuffer implements buf.Reader
   236  func (c *ConnReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   237  	b := buf.New()
   238  	_, err := b.ReadFrom(c)
   239  	return buf.MultiBuffer{b}, err
   240  }
   241  
   242  // PacketReader is UDP Connection Reader Wrapper for trojan protocol
   243  type PacketReader struct {
   244  	io.Reader
   245  }
   246  
   247  // ReadMultiBuffer implements buf.Reader
   248  func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   249  	addr, port, err := addrParser.ReadAddressPort(nil, r)
   250  	if err != nil {
   251  		return nil, newError("failed to read address and port").Base(err)
   252  	}
   253  
   254  	var lengthBuf [2]byte
   255  	if _, err := io.ReadFull(r, lengthBuf[:]); err != nil {
   256  		return nil, newError("failed to read payload length").Base(err)
   257  	}
   258  
   259  	remain := int(binary.BigEndian.Uint16(lengthBuf[:]))
   260  	if remain > maxLength {
   261  		return nil, newError("oversize payload")
   262  	}
   263  
   264  	var crlf [2]byte
   265  	if _, err := io.ReadFull(r, crlf[:]); err != nil {
   266  		return nil, newError("failed to read crlf").Base(err)
   267  	}
   268  
   269  	dest := net.UDPDestination(addr, port)
   270  	var mb buf.MultiBuffer
   271  	for remain > 0 {
   272  		length := buf.Size
   273  		if remain < length {
   274  			length = remain
   275  		}
   276  
   277  		b := buf.New()
   278  		b.UDP = &dest
   279  		mb = append(mb, b)
   280  		n, err := b.ReadFullFrom(r, int32(length))
   281  		if err != nil {
   282  			buf.ReleaseMulti(mb)
   283  			return nil, newError("failed to read payload").Base(err)
   284  		}
   285  
   286  		remain -= int(n)
   287  	}
   288  
   289  	return mb, nil
   290  }
   291  
   292  func ReadV(reader buf.Reader, writer buf.Writer, timer signal.ActivityUpdater, conn *xtls.Conn, rawConn syscall.RawConn, counter stats.Counter, sctx context.Context) error {
   293  	err := func() error {
   294  		var ct stats.Counter
   295  		for {
   296  			if conn.DirectIn {
   297  				conn.DirectIn = false
   298  				if sctx != nil {
   299  					if inbound := session.InboundFromContext(sctx); inbound != nil && inbound.Conn != nil {
   300  						iConn := inbound.Conn
   301  						statConn, ok := iConn.(*stat.CounterConnection)
   302  						if ok {
   303  							iConn = statConn.Connection
   304  						}
   305  						if xc, ok := iConn.(*xtls.Conn); ok {
   306  							iConn = xc.NetConn()
   307  						}
   308  						if tc, ok := iConn.(*net.TCPConn); ok {
   309  							if conn.SHOW {
   310  								fmt.Println(conn.MARK, "Splice")
   311  							}
   312  							runtime.Gosched() // necessary
   313  							w, err := tc.ReadFrom(conn.NetConn())
   314  							if counter != nil {
   315  								counter.Add(w)
   316  							}
   317  							if statConn != nil && statConn.WriteCounter != nil {
   318  								statConn.WriteCounter.Add(w)
   319  							}
   320  							return err
   321  						} else {
   322  							panic("XTLS Splice: not TCP inbound")
   323  						}
   324  					} else {
   325  						// panic("XTLS Splice: nil inbound or nil inbound.Conn")
   326  					}
   327  				}
   328  				reader = buf.NewReadVReader(conn.NetConn(), rawConn, nil)
   329  				ct = counter
   330  				if conn.SHOW {
   331  					fmt.Println(conn.MARK, "ReadV")
   332  				}
   333  			}
   334  			buffer, err := reader.ReadMultiBuffer()
   335  			if !buffer.IsEmpty() {
   336  				if ct != nil {
   337  					ct.Add(int64(buffer.Len()))
   338  				}
   339  				timer.Update()
   340  				if werr := writer.WriteMultiBuffer(buffer); werr != nil {
   341  					return werr
   342  				}
   343  			}
   344  			if err != nil {
   345  				return err
   346  			}
   347  		}
   348  	}()
   349  	if err != nil && errors.Cause(err) != io.EOF {
   350  		return err
   351  	}
   352  	return nil
   353  }