github.com/igoogolx/clash@v1.19.8/transport/snell/snell.go (about)

     1  package snell
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"sync"
     9  
    10  	"github.com/igoogolx/clash/common/pool"
    11  	"github.com/igoogolx/clash/transport/shadowsocks/shadowaead"
    12  	"github.com/igoogolx/clash/transport/socks5"
    13  )
    14  
    15  const (
    16  	Version1            = 1
    17  	Version2            = 2
    18  	Version3            = 3
    19  	DefaultSnellVersion = Version1
    20  
    21  	// max packet length
    22  	maxLength = 0x3FFF
    23  )
    24  
    25  const (
    26  	CommandPing       byte = 0
    27  	CommandConnect    byte = 1
    28  	CommandConnectV2  byte = 5
    29  	CommandUDP        byte = 6
    30  	CommondUDPForward byte = 1
    31  
    32  	CommandTunnel byte = 0
    33  	CommandPong   byte = 1
    34  	CommandError  byte = 2
    35  
    36  	Version byte = 1
    37  )
    38  
    39  var endSignal = []byte{}
    40  
    41  type Snell struct {
    42  	net.Conn
    43  	buffer [1]byte
    44  	reply  bool
    45  }
    46  
    47  func (s *Snell) Read(b []byte) (int, error) {
    48  	if s.reply {
    49  		return s.Conn.Read(b)
    50  	}
    51  
    52  	s.reply = true
    53  	if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
    54  		return 0, err
    55  	}
    56  
    57  	if s.buffer[0] == CommandTunnel {
    58  		return s.Conn.Read(b)
    59  	} else if s.buffer[0] != CommandError {
    60  		return 0, errors.New("command not support")
    61  	}
    62  
    63  	// CommandError
    64  	// 1 byte error code
    65  	if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
    66  		return 0, err
    67  	}
    68  	errcode := int(s.buffer[0])
    69  
    70  	// 1 byte error message length
    71  	if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
    72  		return 0, err
    73  	}
    74  	length := int(s.buffer[0])
    75  	msg := make([]byte, length)
    76  
    77  	if _, err := io.ReadFull(s.Conn, msg); err != nil {
    78  		return 0, err
    79  	}
    80  
    81  	return 0, fmt.Errorf("server reported code: %d, message: %s", errcode, string(msg))
    82  }
    83  
    84  func WriteHeader(conn net.Conn, host string, port uint, version int) error {
    85  	buf := pool.GetBytesBuffer()
    86  	defer pool.PutBytesBuffer(buf)
    87  	buf.PutUint8(Version)
    88  	if version == Version2 {
    89  		buf.PutUint8(CommandConnectV2)
    90  	} else {
    91  		buf.PutUint8(CommandConnect)
    92  	}
    93  
    94  	// clientID length & id
    95  	buf.PutUint8(0)
    96  
    97  	// host & port
    98  	buf.PutUint8(uint8(len(host)))
    99  	buf.PutString(host)
   100  	buf.PutUint16be(uint16(port))
   101  
   102  	if _, err := conn.Write(buf.Bytes()); err != nil {
   103  		return err
   104  	}
   105  
   106  	return nil
   107  }
   108  
   109  func WriteUDPHeader(conn net.Conn, version int) error {
   110  	if version < Version3 {
   111  		return errors.New("unsupport UDP version")
   112  	}
   113  
   114  	// version, command, clientID length
   115  	_, err := conn.Write([]byte{Version, CommandUDP, 0x00})
   116  	return err
   117  }
   118  
   119  // HalfClose works only on version2
   120  func HalfClose(conn net.Conn) error {
   121  	if _, err := conn.Write(endSignal); err != nil {
   122  		return err
   123  	}
   124  
   125  	if s, ok := conn.(*Snell); ok {
   126  		s.reply = false
   127  	}
   128  	return nil
   129  }
   130  
   131  func StreamConn(conn net.Conn, psk []byte, version int) *Snell {
   132  	var cipher shadowaead.Cipher
   133  	if version != Version1 {
   134  		cipher = NewAES128GCM(psk)
   135  	} else {
   136  		cipher = NewChacha20Poly1305(psk)
   137  	}
   138  	return &Snell{Conn: shadowaead.NewConn(conn, cipher)}
   139  }
   140  
   141  func PacketConn(conn net.Conn) net.PacketConn {
   142  	return &packetConn{
   143  		Conn: conn,
   144  	}
   145  }
   146  
   147  func writePacket(w io.Writer, socks5Addr, payload []byte) (int, error) {
   148  	buf := pool.GetBytesBuffer()
   149  	defer pool.PutBytesBuffer(buf)
   150  
   151  	// compose snell UDP address format (refer: icpz/snell-server-reversed)
   152  	// a brand new wheel to replace socks5 address format, well done Yachen
   153  	buf.PutUint8(CommondUDPForward)
   154  	switch socks5Addr[0] {
   155  	case socks5.AtypDomainName:
   156  		hostLen := socks5Addr[1]
   157  		buf.PutSlice(socks5Addr[1 : 1+1+hostLen+2])
   158  	case socks5.AtypIPv4:
   159  		buf.PutSlice([]byte{0x00, 0x04})
   160  		buf.PutSlice(socks5Addr[1 : 1+net.IPv4len+2])
   161  	case socks5.AtypIPv6:
   162  		buf.PutSlice([]byte{0x00, 0x06})
   163  		buf.PutSlice(socks5Addr[1 : 1+net.IPv6len+2])
   164  	}
   165  
   166  	buf.PutSlice(payload)
   167  	_, err := w.Write(buf.Bytes())
   168  	if err != nil {
   169  		return 0, err
   170  	}
   171  	return len(payload), nil
   172  }
   173  
   174  func WritePacket(w io.Writer, socks5Addr, payload []byte) (int, error) {
   175  	if len(payload) <= maxLength {
   176  		return writePacket(w, socks5Addr, payload)
   177  	}
   178  
   179  	offset := 0
   180  	total := len(payload)
   181  	for {
   182  		cursor := offset + maxLength
   183  		if cursor > total {
   184  			cursor = total
   185  		}
   186  
   187  		n, err := writePacket(w, socks5Addr, payload[offset:cursor])
   188  		if err != nil {
   189  			return offset + n, err
   190  		}
   191  
   192  		offset = cursor
   193  		if offset == total {
   194  			break
   195  		}
   196  	}
   197  
   198  	return total, nil
   199  }
   200  
   201  func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, error) {
   202  	buf := pool.Get(pool.UDPBufferSize)
   203  	defer pool.Put(buf)
   204  
   205  	n, err := r.Read(buf)
   206  	headLen := 1
   207  	if err != nil {
   208  		return nil, 0, err
   209  	}
   210  	if n < headLen {
   211  		return nil, 0, errors.New("insufficient UDP length")
   212  	}
   213  
   214  	// parse snell UDP response address format
   215  	switch buf[0] {
   216  	case 0x04:
   217  		headLen += net.IPv4len + 2
   218  		if n < headLen {
   219  			err = errors.New("insufficient UDP length")
   220  			break
   221  		}
   222  		buf[0] = socks5.AtypIPv4
   223  	case 0x06:
   224  		headLen += net.IPv6len + 2
   225  		if n < headLen {
   226  			err = errors.New("insufficient UDP length")
   227  			break
   228  		}
   229  		buf[0] = socks5.AtypIPv6
   230  	default:
   231  		err = errors.New("ip version invalid")
   232  	}
   233  
   234  	if err != nil {
   235  		return nil, 0, err
   236  	}
   237  
   238  	addr := socks5.SplitAddr(buf[0:])
   239  	if addr == nil {
   240  		return nil, 0, errors.New("remote address invalid")
   241  	}
   242  	uAddr := addr.UDPAddr()
   243  	if uAddr == nil {
   244  		return nil, 0, errors.New("parse addr error")
   245  	}
   246  
   247  	length := len(payload)
   248  	if n-headLen < length {
   249  		length = n - headLen
   250  	}
   251  	copy(payload[:], buf[headLen:headLen+length])
   252  
   253  	return uAddr, length, nil
   254  }
   255  
   256  type packetConn struct {
   257  	net.Conn
   258  	rMux sync.Mutex
   259  	wMux sync.Mutex
   260  }
   261  
   262  func (pc *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) {
   263  	pc.wMux.Lock()
   264  	defer pc.wMux.Unlock()
   265  
   266  	return WritePacket(pc, socks5.ParseAddr(addr.String()), b)
   267  }
   268  
   269  func (pc *packetConn) ReadFrom(b []byte) (int, net.Addr, error) {
   270  	pc.rMux.Lock()
   271  	defer pc.rMux.Unlock()
   272  
   273  	addr, n, err := ReadPacket(pc.Conn, b)
   274  	if err != nil {
   275  		return 0, nil, err
   276  	}
   277  
   278  	return n, addr, nil
   279  }