github.com/keltia/go-ipfs@v0.3.8-0.20150909044612-210793031c63/p2p/crypto/secio/rw.go (about)

     1  package secio
     2  
     3  import (
     4  	"crypto/cipher"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"sync"
     9  
    10  	"crypto/hmac"
    11  
    12  	proto "github.com/ipfs/go-ipfs/Godeps/_workspace/src/github.com/gogo/protobuf/proto"
    13  	msgio "github.com/ipfs/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-msgio"
    14  	mpool "github.com/ipfs/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-msgio/mpool"
    15  	context "github.com/ipfs/go-ipfs/Godeps/_workspace/src/golang.org/x/net/context"
    16  )
    17  
    18  const MaxMsgSize = 8 * 1024 * 1024
    19  
    20  var ErrMaxMessageSize = errors.New("attempted to read message larger than max size")
    21  
    22  // ErrMACInvalid signals that a MAC verification failed
    23  var ErrMACInvalid = errors.New("MAC verification failed")
    24  
    25  // bufPool is a ByteSlicePool for messages. we need buffers because (sadly)
    26  // we cannot encrypt in place-- the user needs their buffer back.
    27  var bufPool = mpool.ByteSlicePool
    28  
    29  type etmWriter struct {
    30  	// params
    31  	pool mpool.Pool        // for the buffers with encrypted data
    32  	msg  msgio.WriteCloser // msgio for knowing where boundaries lie
    33  	str  cipher.Stream     // the stream cipher to encrypt with
    34  	mac  HMAC              // the mac to authenticate data with
    35  
    36  	sync.Mutex
    37  }
    38  
    39  // NewETMWriter Encrypt-Then-MAC
    40  func NewETMWriter(w io.Writer, s cipher.Stream, mac HMAC) msgio.WriteCloser {
    41  	return &etmWriter{msg: msgio.NewWriter(w), str: s, mac: mac, pool: bufPool}
    42  }
    43  
    44  // Write writes passed in buffer as a single message.
    45  func (w *etmWriter) Write(b []byte) (int, error) {
    46  	if err := w.WriteMsg(b); err != nil {
    47  		return 0, err
    48  	}
    49  	return len(b), nil
    50  }
    51  
    52  // WriteMsg writes the msg in the passed in buffer.
    53  func (w *etmWriter) WriteMsg(b []byte) error {
    54  	w.Lock()
    55  	defer w.Unlock()
    56  
    57  	// encrypt.
    58  	data := w.pool.Get(uint32(len(b))).([]byte)
    59  	data = data[:len(b)] // the pool's buffer may be larger
    60  	w.str.XORKeyStream(data, b)
    61  
    62  	// log.Debugf("ENC plaintext (%d): %s %v", len(b), b, b)
    63  	// log.Debugf("ENC ciphertext (%d): %s %v", len(data), data, data)
    64  
    65  	// then, mac.
    66  	if _, err := w.mac.Write(data); err != nil {
    67  		return err
    68  	}
    69  
    70  	// Sum appends.
    71  	data = w.mac.Sum(data)
    72  	w.mac.Reset()
    73  	// it's sad to append here. our buffers are -- hopefully -- coming from
    74  	// a shared buffer pool, so the append may not actually cause allocation
    75  	// one can only hope. i guess we'll see.
    76  
    77  	return w.msg.WriteMsg(data)
    78  }
    79  
    80  func (w *etmWriter) Close() error {
    81  	return w.msg.Close()
    82  }
    83  
    84  type etmReader struct {
    85  	msgio.Reader
    86  	io.Closer
    87  
    88  	// buffer
    89  	buf []byte
    90  
    91  	// params
    92  	msg msgio.ReadCloser // msgio for knowing where boundaries lie
    93  	str cipher.Stream    // the stream cipher to encrypt with
    94  	mac HMAC             // the mac to authenticate data with
    95  
    96  	sync.Mutex
    97  }
    98  
    99  // NewETMReader Encrypt-Then-MAC
   100  func NewETMReader(r io.Reader, s cipher.Stream, mac HMAC) msgio.ReadCloser {
   101  	return &etmReader{msg: msgio.NewReader(r), str: s, mac: mac}
   102  }
   103  
   104  func (r *etmReader) NextMsgLen() (int, error) {
   105  	return r.msg.NextMsgLen()
   106  }
   107  
   108  func (r *etmReader) drainBuf(buf []byte) int {
   109  	if r.buf == nil {
   110  		return 0
   111  	}
   112  
   113  	n := copy(buf, r.buf)
   114  	r.buf = r.buf[n:]
   115  	return n
   116  }
   117  
   118  func (r *etmReader) Read(buf []byte) (int, error) {
   119  	r.Lock()
   120  	defer r.Unlock()
   121  
   122  	// first, check if we have anything in the buffer
   123  	copied := r.drainBuf(buf)
   124  	buf = buf[copied:]
   125  	if copied > 0 {
   126  		return copied, nil
   127  		// return here to avoid complicating the rest...
   128  		// user can call io.ReadFull.
   129  	}
   130  
   131  	// check the buffer has enough space for the next msg
   132  	fullLen, err := r.msg.NextMsgLen()
   133  	if err != nil {
   134  		return 0, err
   135  	}
   136  
   137  	if fullLen > MaxMsgSize {
   138  		return 0, ErrMaxMessageSize
   139  	}
   140  
   141  	buf2 := buf
   142  	changed := false
   143  	// if not enough space, allocate a new buffer.
   144  	if cap(buf) < fullLen {
   145  		buf2 = make([]byte, fullLen)
   146  		changed = true
   147  	}
   148  	buf2 = buf2[:fullLen]
   149  
   150  	n, err := io.ReadFull(r.msg, buf2)
   151  	if err != nil {
   152  		return n, err
   153  	}
   154  
   155  	m, err := r.macCheckThenDecrypt(buf2)
   156  	if err != nil {
   157  		return 0, err
   158  	}
   159  	buf2 = buf2[:m]
   160  	if !changed {
   161  		return m, nil
   162  	}
   163  
   164  	n = copy(buf, buf2)
   165  	if len(buf2) > len(buf) {
   166  		r.buf = buf2[len(buf):] // had some left over? save it.
   167  	}
   168  	return n, nil
   169  }
   170  
   171  func (r *etmReader) ReadMsg() ([]byte, error) {
   172  	r.Lock()
   173  	defer r.Unlock()
   174  
   175  	msg, err := r.msg.ReadMsg()
   176  	if err != nil {
   177  		return nil, err
   178  	}
   179  
   180  	n, err := r.macCheckThenDecrypt(msg)
   181  	if err != nil {
   182  		return nil, err
   183  	}
   184  	return msg[:n], nil
   185  }
   186  
   187  func (r *etmReader) macCheckThenDecrypt(m []byte) (int, error) {
   188  	l := len(m)
   189  	if l < r.mac.size {
   190  		return 0, fmt.Errorf("buffer (%d) shorter than MAC size (%d)", l, r.mac.size)
   191  	}
   192  
   193  	mark := l - r.mac.size
   194  	data := m[:mark]
   195  	macd := m[mark:]
   196  
   197  	r.mac.Write(data)
   198  	expected := r.mac.Sum(nil)
   199  	r.mac.Reset()
   200  
   201  	// check mac. if failed, return error.
   202  	if !hmac.Equal(macd, expected) {
   203  		log.Debug("MAC Invalid:", expected, "!=", macd)
   204  		return 0, ErrMACInvalid
   205  	}
   206  
   207  	// ok seems good. decrypt. (can decrypt in place, yay!)
   208  	// log.Debugf("DEC ciphertext (%d): %s %v", len(data), data, data)
   209  	r.str.XORKeyStream(data, data)
   210  	// log.Debugf("DEC plaintext (%d): %s %v", len(data), data, data)
   211  
   212  	return mark, nil
   213  }
   214  
   215  func (w *etmReader) Close() error {
   216  	return w.msg.Close()
   217  }
   218  
   219  // ReleaseMsg signals a buffer can be reused.
   220  func (r *etmReader) ReleaseMsg(b []byte) {
   221  	r.msg.ReleaseMsg(b)
   222  }
   223  
   224  // writeMsgCtx is used by the
   225  func writeMsgCtx(ctx context.Context, w msgio.Writer, msg proto.Message) ([]byte, error) {
   226  	enc, err := proto.Marshal(msg)
   227  	if err != nil {
   228  		return nil, err
   229  	}
   230  
   231  	// write in a goroutine so we can exit when our context is cancelled.
   232  	done := make(chan error)
   233  	go func(m []byte) {
   234  		err := w.WriteMsg(m)
   235  		select {
   236  		case done <- err:
   237  		case <-ctx.Done():
   238  		}
   239  	}(enc)
   240  
   241  	select {
   242  	case <-ctx.Done():
   243  		return nil, ctx.Err()
   244  	case e := <-done:
   245  		return enc, e
   246  	}
   247  }
   248  
   249  func readMsgCtx(ctx context.Context, r msgio.Reader, p proto.Message) ([]byte, error) {
   250  	var msg []byte
   251  
   252  	// read in a goroutine so we can exit when our context is cancelled.
   253  	done := make(chan error)
   254  	go func() {
   255  		var err error
   256  		msg, err = r.ReadMsg()
   257  		select {
   258  		case done <- err:
   259  		case <-ctx.Done():
   260  		}
   261  	}()
   262  
   263  	select {
   264  	case <-ctx.Done():
   265  		return nil, ctx.Err()
   266  	case e := <-done:
   267  		if e != nil {
   268  			return nil, e
   269  		}
   270  	}
   271  
   272  	return msg, proto.Unmarshal(msg, p)
   273  }