github.com/metacubex/mihomo@v1.18.5/transport/snell/snell.go (about)

     1  package snell
     2  
     3  import (
     4  	"encoding/binary"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"sync"
    10  
    11  	"github.com/metacubex/mihomo/common/pool"
    12  	"github.com/metacubex/mihomo/transport/shadowsocks/shadowaead"
    13  	"github.com/metacubex/mihomo/transport/socks5"
    14  )
    15  
    16  const (
    17  	Version1            = 1
    18  	Version2            = 2
    19  	Version3            = 3
    20  	DefaultSnellVersion = Version1
    21  
    22  	// max packet length
    23  	maxLength = 0x3FFF
    24  )
    25  
    26  const (
    27  	CommandPing       byte = 0
    28  	CommandConnect    byte = 1
    29  	CommandConnectV2  byte = 5
    30  	CommandUDP        byte = 6
    31  	CommondUDPForward byte = 1
    32  
    33  	CommandTunnel byte = 0
    34  	CommandPong   byte = 1
    35  	CommandError  byte = 2
    36  
    37  	Version byte = 1
    38  )
    39  
    40  var endSignal = []byte{}
    41  
    42  type Snell struct {
    43  	net.Conn
    44  	buffer [1]byte
    45  	reply  bool
    46  }
    47  
    48  func (s *Snell) Read(b []byte) (int, error) {
    49  	if s.reply {
    50  		return s.Conn.Read(b)
    51  	}
    52  
    53  	s.reply = true
    54  	if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
    55  		return 0, err
    56  	}
    57  
    58  	if s.buffer[0] == CommandTunnel {
    59  		return s.Conn.Read(b)
    60  	} else if s.buffer[0] != CommandError {
    61  		return 0, errors.New("command not support")
    62  	}
    63  
    64  	// CommandError
    65  	// 1 byte error code
    66  	if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
    67  		return 0, err
    68  	}
    69  	errcode := int(s.buffer[0])
    70  
    71  	// 1 byte error message length
    72  	if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
    73  		return 0, err
    74  	}
    75  	length := int(s.buffer[0])
    76  	msg := make([]byte, length)
    77  
    78  	if _, err := io.ReadFull(s.Conn, msg); err != nil {
    79  		return 0, err
    80  	}
    81  
    82  	return 0, fmt.Errorf("server reported code: %d, message: %s", errcode, string(msg))
    83  }
    84  
    85  func WriteHeader(conn net.Conn, host string, port uint, version int) error {
    86  	buf := pool.GetBuffer()
    87  	defer pool.PutBuffer(buf)
    88  	buf.WriteByte(Version)
    89  	if version == Version2 {
    90  		buf.WriteByte(CommandConnectV2)
    91  	} else {
    92  		buf.WriteByte(CommandConnect)
    93  	}
    94  
    95  	// clientID length & id
    96  	buf.WriteByte(0)
    97  
    98  	// host & port
    99  	buf.WriteByte(uint8(len(host)))
   100  	buf.WriteString(host)
   101  	binary.Write(buf, binary.BigEndian, uint16(port))
   102  
   103  	if _, err := conn.Write(buf.Bytes()); err != nil {
   104  		return err
   105  	}
   106  
   107  	return nil
   108  }
   109  
   110  func WriteUDPHeader(conn net.Conn, version int) error {
   111  	if version < Version3 {
   112  		return errors.New("unsupport UDP version")
   113  	}
   114  
   115  	// version, command, clientID length
   116  	_, err := conn.Write([]byte{Version, CommandUDP, 0x00})
   117  	return err
   118  }
   119  
   120  // HalfClose works only on version2
   121  func HalfClose(conn net.Conn) error {
   122  	if _, err := conn.Write(endSignal); err != nil {
   123  		return err
   124  	}
   125  
   126  	if s, ok := conn.(*Snell); ok {
   127  		s.reply = false
   128  	}
   129  	return nil
   130  }
   131  
   132  func StreamConn(conn net.Conn, psk []byte, version int) *Snell {
   133  	var cipher shadowaead.Cipher
   134  	if version != Version1 {
   135  		cipher = NewAES128GCM(psk)
   136  	} else {
   137  		cipher = NewChacha20Poly1305(psk)
   138  	}
   139  	return &Snell{Conn: shadowaead.NewConn(conn, cipher)}
   140  }
   141  
   142  func PacketConn(conn net.Conn) net.PacketConn {
   143  	return &packetConn{
   144  		Conn: conn,
   145  	}
   146  }
   147  
   148  func writePacket(w io.Writer, socks5Addr, payload []byte) (int, error) {
   149  	buf := pool.GetBuffer()
   150  	defer pool.PutBuffer(buf)
   151  
   152  	// compose snell UDP address format (refer: icpz/snell-server-reversed)
   153  	// a brand new wheel to replace socks5 address format, well done Yachen
   154  	buf.WriteByte(CommondUDPForward)
   155  	switch socks5Addr[0] {
   156  	case socks5.AtypDomainName:
   157  		hostLen := socks5Addr[1]
   158  		buf.Write(socks5Addr[1 : 1+1+hostLen+2])
   159  	case socks5.AtypIPv4:
   160  		buf.Write([]byte{0x00, 0x04})
   161  		buf.Write(socks5Addr[1 : 1+net.IPv4len+2])
   162  	case socks5.AtypIPv6:
   163  		buf.Write([]byte{0x00, 0x06})
   164  		buf.Write(socks5Addr[1 : 1+net.IPv6len+2])
   165  	}
   166  
   167  	buf.Write(payload)
   168  	_, err := w.Write(buf.Bytes())
   169  	if err != nil {
   170  		return 0, err
   171  	}
   172  	return len(payload), nil
   173  }
   174  
   175  func WritePacket(w io.Writer, socks5Addr, payload []byte) (int, error) {
   176  	if len(payload) <= maxLength {
   177  		return writePacket(w, socks5Addr, payload)
   178  	}
   179  
   180  	offset := 0
   181  	total := len(payload)
   182  	for {
   183  		cursor := offset + maxLength
   184  		if cursor > total {
   185  			cursor = total
   186  		}
   187  
   188  		n, err := writePacket(w, socks5Addr, payload[offset:cursor])
   189  		if err != nil {
   190  			return offset + n, err
   191  		}
   192  
   193  		offset = cursor
   194  		if offset == total {
   195  			break
   196  		}
   197  	}
   198  
   199  	return total, nil
   200  }
   201  
   202  func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, error) {
   203  	buf := pool.Get(pool.UDPBufferSize)
   204  	defer pool.Put(buf)
   205  
   206  	n, err := r.Read(buf)
   207  	headLen := 1
   208  	if err != nil {
   209  		return nil, 0, err
   210  	}
   211  	if n < headLen {
   212  		return nil, 0, errors.New("insufficient UDP length")
   213  	}
   214  
   215  	// parse snell UDP response address format
   216  	switch buf[0] {
   217  	case 0x04:
   218  		headLen += net.IPv4len + 2
   219  		if n < headLen {
   220  			err = errors.New("insufficient UDP length")
   221  			break
   222  		}
   223  		buf[0] = socks5.AtypIPv4
   224  	case 0x06:
   225  		headLen += net.IPv6len + 2
   226  		if n < headLen {
   227  			err = errors.New("insufficient UDP length")
   228  			break
   229  		}
   230  		buf[0] = socks5.AtypIPv6
   231  	default:
   232  		err = errors.New("ip version invalid")
   233  	}
   234  
   235  	if err != nil {
   236  		return nil, 0, err
   237  	}
   238  
   239  	addr := socks5.SplitAddr(buf[0:])
   240  	if addr == nil {
   241  		return nil, 0, errors.New("remote address invalid")
   242  	}
   243  	uAddr := addr.UDPAddr()
   244  	if uAddr == nil {
   245  		return nil, 0, errors.New("parse addr error")
   246  	}
   247  
   248  	length := len(payload)
   249  	if n-headLen < length {
   250  		length = n - headLen
   251  	}
   252  	copy(payload[:], buf[headLen:headLen+length])
   253  
   254  	return uAddr, length, nil
   255  }
   256  
   257  type packetConn struct {
   258  	net.Conn
   259  	rMux sync.Mutex
   260  	wMux sync.Mutex
   261  }
   262  
   263  func (pc *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) {
   264  	pc.wMux.Lock()
   265  	defer pc.wMux.Unlock()
   266  
   267  	return WritePacket(pc, socks5.ParseAddr(addr.String()), b)
   268  }
   269  
   270  func (pc *packetConn) ReadFrom(b []byte) (int, net.Addr, error) {
   271  	pc.rMux.Lock()
   272  	defer pc.rMux.Unlock()
   273  
   274  	addr, n, err := ReadPacket(pc.Conn, b)
   275  	if err != nil {
   276  		return 0, nil, err
   277  	}
   278  
   279  	return n, addr, nil
   280  }