github.com/geph-official/geph2@v0.22.6-0.20210211030601-f527cb59b0df/libs/cshirt2/legacyTransport.go (about)

     1  package cshirt2
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"crypto/cipher"
     7  	"crypto/subtle"
     8  	"encoding/binary"
     9  	"errors"
    10  	"io"
    11  	"log"
    12  	"net"
    13  	"sync/atomic"
    14  	"time"
    15  
    16  	"github.com/geph-official/geph2/libs/erand"
    17  	pool "github.com/libp2p/go-buffer-pool"
    18  	"golang.org/x/crypto/chacha20"
    19  )
    20  
    21  // generates padding, given a write size
    22  func generatePadding(wsize int) []byte {
    23  	// TODO improve
    24  	if wsize > 3000 {
    25  		return nil
    26  	}
    27  	return make([]byte, erand.Int(512))
    28  }
    29  
    30  type legacyTransport struct {
    31  	readMAC    []byte
    32  	readCrypt  cipher.Stream
    33  	writeMAC   []byte
    34  	writeCrypt cipher.Stream
    35  	wireBuf    *bufio.Reader
    36  	wire       net.Conn
    37  	readbuf    bytes.Buffer
    38  
    39  	readDeadline  atomic.Value
    40  	writeDeadline atomic.Value
    41  
    42  	buf [128]byte
    43  }
    44  
    45  func (tp *legacyTransport) Read(b []byte) (n int, err error) {
    46  	for tp.readbuf.Len() == 0 {
    47  		// read the mac
    48  		macBts := tp.buf[0:16]
    49  		_, err = io.ReadFull(tp.wireBuf, macBts)
    50  		if err != nil {
    51  			return
    52  		}
    53  		// read the *encrypted* payload length
    54  		cryptPayloadLenBts := tp.buf[16:][:2]
    55  		_, err = io.ReadFull(tp.wireBuf, cryptPayloadLenBts)
    56  		if err != nil {
    57  			return
    58  		}
    59  		plainPayloadLenBts := tp.buf[18:][:2]
    60  		tp.readCrypt.XORKeyStream(plainPayloadLenBts, cryptPayloadLenBts)
    61  		// read the encrypted payload
    62  		cryptInnerPayloadBts := pool.GlobalPool.Get(int(binary.BigEndian.Uint16(plainPayloadLenBts)))
    63  		defer pool.GlobalPool.Put(cryptInnerPayloadBts)
    64  		// short timeout
    65  		tp.wire.SetReadDeadline(time.Now().Add(time.Second * 2))
    66  		_, err = io.ReadFull(tp.wireBuf, cryptInnerPayloadBts)
    67  		if err != nil {
    68  			log.Println("could not read the", len(cryptInnerPayloadBts), "bytes requested", err.Error())
    69  			return
    70  		}
    71  		tp.wire.SetReadDeadline(time.Time{})
    72  		rdead := tp.readDeadline.Load()
    73  		if rdead != nil {
    74  			tp.wire.SetReadDeadline(rdead.(time.Time))
    75  		}
    76  		// verify the MAC
    77  		toMAC := pool.GlobalPool.Get(len(cryptPayloadLenBts) + len(cryptInnerPayloadBts))
    78  		defer pool.GlobalPool.Put(toMAC)
    79  		copy(toMAC, cryptPayloadLenBts)
    80  		copy(toMAC[len(cryptPayloadLenBts):], cryptInnerPayloadBts)
    81  		if subtle.ConstantTimeCompare(macBts, mac128(toMAC, tp.readMAC)) != 1 {
    82  			err = errors.New("MAC error")
    83  			return
    84  		}
    85  		tp.readMAC = mac256(tp.readMAC, nil)
    86  		// decrypt the payload itself
    87  		plainInnerPayloadBts := pool.GlobalPool.Get(len(cryptInnerPayloadBts))
    88  		defer pool.GlobalPool.Put(plainInnerPayloadBts)
    89  		tp.readCrypt.XORKeyStream(plainInnerPayloadBts, cryptInnerPayloadBts)
    90  		if len(plainInnerPayloadBts) < 2 {
    91  			err = errors.New("truncated payload")
    92  			return
    93  		}
    94  		// get the non-padding part
    95  		realLenBts := plainInnerPayloadBts[:2]
    96  		realBts := plainInnerPayloadBts[2:][:binary.BigEndian.Uint16(realLenBts)]
    97  		// stuff the payload into the read buffer
    98  		tp.readbuf.Write(realBts)
    99  	}
   100  	n, err = tp.readbuf.Read(b)
   101  	return
   102  }
   103  
   104  func (tp *legacyTransport) Write(b []byte) (n int, err error) {
   105  	if len(b) > 65535 {
   106  		panic("don't know what to do!")
   107  	}
   108  	// first generate the plaintext payload
   109  	plainBuf := new(bytes.Buffer)
   110  	padding := generatePadding(len(b))
   111  	binary.Write(plainBuf, binary.BigEndian, uint16(len(padding)+len(b)+2))
   112  	binary.Write(plainBuf, binary.BigEndian, uint16(len(b)))
   113  	plainBuf.Write(b)
   114  	plainBuf.Write(padding)
   115  	// then we encrypt the payload
   116  	cryptPayload := plainBuf.Bytes()
   117  	tp.writeCrypt.XORKeyStream(cryptPayload, cryptPayload)
   118  	// then we compute the MAC and ratchet forward the key
   119  	mac := mac128(cryptPayload, tp.writeMAC)
   120  	tp.writeMAC = mac256(tp.writeMAC, nil)
   121  	toWrite := pool.GlobalPool.Get(len(mac) + len(cryptPayload))
   122  	defer pool.GlobalPool.Put(toWrite)
   123  	copy(toWrite, mac)
   124  	copy(toWrite[len(mac):], cryptPayload)
   125  	// then we assemble everything
   126  	_, err = tp.wire.Write(toWrite)
   127  	if err != nil {
   128  		return
   129  	}
   130  	n = len(b)
   131  	return
   132  }
   133  
   134  func (tp *legacyTransport) Close() error {
   135  	return tp.wire.Close()
   136  }
   137  
   138  func (tp *legacyTransport) LocalAddr() net.Addr {
   139  	return tp.wire.LocalAddr()
   140  }
   141  
   142  func (tp *legacyTransport) RemoteAddr() net.Addr {
   143  	return tp.wire.RemoteAddr()
   144  }
   145  
   146  func (tp *legacyTransport) SetDeadline(t time.Time) error {
   147  	return tp.wire.SetDeadline(t)
   148  }
   149  
   150  func (tp *legacyTransport) SetReadDeadline(t time.Time) error {
   151  	tp.readDeadline.Store(t)
   152  	return tp.wire.SetReadDeadline(t)
   153  }
   154  
   155  func (tp *legacyTransport) SetWriteDeadline(t time.Time) error {
   156  	tp.writeDeadline.Store(t)
   157  	return tp.wire.SetWriteDeadline(t)
   158  }
   159  
   160  func newLegacyTransport(wire net.Conn, ss []byte, isServer bool) *legacyTransport {
   161  	tp := new(legacyTransport)
   162  	readKey := mac256(ss, []byte("c2s"))
   163  	writeKey := mac256(ss, []byte("c2c"))
   164  	if !isServer {
   165  		readKey, writeKey = writeKey, readKey
   166  	}
   167  	var err error
   168  	tp.readMAC = mac256(readKey, []byte("mac"))
   169  	tp.readCrypt, err = chacha20.NewUnauthenticatedCipher(mac256(readKey, []byte("crypt")), make([]byte, 12))
   170  	if err != nil {
   171  		panic(err)
   172  	}
   173  	tp.writeMAC = mac256(writeKey, []byte("mac"))
   174  	tp.writeCrypt, err = chacha20.NewUnauthenticatedCipher(mac256(writeKey, []byte("crypt")), make([]byte, 12))
   175  	if err != nil {
   176  		panic(err)
   177  	}
   178  	tp.wire = wire
   179  	tp.wireBuf = bufio.NewReader(wire)
   180  	return tp
   181  }