github.com/chwjbn/xclash@v0.2.0/transport/snell/snell.go (about)

     1  package snell
     2  
     3  import (
     4  	"encoding/binary"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  
    10  	"github.com/chwjbn/xclash/common/pool"
    11  
    12  	"github.com/Dreamacro/go-shadowsocks2/shadowaead"
    13  )
    14  
    15  const (
    16  	Version1            = 1
    17  	Version2            = 2
    18  	DefaultSnellVersion = Version1
    19  )
    20  
    21  const (
    22  	CommandPing      byte = 0
    23  	CommandConnect   byte = 1
    24  	CommandConnectV2 byte = 5
    25  
    26  	CommandTunnel byte = 0
    27  	CommandPong   byte = 1
    28  	CommandError  byte = 2
    29  
    30  	Version byte = 1
    31  )
    32  
    33  var endSignal = []byte{}
    34  
    35  type Snell struct {
    36  	net.Conn
    37  	buffer [1]byte
    38  	reply  bool
    39  }
    40  
    41  func (s *Snell) Read(b []byte) (int, error) {
    42  	if s.reply {
    43  		return s.Conn.Read(b)
    44  	}
    45  
    46  	s.reply = true
    47  	if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
    48  		return 0, err
    49  	}
    50  
    51  	if s.buffer[0] == CommandTunnel {
    52  		return s.Conn.Read(b)
    53  	} else if s.buffer[0] != CommandError {
    54  		return 0, errors.New("command not support")
    55  	}
    56  
    57  	// CommandError
    58  	// 1 byte error code
    59  	if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
    60  		return 0, err
    61  	}
    62  	errcode := int(s.buffer[0])
    63  
    64  	// 1 byte error message length
    65  	if _, err := io.ReadFull(s.Conn, s.buffer[:]); err != nil {
    66  		return 0, err
    67  	}
    68  	length := int(s.buffer[0])
    69  	msg := make([]byte, length)
    70  
    71  	if _, err := io.ReadFull(s.Conn, msg); err != nil {
    72  		return 0, err
    73  	}
    74  
    75  	return 0, fmt.Errorf("server reported code: %d, message: %s", errcode, string(msg))
    76  }
    77  
    78  func WriteHeader(conn net.Conn, host string, port uint, version int) error {
    79  	buf := pool.GetBuffer()
    80  	defer pool.PutBuffer(buf)
    81  	buf.WriteByte(Version)
    82  	if version == Version2 {
    83  		buf.WriteByte(CommandConnectV2)
    84  	} else {
    85  		buf.WriteByte(CommandConnect)
    86  	}
    87  
    88  	// clientID length & id
    89  	buf.WriteByte(0)
    90  
    91  	// host & port
    92  	buf.WriteByte(uint8(len(host)))
    93  	buf.WriteString(host)
    94  	binary.Write(buf, binary.BigEndian, uint16(port))
    95  
    96  	if _, err := conn.Write(buf.Bytes()); err != nil {
    97  		return err
    98  	}
    99  
   100  	return nil
   101  }
   102  
   103  // HalfClose works only on version2
   104  func HalfClose(conn net.Conn) error {
   105  	if _, err := conn.Write(endSignal); err != nil {
   106  		return err
   107  	}
   108  
   109  	if s, ok := conn.(*Snell); ok {
   110  		s.reply = false
   111  	}
   112  	return nil
   113  }
   114  
   115  func StreamConn(conn net.Conn, psk []byte, version int) *Snell {
   116  	var cipher shadowaead.Cipher
   117  	if version == Version2 {
   118  		cipher = NewAES128GCM(psk)
   119  	} else {
   120  		cipher = NewChacha20Poly1305(psk)
   121  	}
   122  	return &Snell{Conn: shadowaead.NewConn(conn, cipher)}
   123  }