github.com/sagernet/sing-box@v1.2.7/transport/trojan/protocol.go (about)

     1  package trojan
     2  
     3  import (
     4  	"crypto/sha256"
     5  	"encoding/binary"
     6  	"encoding/hex"
     7  	"io"
     8  	"net"
     9  	"os"
    10  
    11  	"github.com/sagernet/sing/common"
    12  	"github.com/sagernet/sing/common/buf"
    13  	"github.com/sagernet/sing/common/bufio"
    14  	E "github.com/sagernet/sing/common/exceptions"
    15  	M "github.com/sagernet/sing/common/metadata"
    16  	N "github.com/sagernet/sing/common/network"
    17  	"github.com/sagernet/sing/common/rw"
    18  )
    19  
    20  const (
    21  	KeyLength  = 56
    22  	CommandTCP = 1
    23  	CommandUDP = 3
    24  	CommandMux = 0x7f
    25  )
    26  
    27  var CRLF = []byte{'\r', '\n'}
    28  
    29  var _ N.EarlyConn = (*ClientConn)(nil)
    30  
    31  type ClientConn struct {
    32  	N.ExtendedConn
    33  	key           [KeyLength]byte
    34  	destination   M.Socksaddr
    35  	headerWritten bool
    36  }
    37  
    38  func NewClientConn(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr) *ClientConn {
    39  	return &ClientConn{
    40  		ExtendedConn: bufio.NewExtendedConn(conn),
    41  		key:          key,
    42  		destination:  destination,
    43  	}
    44  }
    45  
    46  func (c *ClientConn) NeedHandshake() bool {
    47  	return !c.headerWritten
    48  }
    49  
    50  func (c *ClientConn) Write(p []byte) (n int, err error) {
    51  	if c.headerWritten {
    52  		return c.ExtendedConn.Write(p)
    53  	}
    54  	err = ClientHandshake(c.ExtendedConn, c.key, c.destination, p)
    55  	if err != nil {
    56  		return
    57  	}
    58  	n = len(p)
    59  	c.headerWritten = true
    60  	return
    61  }
    62  
    63  func (c *ClientConn) WriteBuffer(buffer *buf.Buffer) error {
    64  	if c.headerWritten {
    65  		return c.ExtendedConn.WriteBuffer(buffer)
    66  	}
    67  	err := ClientHandshakeBuffer(c.ExtendedConn, c.key, c.destination, buffer)
    68  	if err != nil {
    69  		return err
    70  	}
    71  	c.headerWritten = true
    72  	return nil
    73  }
    74  
    75  func (c *ClientConn) ReadFrom(r io.Reader) (n int64, err error) {
    76  	if !c.headerWritten {
    77  		return bufio.ReadFrom0(c, r)
    78  	}
    79  	return bufio.Copy(c.ExtendedConn, r)
    80  }
    81  
    82  func (c *ClientConn) WriteTo(w io.Writer) (n int64, err error) {
    83  	return bufio.Copy(w, c.ExtendedConn)
    84  }
    85  
    86  func (c *ClientConn) FrontHeadroom() int {
    87  	if !c.headerWritten {
    88  		return KeyLength + 5 + M.MaxSocksaddrLength
    89  	}
    90  	return 0
    91  }
    92  
    93  func (c *ClientConn) Upstream() any {
    94  	return c.ExtendedConn
    95  }
    96  
    97  type ClientPacketConn struct {
    98  	net.Conn
    99  	key           [KeyLength]byte
   100  	headerWritten bool
   101  }
   102  
   103  func NewClientPacketConn(conn net.Conn, key [KeyLength]byte) *ClientPacketConn {
   104  	return &ClientPacketConn{
   105  		Conn: conn,
   106  		key:  key,
   107  	}
   108  }
   109  
   110  func (c *ClientPacketConn) NeedHandshake() bool {
   111  	return !c.headerWritten
   112  }
   113  
   114  func (c *ClientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
   115  	return ReadPacket(c.Conn, buffer)
   116  }
   117  
   118  func (c *ClientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   119  	if !c.headerWritten {
   120  		err := ClientHandshakePacket(c.Conn, c.key, destination, buffer)
   121  		c.headerWritten = true
   122  		return err
   123  	}
   124  	return WritePacket(c.Conn, buffer, destination)
   125  }
   126  
   127  func (c *ClientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   128  	buffer := buf.With(p)
   129  	destination, err := c.ReadPacket(buffer)
   130  	if err != nil {
   131  		return
   132  	}
   133  	n = buffer.Len()
   134  	if destination.IsFqdn() {
   135  		addr = destination
   136  	} else {
   137  		addr = destination.UDPAddr()
   138  	}
   139  	return
   140  }
   141  
   142  func (c *ClientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   143  	return bufio.WritePacket(c, p, addr)
   144  }
   145  
   146  func (c *ClientPacketConn) Read(p []byte) (n int, err error) {
   147  	n, _, err = c.ReadFrom(p)
   148  	return
   149  }
   150  
   151  func (c *ClientPacketConn) Write(p []byte) (n int, err error) {
   152  	return 0, os.ErrInvalid
   153  }
   154  
   155  func (c *ClientPacketConn) FrontHeadroom() int {
   156  	if !c.headerWritten {
   157  		return KeyLength + 2*M.MaxSocksaddrLength + 9
   158  	}
   159  	return M.MaxSocksaddrLength + 4
   160  }
   161  
   162  func (c *ClientPacketConn) Upstream() any {
   163  	return c.Conn
   164  }
   165  
   166  func Key(password string) [KeyLength]byte {
   167  	var key [KeyLength]byte
   168  	hash := sha256.New224()
   169  	common.Must1(hash.Write([]byte(password)))
   170  	hex.Encode(key[:], hash.Sum(nil))
   171  	return key
   172  }
   173  
   174  func ClientHandshakeRaw(conn net.Conn, key [KeyLength]byte, command byte, destination M.Socksaddr, payload []byte) error {
   175  	_, err := conn.Write(key[:])
   176  	if err != nil {
   177  		return err
   178  	}
   179  	_, err = conn.Write(CRLF)
   180  	if err != nil {
   181  		return err
   182  	}
   183  	_, err = conn.Write([]byte{command})
   184  	if err != nil {
   185  		return err
   186  	}
   187  	err = M.SocksaddrSerializer.WriteAddrPort(conn, destination)
   188  	if err != nil {
   189  		return err
   190  	}
   191  	_, err = conn.Write(CRLF)
   192  	if err != nil {
   193  		return err
   194  	}
   195  	if len(payload) > 0 {
   196  		_, err = conn.Write(payload)
   197  		if err != nil {
   198  			return err
   199  		}
   200  	}
   201  	return nil
   202  }
   203  
   204  func ClientHandshake(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload []byte) error {
   205  	headerLen := KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5
   206  	var header *buf.Buffer
   207  	defer header.Release()
   208  	var writeHeader bool
   209  	if len(payload) > 0 && headerLen+len(payload) < 65535 {
   210  		buffer := buf.StackNewSize(headerLen + len(payload))
   211  		defer common.KeepAlive(buffer)
   212  		header = common.Dup(buffer)
   213  	} else {
   214  		buffer := buf.StackNewSize(headerLen)
   215  		defer common.KeepAlive(buffer)
   216  		header = common.Dup(buffer)
   217  		writeHeader = true
   218  	}
   219  	common.Must1(header.Write(key[:]))
   220  	common.Must1(header.Write(CRLF))
   221  	common.Must(header.WriteByte(CommandTCP))
   222  	common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
   223  	common.Must1(header.Write(CRLF))
   224  	if !writeHeader {
   225  		common.Must1(header.Write(payload))
   226  	}
   227  
   228  	_, err := conn.Write(header.Bytes())
   229  	if err != nil {
   230  		return E.Cause(err, "write request")
   231  	}
   232  
   233  	if writeHeader {
   234  		_, err = conn.Write(payload)
   235  		if err != nil {
   236  			return E.Cause(err, "write payload")
   237  		}
   238  	}
   239  	return nil
   240  }
   241  
   242  func ClientHandshakeBuffer(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error {
   243  	header := buf.With(payload.ExtendHeader(KeyLength + M.SocksaddrSerializer.AddrPortLen(destination) + 5))
   244  	common.Must1(header.Write(key[:]))
   245  	common.Must1(header.Write(CRLF))
   246  	common.Must(header.WriteByte(CommandTCP))
   247  	common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
   248  	common.Must1(header.Write(CRLF))
   249  
   250  	_, err := conn.Write(payload.Bytes())
   251  	if err != nil {
   252  		return E.Cause(err, "write request")
   253  	}
   254  	return nil
   255  }
   256  
   257  func ClientHandshakePacket(conn net.Conn, key [KeyLength]byte, destination M.Socksaddr, payload *buf.Buffer) error {
   258  	headerLen := KeyLength + 2*M.SocksaddrSerializer.AddrPortLen(destination) + 9
   259  	payloadLen := payload.Len()
   260  	var header *buf.Buffer
   261  	defer header.Release()
   262  	var writeHeader bool
   263  	if payload.Start() >= headerLen {
   264  		header = buf.With(payload.ExtendHeader(headerLen))
   265  	} else {
   266  		buffer := buf.StackNewSize(headerLen)
   267  		defer common.KeepAlive(buffer)
   268  		header = common.Dup(buffer)
   269  		writeHeader = true
   270  	}
   271  	common.Must1(header.Write(key[:]))
   272  	common.Must1(header.Write(CRLF))
   273  	common.Must(header.WriteByte(CommandUDP))
   274  	common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
   275  	common.Must1(header.Write(CRLF))
   276  	common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
   277  	common.Must(binary.Write(header, binary.BigEndian, uint16(payloadLen)))
   278  	common.Must1(header.Write(CRLF))
   279  
   280  	if writeHeader {
   281  		_, err := conn.Write(header.Bytes())
   282  		if err != nil {
   283  			return E.Cause(err, "write request")
   284  		}
   285  	}
   286  
   287  	_, err := conn.Write(payload.Bytes())
   288  	if err != nil {
   289  		return E.Cause(err, "write payload")
   290  	}
   291  	return nil
   292  }
   293  
   294  func ReadPacket(conn net.Conn, buffer *buf.Buffer) (M.Socksaddr, error) {
   295  	destination, err := M.SocksaddrSerializer.ReadAddrPort(conn)
   296  	if err != nil {
   297  		return M.Socksaddr{}, E.Cause(err, "read destination")
   298  	}
   299  
   300  	var length uint16
   301  	err = binary.Read(conn, binary.BigEndian, &length)
   302  	if err != nil {
   303  		return M.Socksaddr{}, E.Cause(err, "read chunk length")
   304  	}
   305  
   306  	err = rw.SkipN(conn, 2)
   307  	if err != nil {
   308  		return M.Socksaddr{}, E.Cause(err, "skip crlf")
   309  	}
   310  
   311  	_, err = buffer.ReadFullFrom(conn, int(length))
   312  	return destination, err
   313  }
   314  
   315  func WritePacket(conn net.Conn, buffer *buf.Buffer, destination M.Socksaddr) error {
   316  	defer buffer.Release()
   317  	bufferLen := buffer.Len()
   318  	header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 4))
   319  	common.Must(M.SocksaddrSerializer.WriteAddrPort(header, destination))
   320  	common.Must(binary.Write(header, binary.BigEndian, uint16(bufferLen)))
   321  	common.Must1(header.Write(CRLF))
   322  	_, err := conn.Write(buffer.Bytes())
   323  	if err != nil {
   324  		return E.Cause(err, "write packet")
   325  	}
   326  	return nil
   327  }