
     1  // Copyright 2018 Keybase, Inc. All rights reserved. Use of
     2  // this source code is governed by the included BSD license.
     4  package kbcrypto
     6  // Code for encoding and decoding Keybase packet types.
     8  import (
     9  	"bytes"
    10  	"crypto/sha256"
    11  	"encoding/base64"
    12  	"errors"
    13  	"fmt"
    15  	""
    16  )
    18  type PacketVersion int
    20  const (
    21  	KeybasePacketV1 PacketVersion = 1
    22  )
    24  // PacketTag are tags for OpenPGP and Keybase packets. It is a uint to
    25  // be backwards compatible with older versions of codec that encoded
    26  // positive ints as uints.
    27  type PacketTag uint
    29  const (
    30  	TagP3skb      PacketTag = 513
    31  	TagSignature  PacketTag = 514
    32  	TagEncryption PacketTag = 515
    33  )
    35  func (t PacketTag) String() string {
    36  	switch t {
    37  	case TagP3skb:
    38  		return "PacketTag(P3skb)"
    39  	case TagSignature:
    40  		return "PacketTag(Signature)"
    41  	case TagEncryption:
    42  		return "PacketTag(Encryption)"
    43  	default:
    44  		return fmt.Sprintf("PacketTag(%d)", uint(t))
    45  	}
    46  }
    48  type Packetable interface {
    49  	GetTagAndVersion() (PacketTag, PacketVersion)
    50  }
    52  func EncodePacket(p Packetable, encoder *codec.Encoder) error {
    53  	packet, err := newKeybasePacket(p, true)
    54  	if err != nil {
    55  		return err
    56  	}
    57  	return encoder.Encode(packet)
    58  }
    60  func EncodePacketToBytes(p Packetable) ([]byte, error) {
    61  	packet, err := newKeybasePacket(p, true)
    62  	if err != nil {
    63  		return nil, err
    64  	}
    65  	return packet.encode()
    66  }
    68  func EncodePacketToBytesWithOptionalHash(p Packetable, doHash bool) ([]byte, error) {
    69  	packet, err := newKeybasePacket(p, doHash)
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  	return packet.encode()
    74  }
    76  func EncodePacketToArmoredString(p Packetable) (string, error) {
    77  	packet, err := newKeybasePacket(p, true)
    78  	if err != nil {
    79  		return "", err
    80  	}
    81  	return packet.armoredEncode()
    82  }
    84  type UnmarshalError struct {
    85  	ExpectedTag PacketTag
    86  	Tag         PacketTag
    87  }
    89  func (u UnmarshalError) Error() string {
    90  	return fmt.Sprintf("Expected %s packet, got %s packet", u.ExpectedTag, u.Tag)
    91  }
    93  func DecodePacket(decoder *codec.Decoder, body Packetable) error {
    94  	// TODO: Do something with the version too?
    95  	tag, _ := body.GetTagAndVersion()
    96  	p := keybasePacket{
    97  		Body: body,
    98  	}
    99  	err := decoder.Decode(&p)
   100  	if err != nil {
   101  		return err
   102  	}
   104  	if p.Tag != tag {
   105  		return UnmarshalError{ExpectedTag: p.Tag, Tag: tag}
   106  	}
   108  	// TODO: Figure out a way to do the same reencode check as in
   109  	// DecodePacketFromBytes.
   111  	return p.checkHash()
   112  }
   114  func DecodePacketFromBytes(data []byte, body Packetable) error {
   115  	ch := CodecHandle()
   116  	decoder := codec.NewDecoderBytes(data, ch)
   118  	// TODO: Do something with the version too?
   119  	tag, _ := body.GetTagAndVersion()
   120  	p := keybasePacket{
   121  		Body: body,
   122  	}
   123  	err := decoder.Decode(&p)
   124  	if err != nil {
   125  		return err
   126  	}
   128  	if decoder.NumBytesRead() != len(data) {
   129  		return fmt.Errorf("Did not consume entire buffer: %d byte(s) left", len(data)-decoder.NumBytesRead())
   130  	}
   132  	if p.Tag != tag {
   133  		return UnmarshalError{ExpectedTag: p.Tag, Tag: tag}
   134  	}
   136  	// Test for nonstandard msgpack data (which could be maliciously crafted)
   137  	// by re-encoding and making sure we get the same thing.
   138  	//
   139  	//
   140  	// Ideally this should be done at a lower level, but our
   141  	// msgpack library doesn't sort maps the way we expect. See
   142  	//
   143  	if reencoded, err := p.encode(); err != nil {
   144  		return err
   145  	} else if !bytes.Equal(reencoded, data) {
   146  		return FishyMsgpackError{data, reencoded}
   147  	}
   149  	return p.checkHash()
   150  }
   152  type FishyMsgpackError struct {
   153  	original  []byte
   154  	reencoded []byte
   155  }
   157  func (e FishyMsgpackError) Error() string {
   158  	return fmt.Sprintf("Original msgpack data didn't match re-encoded version: reencoded=%x != original=%x", e.reencoded, e.original)
   159  }
   161  func CodecHandle() *codec.MsgpackHandle {
   162  	var mh codec.MsgpackHandle
   163  	mh.WriteExt = true
   164  	return &mh
   165  }
   167  const SHA256Code = 8
   169  type keybasePacketHash struct {
   170  	Type  int    `codec:"type"`
   171  	Value []byte `codec:"value"`
   172  }
   174  type keybasePacket struct {
   175  	Body    Packetable         `codec:"body"`
   176  	Hash    *keybasePacketHash `codec:"hash,omitempty"`
   177  	Tag     PacketTag          `codec:"tag"`
   178  	Version PacketVersion      `codec:"version"`
   179  }
   181  // newKeybasePacket creates a new keybase crypto packet, optionally computing a
   182  // hash over all data in the packet (via doHash). Every client 1.0.17 and above
   183  // provides this flag (implicitly, since before it wasn't optional).  Some 1.0.16
   184  // clients did this, and no clients 1.0.15 and earlier did it. We use the flag
   185  // so that we can generate the legacy hashes for old 1.0.16
   186  func newKeybasePacket(body Packetable, doHash bool) (*keybasePacket, error) {
   187  	tag, version := body.GetTagAndVersion()
   188  	ret := keybasePacket{
   189  		Body:    body,
   190  		Tag:     tag,
   191  		Version: version,
   192  	}
   193  	if doHash {
   194  		ret.Hash = &keybasePacketHash{
   195  			Type:  SHA256Code,
   196  			Value: []byte{},
   197  		}
   198  		hashBytes, hashErr := ret.hashSum()
   199  		if hashErr != nil {
   200  			return nil, hashErr
   201  		}
   202  		ret.Hash.Value = hashBytes
   203  	}
   204  	return &ret, nil
   205  }
   207  func (p *keybasePacket) hashToBytes() ([]byte, error) {
   208  	// We don't include the Hash field in the encoded bytes that we hash,
   209  	// because if we did then the result wouldn't be stable. To work around
   210  	// that, we make a copy of the packet and overwrite the Hash field with
   211  	// an empty slice.
   212  	packetCopy := *p
   213  	packetCopy.Hash = &keybasePacketHash{
   214  		Type:  SHA256Code,
   215  		Value: []byte{},
   216  	}
   217  	return packetCopy.hashSum()
   218  }
   220  func (p *keybasePacket) hashSum() ([]byte, error) {
   221  	if len(p.Hash.Value) != 0 {
   222  		return nil, errors.New("cannot compute hash with Value present")
   223  	}
   224  	encoded, err := p.encode()
   225  	if err != nil {
   226  		return nil, err
   227  	}
   228  	ret := sha256.Sum256(encoded)
   229  	return ret[:], nil
   230  }
   232  func (p *keybasePacket) checkHash() error {
   233  	var gotten []byte
   234  	var err error
   235  	if p.Hash == nil {
   236  		return nil
   237  	}
   238  	given := p.Hash.Value
   239  	if p.Hash.Type != SHA256Code {
   240  		err = fmt.Errorf("Bad hash code: %d", p.Hash.Type)
   241  	} else if gotten, err = p.hashToBytes(); err != nil {
   243  	} else if !FastByteArrayEq(gotten, given) {
   244  		err = fmt.Errorf("Bad packet hash")
   245  	}
   246  	return err
   247  }
   249  func (p *keybasePacket) encode() ([]byte, error) {
   250  	var encoded []byte
   251  	err := codec.NewEncoderBytes(&encoded, CodecHandle()).Encode(p)
   252  	return encoded, err
   253  }
   255  func (p *keybasePacket) armoredEncode() (string, error) {
   256  	var buf bytes.Buffer
   257  	err := func() (err error) {
   258  		b64 := base64.NewEncoder(base64.StdEncoding, &buf)
   259  		defer func() {
   260  			closeErr := b64.Close()
   261  			if err == nil {
   262  				err = closeErr
   263  			}
   264  		}()
   265  		encoder := codec.NewEncoder(b64, CodecHandle())
   266  		return encoder.Encode(p)
   267  	}()
   268  	if err != nil {
   269  		return "", err
   270  	}
   271  	return buf.String(), nil
   272  }