github.com/pion/dtls/v2@v2.2.12/pkg/crypto/ciphersuite/gcm.go (about)

     1  // SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
     2  // SPDX-License-Identifier: MIT
     3  
     4  package ciphersuite
     5  
     6  import (
     7  	"crypto/aes"
     8  	"crypto/cipher"
     9  	"crypto/rand"
    10  	"encoding/binary"
    11  	"fmt"
    12  
    13  	"github.com/pion/dtls/v2/pkg/protocol"
    14  	"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
    15  )
    16  
    17  const (
    18  	gcmTagLength   = 16
    19  	gcmNonceLength = 12
    20  )
    21  
    22  // GCM Provides an API to Encrypt/Decrypt DTLS 1.2 Packets
    23  type GCM struct {
    24  	localGCM, remoteGCM         cipher.AEAD
    25  	localWriteIV, remoteWriteIV []byte
    26  }
    27  
    28  // NewGCM creates a DTLS GCM Cipher
    29  func NewGCM(localKey, localWriteIV, remoteKey, remoteWriteIV []byte) (*GCM, error) {
    30  	localBlock, err := aes.NewCipher(localKey)
    31  	if err != nil {
    32  		return nil, err
    33  	}
    34  	localGCM, err := cipher.NewGCM(localBlock)
    35  	if err != nil {
    36  		return nil, err
    37  	}
    38  
    39  	remoteBlock, err := aes.NewCipher(remoteKey)
    40  	if err != nil {
    41  		return nil, err
    42  	}
    43  	remoteGCM, err := cipher.NewGCM(remoteBlock)
    44  	if err != nil {
    45  		return nil, err
    46  	}
    47  
    48  	return &GCM{
    49  		localGCM:      localGCM,
    50  		localWriteIV:  localWriteIV,
    51  		remoteGCM:     remoteGCM,
    52  		remoteWriteIV: remoteWriteIV,
    53  	}, nil
    54  }
    55  
    56  // Encrypt encrypt a DTLS RecordLayer message
    57  func (g *GCM) Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) {
    58  	payload := raw[recordlayer.HeaderSize:]
    59  	raw = raw[:recordlayer.HeaderSize]
    60  
    61  	nonce := make([]byte, gcmNonceLength)
    62  	copy(nonce, g.localWriteIV[:4])
    63  	if _, err := rand.Read(nonce[4:]); err != nil {
    64  		return nil, err
    65  	}
    66  
    67  	additionalData := generateAEADAdditionalData(&pkt.Header, len(payload))
    68  	encryptedPayload := g.localGCM.Seal(nil, nonce, payload, additionalData)
    69  	r := make([]byte, len(raw)+len(nonce[4:])+len(encryptedPayload))
    70  	copy(r, raw)
    71  	copy(r[len(raw):], nonce[4:])
    72  	copy(r[len(raw)+len(nonce[4:]):], encryptedPayload)
    73  
    74  	// Update recordLayer size to include explicit nonce
    75  	binary.BigEndian.PutUint16(r[recordlayer.HeaderSize-2:], uint16(len(r)-recordlayer.HeaderSize))
    76  	return r, nil
    77  }
    78  
    79  // Decrypt decrypts a DTLS RecordLayer message
    80  func (g *GCM) Decrypt(in []byte) ([]byte, error) {
    81  	var h recordlayer.Header
    82  	err := h.Unmarshal(in)
    83  	switch {
    84  	case err != nil:
    85  		return nil, err
    86  	case h.ContentType == protocol.ContentTypeChangeCipherSpec:
    87  		// Nothing to encrypt with ChangeCipherSpec
    88  		return in, nil
    89  	case len(in) <= (8 + recordlayer.HeaderSize):
    90  		return nil, errNotEnoughRoomForNonce
    91  	}
    92  
    93  	nonce := make([]byte, 0, gcmNonceLength)
    94  	nonce = append(append(nonce, g.remoteWriteIV[:4]...), in[recordlayer.HeaderSize:recordlayer.HeaderSize+8]...)
    95  	out := in[recordlayer.HeaderSize+8:]
    96  
    97  	additionalData := generateAEADAdditionalData(&h, len(out)-gcmTagLength)
    98  	out, err = g.remoteGCM.Open(out[:0], nonce, out, additionalData)
    99  	if err != nil {
   100  		return nil, fmt.Errorf("%w: %v", errDecryptPacket, err) //nolint:errorlint
   101  	}
   102  	return append(in[:recordlayer.HeaderSize], out...), nil
   103  }