github.com/anacrolix/torrent@v1.61.0/mse/mse.go (about)

     1  // https://wiki.vuze.com/w/Message_Stream_Encryption
     2  
     3  package mse
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"crypto/rand"
     9  	"crypto/rc4"
    10  	"crypto/sha1"
    11  	"encoding/binary"
    12  	"errors"
    13  	"expvar"
    14  	"fmt"
    15  	"io"
    16  	"math"
    17  	"math/big"
    18  	"strconv"
    19  	"sync"
    20  
    21  	"github.com/anacrolix/torrent/internal/ctxrw"
    22  )
    23  
    24  const (
    25  	maxPadLen = 512
    26  
    27  	CryptoMethodPlaintext CryptoMethod = 1 // After header obfuscation, drop into plaintext
    28  	CryptoMethodRC4       CryptoMethod = 2 // After header obfuscation, use RC4 for the rest of the stream
    29  	AllSupportedCrypto                 = CryptoMethodPlaintext | CryptoMethodRC4
    30  )
    31  
    32  type CryptoMethod uint32
    33  
    34  var (
    35  	// Prime P according to the spec, and G, the generator.
    36  	p, specG big.Int
    37  	// The rand.Int max arg for use in newPadLen()
    38  	newPadLenMax big.Int
    39  	// For use in initer's hashes
    40  	req1 = []byte("req1")
    41  	req2 = []byte("req2")
    42  	req3 = []byte("req3")
    43  	// Verification constant "VC" which is all zeroes in the bittorrent
    44  	// implementation.
    45  	vc [8]byte
    46  	// Zero padding
    47  	zeroPad [512]byte
    48  	// Tracks counts of received crypto_provides
    49  	cryptoProvidesCount = expvar.NewMap("mseCryptoProvides")
    50  )
    51  
    52  func init() {
    53  	p.SetString("0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A63A36210000000000090563", 0)
    54  	specG.SetInt64(2)
    55  	newPadLenMax.SetInt64(maxPadLen + 1)
    56  }
    57  
    58  func hash(parts ...[]byte) []byte {
    59  	h := sha1.New()
    60  	for _, p := range parts {
    61  		n, err := h.Write(p)
    62  		if err != nil {
    63  			panic(err)
    64  		}
    65  		if n != len(p) {
    66  			panic(n)
    67  		}
    68  	}
    69  	return h.Sum(nil)
    70  }
    71  
    72  func newEncrypt(initer bool, s, skey []byte) (c *rc4.Cipher) {
    73  	c, err := rc4.NewCipher(hash([]byte(func() string {
    74  		if initer {
    75  			return "keyA"
    76  		} else {
    77  			return "keyB"
    78  		}
    79  	}()), s, skey))
    80  	if err != nil {
    81  		panic(err)
    82  	}
    83  	var burnSrc, burnDst [1024]byte
    84  	c.XORKeyStream(burnDst[:], burnSrc[:])
    85  	return
    86  }
    87  
    88  type cipherReader struct {
    89  	c  *rc4.Cipher
    90  	r  io.Reader
    91  	be []byte
    92  }
    93  
    94  func (cr *cipherReader) Read(b []byte) (n int, err error) {
    95  	if cap(cr.be) < len(b) {
    96  		cr.be = make([]byte, len(b))
    97  	}
    98  	n, err = cr.r.Read(cr.be[:len(b)])
    99  	cr.c.XORKeyStream(b[:n], cr.be[:n])
   100  	return
   101  }
   102  
   103  func newCipherReader(c *rc4.Cipher, r io.Reader) io.Reader {
   104  	return &cipherReader{c: c, r: r}
   105  }
   106  
   107  type cipherWriter struct {
   108  	c *rc4.Cipher
   109  	w io.Writer
   110  	b []byte
   111  }
   112  
   113  func (cr *cipherWriter) Write(b []byte) (n int, err error) {
   114  	be := func() []byte {
   115  		if len(cr.b) < len(b) {
   116  			return make([]byte, len(b))
   117  		} else {
   118  			ret := cr.b
   119  			cr.b = nil
   120  			return ret
   121  		}
   122  	}()
   123  	cr.c.XORKeyStream(be, b)
   124  	n, err = cr.w.Write(be[:len(b)])
   125  	if n != len(b) {
   126  		// The cipher will have advanced beyond the callers stream position.
   127  		// We can't use the cipher anymore.
   128  		cr.c = nil
   129  	}
   130  	if len(be) > len(cr.b) {
   131  		cr.b = be
   132  	}
   133  	return
   134  }
   135  
   136  func newX() big.Int {
   137  	var X big.Int
   138  	X.SetBytes(func() []byte {
   139  		var b [20]byte
   140  		_, err := rand.Read(b[:])
   141  		if err != nil {
   142  			panic(err)
   143  		}
   144  		return b[:]
   145  	}())
   146  	return X
   147  }
   148  
   149  func paddedLeft(b []byte, _len int) []byte {
   150  	if len(b) == _len {
   151  		return b
   152  	}
   153  	ret := make([]byte, _len)
   154  	if n := copy(ret[_len-len(b):], b); n != len(b) {
   155  		panic(n)
   156  	}
   157  	return ret
   158  }
   159  
   160  // Calculate, and send Y, our public key.
   161  func (h *handshake) postY(x *big.Int) error {
   162  	var y big.Int
   163  	y.Exp(&specG, x, &p)
   164  	return h.postWrite(paddedLeft(y.Bytes(), 96))
   165  }
   166  
   167  func (h *handshake) establishS() error {
   168  	x := newX()
   169  	h.postY(&x)
   170  	var b [96]byte
   171  	_, err := io.ReadFull(h.ctxConn, b[:])
   172  	if err != nil {
   173  		return fmt.Errorf("error reading Y: %w", err)
   174  	}
   175  	var Y, S big.Int
   176  	Y.SetBytes(b[:])
   177  	S.Exp(&Y, &x, &p)
   178  	sBytes := S.Bytes()
   179  	copy(h.s[96-len(sBytes):96], sBytes)
   180  	return nil
   181  }
   182  
   183  func newPadLen() int64 {
   184  	i, err := rand.Int(rand.Reader, &newPadLenMax)
   185  	if err != nil {
   186  		panic(err)
   187  	}
   188  	ret := i.Int64()
   189  	if ret < 0 || ret > maxPadLen {
   190  		panic(ret)
   191  	}
   192  	return ret
   193  }
   194  
   195  // Manages state for both initiating and receiving handshakes.
   196  type handshake struct {
   197  	conn io.ReadWriter
   198  	// The conn with Reads and Writes wrapped to the context given in handshake.Do.
   199  	ctxConn io.ReadWriter
   200  	s       [96]byte
   201  	initer  bool          // Whether we're initiating or receiving.
   202  	skeys   SecretKeyIter // Skeys we'll accept if receiving.
   203  	skey    []byte        // Skey we're initiating with.
   204  	ia      []byte        // Initial payload. Only used by the initiator.
   205  	// Return the bit for the crypto method the receiver wants to use.
   206  	chooseMethod CryptoSelector
   207  	// Sent to the receiver.
   208  	cryptoProvides CryptoMethod
   209  
   210  	writeMu    sync.Mutex
   211  	writes     [][]byte
   212  	writeErr   error
   213  	writeCond  sync.Cond
   214  	writeClose bool
   215  
   216  	writerMu   sync.Mutex
   217  	writerCond sync.Cond
   218  	writerDone bool
   219  }
   220  
   221  func (h *handshake) finishWriting() {
   222  	h.writeMu.Lock()
   223  	h.writeClose = true
   224  	h.writeCond.Broadcast()
   225  	h.writeMu.Unlock()
   226  
   227  	h.writerMu.Lock()
   228  	for !h.writerDone {
   229  		h.writerCond.Wait()
   230  	}
   231  	h.writerMu.Unlock()
   232  }
   233  
   234  func (h *handshake) writer() {
   235  	defer func() {
   236  		h.writerMu.Lock()
   237  		h.writerDone = true
   238  		h.writerCond.Broadcast()
   239  		h.writerMu.Unlock()
   240  	}()
   241  	for {
   242  		h.writeMu.Lock()
   243  		for {
   244  			if len(h.writes) != 0 {
   245  				break
   246  			}
   247  			if h.writeClose {
   248  				h.writeMu.Unlock()
   249  				return
   250  			}
   251  			h.writeCond.Wait()
   252  		}
   253  		b := h.writes[0]
   254  		h.writes = h.writes[1:]
   255  		h.writeMu.Unlock()
   256  		_, err := h.ctxConn.Write(b)
   257  		if err != nil {
   258  			h.writeMu.Lock()
   259  			h.writeErr = err
   260  			h.writeMu.Unlock()
   261  			return
   262  		}
   263  	}
   264  }
   265  
   266  func (h *handshake) postWrite(b []byte) error {
   267  	h.writeMu.Lock()
   268  	defer h.writeMu.Unlock()
   269  	if h.writeErr != nil {
   270  		return h.writeErr
   271  	}
   272  	h.writes = append(h.writes, b)
   273  	h.writeCond.Signal()
   274  	return nil
   275  }
   276  
   277  func xor(a, b []byte) (ret []byte) {
   278  	max := len(a)
   279  	if max > len(b) {
   280  		max = len(b)
   281  	}
   282  	ret = make([]byte, max)
   283  	xorInPlace(ret, a, b)
   284  	return
   285  }
   286  
   287  func xorInPlace(dst, a, b []byte) {
   288  	for i := range dst {
   289  		dst[i] = a[i] ^ b[i]
   290  	}
   291  }
   292  
   293  func marshal(w io.Writer, data ...interface{}) (err error) {
   294  	for _, data := range data {
   295  		err = binary.Write(w, binary.BigEndian, data)
   296  		if err != nil {
   297  			break
   298  		}
   299  	}
   300  	return
   301  }
   302  
   303  func unmarshal(r io.Reader, data ...interface{}) (err error) {
   304  	for _, data := range data {
   305  		err = binary.Read(r, binary.BigEndian, data)
   306  		if err != nil {
   307  			break
   308  		}
   309  	}
   310  	return
   311  }
   312  
   313  // Looking for b at the end of a.
   314  func suffixMatchLen(a, b []byte) int {
   315  	if len(b) > len(a) {
   316  		b = b[:len(a)]
   317  	}
   318  	// i is how much of b to try to match
   319  	for i := len(b); i > 0; i-- {
   320  		// j is how many chars we've compared
   321  		j := 0
   322  		for ; j < i; j++ {
   323  			if b[i-1-j] != a[len(a)-1-j] {
   324  				goto shorter
   325  			}
   326  		}
   327  		return j
   328  	shorter:
   329  	}
   330  	return 0
   331  }
   332  
   333  // Reads from r until b has been seen. Keeps the minimum amount of data in
   334  // memory.
   335  func readUntil(r io.Reader, b []byte) error {
   336  	b1 := make([]byte, len(b))
   337  	i := 0
   338  	for {
   339  		_, err := io.ReadFull(r, b1[i:])
   340  		if err != nil {
   341  			return err
   342  		}
   343  		i = suffixMatchLen(b1, b)
   344  		if i == len(b) {
   345  			break
   346  		}
   347  		if copy(b1, b1[len(b1)-i:]) != i {
   348  			panic("wat")
   349  		}
   350  	}
   351  	return nil
   352  }
   353  
   354  type readWriter struct {
   355  	io.Reader
   356  	io.Writer
   357  }
   358  
   359  func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
   360  	return newEncrypt(initer, h.s[:], h.skey)
   361  }
   362  
   363  func (h *handshake) initerSteps(ctx context.Context) (ret io.ReadWriter, selected CryptoMethod, err error) {
   364  	h.postWrite(hash(req1, h.s[:]))
   365  	h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
   366  	buf := &bytes.Buffer{}
   367  	padLen := uint16(newPadLen())
   368  	if len(h.ia) > math.MaxUint16 {
   369  		err = errors.New("initial payload too large")
   370  		return
   371  	}
   372  	err = marshal(buf, vc[:], h.cryptoProvides, padLen, zeroPad[:padLen], uint16(len(h.ia)), h.ia)
   373  	if err != nil {
   374  		return
   375  	}
   376  	e := h.newEncrypt(true)
   377  	be := make([]byte, buf.Len())
   378  	e.XORKeyStream(be, buf.Bytes())
   379  	h.postWrite(be)
   380  	bC := h.newEncrypt(false)
   381  	var eVC [8]byte
   382  	bC.XORKeyStream(eVC[:], vc[:])
   383  	// Read until the all zero VC. At this point we've only read the 96 byte
   384  	// public key, Y. There is potentially 512 byte padding, between us and
   385  	// the 8 byte verification constant.
   386  	err = readUntil(io.LimitReader(h.ctxConn, 520), eVC[:])
   387  	if err != nil {
   388  		if err == io.EOF {
   389  			err = errors.New("failed to synchronize on VC")
   390  		} else {
   391  			err = fmt.Errorf("error reading until VC: %w", err)
   392  		}
   393  		return
   394  	}
   395  	ctxReader := newCipherReader(bC, h.ctxConn)
   396  	var method CryptoMethod
   397  	err = unmarshal(ctxReader, &method, &padLen)
   398  	if err != nil {
   399  		return
   400  	}
   401  	_, err = io.CopyN(io.Discard, ctxReader, int64(padLen))
   402  	if err != nil {
   403  		return
   404  	}
   405  	selected = method & h.cryptoProvides
   406  	switch selected {
   407  	case CryptoMethodRC4:
   408  		ret = readWriter{
   409  			newCipherReader(bC, h.conn),
   410  			&cipherWriter{e, h.conn, nil},
   411  		}
   412  	case CryptoMethodPlaintext:
   413  		ret = h.conn
   414  	default:
   415  		err = fmt.Errorf("receiver chose unsupported method: %x", method)
   416  	}
   417  	return
   418  }
   419  
   420  var ErrNoSecretKeyMatch = errors.New("no skey matched")
   421  
   422  func (h *handshake) receiverSteps(ctx context.Context) (ret io.ReadWriter, chosen CryptoMethod, err error) {
   423  	// There is up to 512 bytes of padding, then the 20 byte hash.
   424  	err = readUntil(io.LimitReader(h.ctxConn, 532), hash(req1, h.s[:]))
   425  	if err != nil {
   426  		if err == io.EOF {
   427  			err = errors.New("failed to synchronize on S hash")
   428  		}
   429  		return
   430  	}
   431  	var b [20]byte
   432  	_, err = io.ReadFull(h.ctxConn, b[:])
   433  	if err != nil {
   434  		return
   435  	}
   436  	expectedHash := hash(req3, h.s[:])
   437  	eachHash := sha1.New()
   438  	var sum, xored [sha1.Size]byte
   439  	err = ErrNoSecretKeyMatch
   440  	h.skeys(func(skey []byte) bool {
   441  		eachHash.Reset()
   442  		eachHash.Write(req2)
   443  		eachHash.Write(skey)
   444  		eachHash.Sum(sum[:0])
   445  		xorInPlace(xored[:], sum[:], expectedHash)
   446  		if bytes.Equal(xored[:], b[:]) {
   447  			h.skey = skey
   448  			err = nil
   449  			return false
   450  		}
   451  		return true
   452  	})
   453  	if err != nil {
   454  		return
   455  	}
   456  	cipher := newEncrypt(true, h.s[:], h.skey)
   457  	ctxReader := newCipherReader(cipher, h.ctxConn)
   458  	var (
   459  		vc       [8]byte
   460  		provides CryptoMethod
   461  		padLen   uint16
   462  	)
   463  
   464  	err = unmarshal(ctxReader, vc[:], &provides, &padLen)
   465  	if err != nil {
   466  		return
   467  	}
   468  	cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
   469  	chosen = h.chooseMethod(provides)
   470  	_, err = io.CopyN(io.Discard, ctxReader, int64(padLen))
   471  	if err != nil {
   472  		return
   473  	}
   474  	var lenIA uint16
   475  	unmarshal(ctxReader, &lenIA)
   476  	if lenIA != 0 {
   477  		h.ia = make([]byte, lenIA)
   478  		unmarshal(ctxReader, h.ia)
   479  	}
   480  	buf := &bytes.Buffer{}
   481  	w := cipherWriter{h.newEncrypt(false), buf, nil}
   482  	padLen = uint16(newPadLen())
   483  	err = marshal(&w, &vc, uint32(chosen), padLen, zeroPad[:padLen])
   484  	if err != nil {
   485  		return
   486  	}
   487  	err = h.postWrite(buf.Bytes())
   488  	if err != nil {
   489  		return
   490  	}
   491  	switch chosen {
   492  	case CryptoMethodRC4:
   493  		ret = readWriter{
   494  			io.MultiReader(bytes.NewReader(h.ia), newCipherReader(cipher, h.conn)),
   495  			&cipherWriter{w.c, h.conn, nil},
   496  		}
   497  	case CryptoMethodPlaintext:
   498  		ret = readWriter{
   499  			io.MultiReader(bytes.NewReader(h.ia), h.conn),
   500  			h.conn,
   501  		}
   502  	default:
   503  		err = errors.New("chosen crypto method is not supported")
   504  	}
   505  	return
   506  }
   507  
   508  func (h *handshake) Do(ctx context.Context) (ret io.ReadWriter, method CryptoMethod, err error) {
   509  	h.writeCond.L = &h.writeMu
   510  	h.writerCond.L = &h.writerMu
   511  	go h.writer()
   512  	defer func() {
   513  		h.finishWriting()
   514  		if err == nil {
   515  			err = h.writeErr
   516  		}
   517  	}()
   518  	err = h.establishS()
   519  	if err != nil {
   520  		err = fmt.Errorf("error while establishing secret: %w", err)
   521  		return
   522  	}
   523  	pad := make([]byte, newPadLen())
   524  	io.ReadFull(rand.Reader, pad)
   525  	err = h.postWrite(pad)
   526  	if err != nil {
   527  		return
   528  	}
   529  	if h.initer {
   530  		ret, method, err = h.initerSteps(ctx)
   531  	} else {
   532  		ret, method, err = h.receiverSteps(ctx)
   533  	}
   534  	return
   535  }
   536  
   537  func InitiateHandshake(
   538  	rw io.ReadWriter,
   539  	skey, initialPayload []byte,
   540  	cryptoProvides CryptoMethod,
   541  ) (
   542  	ret io.ReadWriter, method CryptoMethod, err error,
   543  ) {
   544  	return InitiateHandshakeContext(context.TODO(), rw, skey, initialPayload, cryptoProvides)
   545  }
   546  
   547  func InitiateHandshakeContext(
   548  	ctx context.Context,
   549  	rw io.ReadWriter,
   550  	skey, initialPayload []byte,
   551  	cryptoProvides CryptoMethod,
   552  ) (
   553  	ret io.ReadWriter, method CryptoMethod, err error,
   554  ) {
   555  	h := handshake{
   556  		conn:           rw,
   557  		ctxConn:        ctxrw.WrapReadWriter(ctx, rw),
   558  		initer:         true,
   559  		skey:           skey,
   560  		ia:             initialPayload,
   561  		cryptoProvides: cryptoProvides,
   562  	}
   563  	return h.Do(ctx)
   564  }
   565  
   566  type HandshakeResult struct {
   567  	io.ReadWriter
   568  	CryptoMethod
   569  	error
   570  	SecretKey []byte
   571  }
   572  
   573  func ReceiveHandshake(
   574  	ctx context.Context,
   575  	rw io.ReadWriter,
   576  	skeys SecretKeyIter,
   577  	selectCrypto CryptoSelector,
   578  ) (io.ReadWriter, CryptoMethod, error) {
   579  	res := ReceiveHandshakeEx(ctx, rw, skeys, selectCrypto)
   580  	return res.ReadWriter, res.CryptoMethod, res.error
   581  }
   582  
   583  func ReceiveHandshakeEx(
   584  	ctx context.Context,
   585  	rw io.ReadWriter,
   586  	skeys SecretKeyIter,
   587  	selectCrypto CryptoSelector,
   588  ) (ret HandshakeResult) {
   589  	h := handshake{
   590  		conn:         rw,
   591  		ctxConn:      ctxrw.WrapReadWriter(ctx, rw),
   592  		initer:       false,
   593  		skeys:        skeys,
   594  		chooseMethod: selectCrypto,
   595  	}
   596  	ret.ReadWriter, ret.CryptoMethod, ret.error = h.Do(ctx)
   597  	ret.SecretKey = h.skey
   598  	return
   599  }
   600  
   601  // A function that given a function, calls it with secret keys until it
   602  // returns false or exhausted.
   603  type SecretKeyIter func(callback func(skey []byte) (more bool))
   604  
   605  func DefaultCryptoSelector(provided CryptoMethod) CryptoMethod {
   606  	// We prefer plaintext for performance reasons.
   607  	if provided&CryptoMethodPlaintext != 0 {
   608  		return CryptoMethodPlaintext
   609  	}
   610  	return CryptoMethodRC4
   611  }
   612  
   613  type CryptoSelector func(CryptoMethod) CryptoMethod