github.com/pion/dtls/v2@v2.2.12/pkg/crypto/ciphersuite/cbc.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  package ciphersuite
     5  
     6  import ( //nolint:gci
     7  	"crypto/aes"
     8  	"crypto/cipher"
     9  	"crypto/hmac"
    10  	"crypto/rand"
    11  	"encoding/binary"
    12  	"hash"
    13  
    14  	"github.com/pion/dtls/v2/internal/util"
    15  	"github.com/pion/dtls/v2/pkg/crypto/prf"
    16  	"github.com/pion/dtls/v2/pkg/protocol"
    17  	"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
    18  )
    19  
    20  // block ciphers using cipher block chaining.
    21  type cbcMode interface {
    22  	cipher.BlockMode
    23  	SetIV([]byte)
    24  }
    25  
    26  // CBC Provides an API to Encrypt/Decrypt DTLS 1.2 Packets
    27  type CBC struct {
    28  	writeCBC, readCBC cbcMode
    29  	writeMac, readMac []byte
    30  	h                 prf.HashFunc
    31  }
    32  
    33  // NewCBC creates a DTLS CBC Cipher
    34  func NewCBC(localKey, localWriteIV, localMac, remoteKey, remoteWriteIV, remoteMac []byte, h prf.HashFunc) (*CBC, error) {
    35  	writeBlock, err := aes.NewCipher(localKey)
    36  	if err != nil {
    37  		return nil, err
    38  	}
    39  
    40  	readBlock, err := aes.NewCipher(remoteKey)
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  
    45  	writeCBC, ok := cipher.NewCBCEncrypter(writeBlock, localWriteIV).(cbcMode)
    46  	if !ok {
    47  		return nil, errFailedToCast
    48  	}
    49  
    50  	readCBC, ok := cipher.NewCBCDecrypter(readBlock, remoteWriteIV).(cbcMode)
    51  	if !ok {
    52  		return nil, errFailedToCast
    53  	}
    54  
    55  	return &CBC{
    56  		writeCBC: writeCBC,
    57  		writeMac: localMac,
    58  
    59  		readCBC: readCBC,
    60  		readMac: remoteMac,
    61  		h:       h,
    62  	}, nil
    63  }
    64  
    65  // Encrypt encrypt a DTLS RecordLayer message
    66  func (c *CBC) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) {
    67  	payload := raw[recordlayer.HeaderSize:]
    68  	raw = raw[:recordlayer.HeaderSize]
    69  	blockSize := c.writeCBC.BlockSize()
    70  
    71  	// Generate + Append MAC
    72  	h := pkt.Header
    73  
    74  	MAC, err := c.hmac(h.Epoch, h.SequenceNumber, h.ContentType, h.Version, payload, c.writeMac, c.h)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  	payload = append(payload, MAC...)
    79  
    80  	// Generate + Append padding
    81  	padding := make([]byte, blockSize-len(payload)%blockSize)
    82  	paddingLen := len(padding)
    83  	for i := 0; i < paddingLen; i++ {
    84  		padding[i] = byte(paddingLen - 1)
    85  	}
    86  	payload = append(payload, padding...)
    87  
    88  	// Generate IV
    89  	iv := make([]byte, blockSize)
    90  	if _, err := rand.Read(iv); err != nil {
    91  		return nil, err
    92  	}
    93  
    94  	// Set IV + Encrypt + Prepend IV
    95  	c.writeCBC.SetIV(iv)
    96  	c.writeCBC.CryptBlocks(payload, payload)
    97  	payload = append(iv, payload...)
    98  
    99  	// Prepend unencrypte header with encrypted payload
   100  	raw = append(raw, payload...)
   101  
   102  	// Update recordLayer size to include IV+MAC+Padding
   103  	binary.BigEndian.PutUint16(raw[recordlayer.HeaderSize-2:], uint16(len(raw)-recordlayer.HeaderSize))
   104  
   105  	return raw, nil
   106  }
   107  
   108  // Decrypt decrypts a DTLS RecordLayer message
   109  func (c *CBC) Decrypt(in []byte) ([]byte, error) {
   110  	body := in[recordlayer.HeaderSize:]
   111  	blockSize := c.readCBC.BlockSize()
   112  	mac := c.h()
   113  
   114  	var h recordlayer.Header
   115  	err := h.Unmarshal(in)
   116  	switch {
   117  	case err != nil:
   118  		return nil, err
   119  	case h.ContentType == protocol.ContentTypeChangeCipherSpec:
   120  		// Nothing to encrypt with ChangeCipherSpec
   121  		return in, nil
   122  	case len(body)%blockSize != 0 || len(body) < blockSize+util.Max(mac.Size()+1, blockSize):
   123  		return nil, errNotEnoughRoomForNonce
   124  	}
   125  
   126  	// Set + remove per record IV
   127  	c.readCBC.SetIV(body[:blockSize])
   128  	body = body[blockSize:]
   129  
   130  	// Decrypt
   131  	c.readCBC.CryptBlocks(body, body)
   132  
   133  	// Padding+MAC needs to be checked in constant time
   134  	// Otherwise we reveal information about the level of correctness
   135  	paddingLen, paddingGood := examinePadding(body)
   136  	if paddingGood != 255 {
   137  		return nil, errInvalidMAC
   138  	}
   139  
   140  	macSize := mac.Size()
   141  	if len(body) < macSize {
   142  		return nil, errInvalidMAC
   143  	}
   144  
   145  	dataEnd := len(body) - macSize - paddingLen
   146  
   147  	expectedMAC := body[dataEnd : dataEnd+macSize]
   148  	actualMAC, err := c.hmac(h.Epoch, h.SequenceNumber, h.ContentType, h.Version, body[:dataEnd], c.readMac, c.h)
   149  
   150  	// Compute Local MAC and compare
   151  	if err != nil || !hmac.Equal(actualMAC, expectedMAC) {
   152  		return nil, errInvalidMAC
   153  	}
   154  
   155  	return append(in[:recordlayer.HeaderSize], body[:dataEnd]...), nil
   156  }
   157  
   158  func (c *CBC) hmac(epoch uint16, sequenceNumber uint64, contentType protocol.ContentType, protocolVersion protocol.Version, payload []byte, key []byte, hf func() hash.Hash) ([]byte, error) {
   159  	h := hmac.New(hf, key)
   160  
   161  	msg := make([]byte, 13)
   162  
   163  	binary.BigEndian.PutUint16(msg, epoch)
   164  	util.PutBigEndianUint48(msg[2:], sequenceNumber)
   165  	msg[8] = byte(contentType)
   166  	msg[9] = protocolVersion.Major
   167  	msg[10] = protocolVersion.Minor
   168  	binary.BigEndian.PutUint16(msg[11:], uint16(len(payload)))
   169  
   170  	if _, err := h.Write(msg); err != nil {
   171  		return nil, err
   172  	} else if _, err := h.Write(payload); err != nil {
   173  		return nil, err
   174  	}
   175  
   176  	return h.Sum(nil), nil
   177  }