github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/yuubinsya/crypto/handshake.go (about)

     1  package crypto
     2  
     3  import (
     4  	"crypto/cipher"
     5  	"crypto/ecdh"
     6  	"crypto/rand"
     7  	"encoding/binary"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"math"
    12  	"net"
    13  	"time"
    14  
    15  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    16  	"github.com/Asutorufa/yuhaiin/pkg/net/proxy/socks5/tools"
    17  	"github.com/Asutorufa/yuhaiin/pkg/net/proxy/yuubinsya/types"
    18  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    19  	"golang.org/x/crypto/chacha20"
    20  	"golang.org/x/crypto/hkdf"
    21  )
    22  
    23  type encryptedHandshaker struct {
    24  	server bool
    25  
    26  	signer   types.Signer
    27  	hash     types.Hash
    28  	aead     types.Aead
    29  	password []byte
    30  }
    31  
    32  func (t *encryptedHandshaker) EncodeHeader(net types.Protocol, buf types.Buffer, addr netapi.Address) {
    33  	_, _ = buf.Write([]byte{byte(net)})
    34  
    35  	if net == types.TCP {
    36  		tools.EncodeAddr(addr, buf)
    37  	}
    38  }
    39  
    40  func (t *encryptedHandshaker) DecodeHeader(c net.Conn) (types.Protocol, error) {
    41  	z := make([]byte, 1)
    42  
    43  	if _, err := io.ReadFull(c, z); err != nil {
    44  		return 0, fmt.Errorf("read net type failed: %w", err)
    45  	}
    46  	net := types.Protocol(z[0])
    47  
    48  	if net.Unknown() {
    49  		return 0, fmt.Errorf("unknown network")
    50  	}
    51  
    52  	return net, nil
    53  }
    54  
    55  func (h *encryptedHandshaker) Handshake(conn net.Conn) (net.Conn, error) {
    56  	if h.server {
    57  		return h.handshakeServer(conn)
    58  	}
    59  
    60  	return h.handshakeClient(conn)
    61  }
    62  
    63  func (h *encryptedHandshaker) handshakeClient(conn net.Conn) (net.Conn, error) {
    64  	header := newHeader(h)
    65  	defer header.Def()
    66  
    67  	salt := make([]byte, h.hash.Size())
    68  
    69  	pk, time1, err := h.send(header, conn, nil)
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  
    74  	copy(salt, header.salt()) // client salt
    75  
    76  	rpb, time2, err := h.receive(header, conn, salt)
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  
    81  	if pk.PublicKey().Equal(rpb) {
    82  		return nil, fmt.Errorf("look like replay attack")
    83  	}
    84  
    85  	cryptKey, err := pk.ECDH(rpb)
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  
    90  	raead, rnonce, err := h.newAead(cryptKey, salt, time1)
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	waead, wnonce, err := h.newAead(cryptKey, salt, time2)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  
   100  	return NewConn(conn, rnonce, wnonce, raead, waead), nil
   101  }
   102  
   103  func (h *encryptedHandshaker) handshakeServer(conn net.Conn) (net.Conn, error) {
   104  	header := newHeader(h)
   105  	defer header.Def()
   106  
   107  	salt := make([]byte, h.hash.Size())
   108  
   109  	rpb, time1, err := h.receive(header, conn, nil)
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  
   114  	copy(salt, header.salt()) // client salt
   115  
   116  	pk, time2, err := h.send(header, conn, salt)
   117  	if err != nil {
   118  		return nil, err
   119  	}
   120  
   121  	if pk.PublicKey().Equal(rpb) {
   122  		return nil, fmt.Errorf("look like replay attack")
   123  	}
   124  
   125  	cryptKey, err := pk.ECDH(rpb)
   126  	if err != nil {
   127  		return nil, err
   128  	}
   129  
   130  	raead, rnonce, err := h.newAead(cryptKey, salt, time1)
   131  	if err != nil {
   132  		return nil, err
   133  	}
   134  
   135  	waead, wnonce, err := h.newAead(cryptKey, salt, time2)
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  
   140  	return NewConn(conn, wnonce, rnonce, waead, raead), nil
   141  }
   142  
   143  func (h *encryptedHandshaker) newAead(cryptKey, salt, time []byte) (cipher.AEAD, []byte, error) {
   144  	keyNonce := make([]byte, h.aead.KeySize()+h.aead.NonceSize())
   145  	if _, err := io.ReadFull(hkdf.New(h.hash.New, cryptKey, salt, append(h.aead.Name(), time...)), keyNonce); err != nil {
   146  		return nil, nil, err
   147  	}
   148  	aead, err := h.aead.New(keyNonce[:h.aead.KeySize()])
   149  	if err != nil {
   150  		return nil, nil, err
   151  	}
   152  
   153  	return aead, keyNonce[h.aead.KeySize():], nil
   154  }
   155  
   156  func (h *encryptedHandshaker) receive(buf *header, conn net.Conn, salt []byte) (_ *ecdh.PublicKey, ttime []byte, _ error) {
   157  	_, err := io.ReadFull(conn, buf.Bytes())
   158  	if err != nil {
   159  		return nil, nil, err
   160  	}
   161  
   162  	if salt != nil {
   163  		copy(buf.salt(), salt) // client: verify signature with client salt
   164  	}
   165  
   166  	if !h.signer.Verify(buf.saltTimeSignature(), buf.signature()) {
   167  		return nil, nil, errors.New("can't verify signature")
   168  	}
   169  
   170  	ttime = make([]byte, 8)
   171  	if err = h.encryptTime(h.password, buf.salt(), ttime, buf.time()); err != nil {
   172  		return nil, nil, fmt.Errorf("decrypt time failed: %w", err)
   173  	}
   174  
   175  	if math.Abs(float64(time.Now().Unix()-int64(binary.BigEndian.Uint64(ttime)))) > 30 { // check time is in +-30s
   176  		return nil, nil, errors.New("bad timestamp")
   177  	}
   178  
   179  	pubkey, err := ecdh.P256().NewPublicKey(buf.publickey())
   180  	if err != nil {
   181  		return nil, nil, err
   182  	}
   183  
   184  	return pubkey, ttime, nil
   185  }
   186  
   187  func (h *encryptedHandshaker) send(buf *header, conn net.Conn, salt []byte) (_ *ecdh.PrivateKey, ttime []byte, _ error) {
   188  	pk, err := ecdh.P256().GenerateKey(rand.Reader)
   189  	if err != nil {
   190  		return nil, nil, err
   191  	}
   192  
   193  	if salt != nil {
   194  		copy(buf.salt(), salt) // server: sign with client salt
   195  	} else {
   196  		if _, err = rand.Read(buf.salt()); err != nil { // client: read random bytes to salt
   197  			return nil, nil, fmt.Errorf("read salt from rand failed: %w", err)
   198  		}
   199  	}
   200  
   201  	copy(buf.publickey(), pk.PublicKey().Bytes())
   202  
   203  	ttime = make([]byte, 8)
   204  	binary.BigEndian.PutUint64(ttime, uint64(time.Now().Unix()))
   205  
   206  	if err = h.encryptTime(h.password, buf.salt(), buf.time(), ttime); err != nil {
   207  		return nil, nil, fmt.Errorf("encrypt time failed: %w", err)
   208  	}
   209  
   210  	signature, err := h.signer.Sign(rand.Reader, buf.saltTimeSignature())
   211  	if err != nil {
   212  		return nil, nil, err
   213  	}
   214  
   215  	copy(buf.signature(), signature)
   216  
   217  	if salt != nil {
   218  		if _, err := rand.Read(buf.salt()); err != nil { // server: read random bytes to padding
   219  			return nil, nil, fmt.Errorf("read salt from rand failed: %w", err)
   220  		}
   221  	}
   222  
   223  	if _, err = conn.Write(buf.Bytes()); err != nil {
   224  		return nil, nil, err
   225  	}
   226  
   227  	return pk, ttime, nil
   228  }
   229  
   230  type header struct {
   231  	bytes *pool.Bytes
   232  	th    *encryptedHandshaker
   233  }
   234  
   235  func newHeader(h *encryptedHandshaker) *header {
   236  	return &header{pool.GetBytesBuffer(h.hash.Size() + 8 + h.signer.SignatureSize() + 65), h}
   237  }
   238  func (h *header) Bytes() []byte { return h.bytes.Bytes() }
   239  func (h *header) signature() []byte {
   240  	return h.Bytes()[:h.th.signer.SignatureSize()]
   241  }
   242  func (h *header) publickey() []byte {
   243  	return h.Bytes()[h.th.hash.Size()+8+h.th.signer.SignatureSize():]
   244  }
   245  func (h *header) time() []byte {
   246  	return h.Bytes()[h.th.hash.Size()+h.th.signer.SignatureSize() : h.th.hash.Size()+8+h.th.signer.SignatureSize()]
   247  }
   248  func (h *header) salt() []byte {
   249  	return h.Bytes()[h.th.signer.SignatureSize() : h.th.signer.SignatureSize()+h.th.hash.Size()]
   250  }
   251  func (h *header) saltTimeSignature() []byte {
   252  	return h.Bytes()[h.th.signer.SignatureSize():]
   253  }
   254  func (h *header) Def() { defer h.bytes.Free() }
   255  
   256  func (h *encryptedHandshaker) encryptTime(password, salt, dst, src []byte) error {
   257  	nonce := make([]byte, chacha20.NonceSize)
   258  	key := make([]byte, chacha20.KeySize)
   259  
   260  	kdf := hkdf.New(h.hash.New, password, salt, []byte{'t', 'i', 'm', 'e'})
   261  
   262  	if _, err := io.ReadFull(kdf, key); err != nil {
   263  		return err
   264  	}
   265  	if _, err := io.ReadFull(kdf, nonce); err != nil {
   266  		return err
   267  	}
   268  
   269  	cipher, err := chacha20.NewUnauthenticatedCipher(key, nonce)
   270  	if err != nil {
   271  		return err
   272  	}
   273  
   274  	cipher.XORKeyStream(dst, src)
   275  
   276  	return nil
   277  }
   278  
   279  func NewHandshaker(server bool, hash []byte, password []byte) *encryptedHandshaker {
   280  	// sha256-hkdf-ecdh-ed25519-chacha20poly1305
   281  	return &encryptedHandshaker{
   282  		signer:   NewEd25519(Sha256, hash),
   283  		hash:     Sha256,
   284  		aead:     Chacha20poly1305,
   285  		password: password,
   286  		server:   server,
   287  	}
   288  }