github.com/pion/dtls/v2@v2.2.12/pkg/crypto/ciphersuite/ccm.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  package ciphersuite
     5  
     6  import (
     7  	"crypto/aes"
     8  	"crypto/rand"
     9  	"encoding/binary"
    10  	"fmt"
    11  
    12  	"github.com/pion/dtls/v2/pkg/crypto/ccm"
    13  	"github.com/pion/dtls/v2/pkg/protocol"
    14  	"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
    15  )
    16  
    17  // CCMTagLen is the length of Authentication Tag
    18  type CCMTagLen int
    19  
    20  // CCM Enums
    21  const (
    22  	CCMTagLength8  CCMTagLen = 8
    23  	CCMTagLength   CCMTagLen = 16
    24  	ccmNonceLength           = 12
    25  )
    26  
    27  // CCM Provides an API to Encrypt/Decrypt DTLS 1.2 Packets
    28  type CCM struct {
    29  	localCCM, remoteCCM         ccm.CCM
    30  	localWriteIV, remoteWriteIV []byte
    31  	tagLen                      CCMTagLen
    32  }
    33  
    34  // NewCCM creates a DTLS GCM Cipher
    35  func NewCCM(tagLen CCMTagLen, localKey, localWriteIV, remoteKey, remoteWriteIV []byte) (*CCM, error) {
    36  	localBlock, err := aes.NewCipher(localKey)
    37  	if err != nil {
    38  		return nil, err
    39  	}
    40  	localCCM, err := ccm.NewCCM(localBlock, int(tagLen), ccmNonceLength)
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  
    45  	remoteBlock, err := aes.NewCipher(remoteKey)
    46  	if err != nil {
    47  		return nil, err
    48  	}
    49  	remoteCCM, err := ccm.NewCCM(remoteBlock, int(tagLen), ccmNonceLength)
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  
    54  	return &CCM{
    55  		localCCM:      localCCM,
    56  		localWriteIV:  localWriteIV,
    57  		remoteCCM:     remoteCCM,
    58  		remoteWriteIV: remoteWriteIV,
    59  		tagLen:        tagLen,
    60  	}, nil
    61  }
    62  
    63  // Encrypt encrypt a DTLS RecordLayer message
    64  func (c *CCM) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) {
    65  	payload := raw[recordlayer.HeaderSize:]
    66  	raw = raw[:recordlayer.HeaderSize]
    67  
    68  	nonce := append(append([]byte{}, c.localWriteIV[:4]...), make([]byte, 8)...)
    69  	if _, err := rand.Read(nonce[4:]); err != nil {
    70  		return nil, err
    71  	}
    72  
    73  	additionalData := generateAEADAdditionalData(&pkt.Header, len(payload))
    74  	encryptedPayload := c.localCCM.Seal(nil, nonce, payload, additionalData)
    75  
    76  	encryptedPayload = append(nonce[4:], encryptedPayload...)
    77  	raw = append(raw, encryptedPayload...)
    78  
    79  	// Update recordLayer size to include explicit nonce
    80  	binary.BigEndian.PutUint16(raw[recordlayer.HeaderSize-2:], uint16(len(raw)-recordlayer.HeaderSize))
    81  	return raw, nil
    82  }
    83  
    84  // Decrypt decrypts a DTLS RecordLayer message
    85  func (c *CCM) Decrypt(in []byte) ([]byte, error) {
    86  	var h recordlayer.Header
    87  	err := h.Unmarshal(in)
    88  	switch {
    89  	case err != nil:
    90  		return nil, err
    91  	case h.ContentType == protocol.ContentTypeChangeCipherSpec:
    92  		// Nothing to encrypt with ChangeCipherSpec
    93  		return in, nil
    94  	case len(in) <= (8 + recordlayer.HeaderSize):
    95  		return nil, errNotEnoughRoomForNonce
    96  	}
    97  
    98  	nonce := append(append([]byte{}, c.remoteWriteIV[:4]...), in[recordlayer.HeaderSize:recordlayer.HeaderSize+8]...)
    99  	out := in[recordlayer.HeaderSize+8:]
   100  
   101  	additionalData := generateAEADAdditionalData(&h, len(out)-int(c.tagLen))
   102  	out, err = c.remoteCCM.Open(out[:0], nonce, out, additionalData)
   103  	if err != nil {
   104  		return nil, fmt.Errorf("%w: %v", errDecryptPacket, err) //nolint:errorlint
   105  	}
   106  	return append(in[:recordlayer.HeaderSize], out...), nil
   107  }