github.com/cozy/cozy-stack@v0.0.0-20240603063001-31110fa4cae1/model/oauth/android_play_integrity.go (about)

     1  package oauth
     2  
     3  import (
     4  	"crypto/aes"
     5  	"crypto/cipher"
     6  	"crypto/x509"
     7  	"encoding/base64"
     8  	"errors"
     9  	"fmt"
    10  	"strings"
    11  
    12  	"github.com/cozy/cozy-stack/model/instance"
    13  	"github.com/cozy/cozy-stack/pkg/config/config"
    14  	"github.com/cozy/cozy-stack/pkg/crypto"
    15  	"github.com/cozy/cozy-stack/pkg/logger"
    16  	jwt "github.com/golang-jwt/jwt/v5"
    17  )
    18  
    19  // checkPlayIntegrityAttestation will check an attestation made by the Play
    20  // Integrity API.
    21  // https://developer.android.com/google/play/integrity
    22  func (c *Client) checkPlayIntegrityAttestation(inst *instance.Instance, req AttestationRequest) error {
    23  	store := GetStore()
    24  	if ok := store.CheckAndClearChallenge(inst, c.ID(), req.Challenge); !ok {
    25  		return errors.New("invalid challenge")
    26  	}
    27  
    28  	token, err := decryptPlayIntegrityToken(req)
    29  	if err != nil {
    30  		inst.Logger().Debugf("cannot decrypt the play integrity token: %s", err)
    31  		return fmt.Errorf("cannot parse attestation: %s", err)
    32  	}
    33  	claims, ok := token.Claims.(jwt.MapClaims)
    34  	if !ok {
    35  		return errors.New("invalid claims type")
    36  	}
    37  	inst.Logger().Debugf("checkPlayIntegrityAttestation claims = %#v", claims)
    38  
    39  	nonce, ok := getFromClaims(claims, "requestDetails.nonce").(string)
    40  	if !ok || len(nonce) == 0 {
    41  		return errors.New("missing nonce")
    42  	}
    43  	if req.Challenge != nonce {
    44  		return errors.New("invalid nonce")
    45  	}
    46  
    47  	if err := checkPlayIntegrityPackageName(claims); err != nil {
    48  		return err
    49  	}
    50  	if err := checkPlayIntegrityCertificateDigest(claims); err != nil {
    51  		return err
    52  	}
    53  	return nil
    54  }
    55  
    56  // CheckPlayIntegrityAttestationForTestingPurpose is only used for testing
    57  // purpose. It is a simplified version of checkPlayIntegrityAttestation. In
    58  // particular, it doesn't return an error for invalid package name with a test
    59  // attestation.
    60  func CheckPlayIntegrityAttestationForTestingPurpose(req AttestationRequest) error {
    61  	token, err := decryptPlayIntegrityToken(req)
    62  	if err != nil {
    63  		return fmt.Errorf("cannot parse attestation: %s", err)
    64  	}
    65  	claims, ok := token.Claims.(jwt.MapClaims)
    66  	if !ok {
    67  		return errors.New("invalid claims type")
    68  	}
    69  
    70  	nonce, ok := getFromClaims(claims, "requestDetails.nonce").(string)
    71  	if !ok || len(nonce) == 0 {
    72  		return errors.New("missing nonce")
    73  	}
    74  	if req.Challenge != nonce {
    75  		return errors.New("invalid nonce")
    76  	}
    77  	return nil
    78  }
    79  
    80  func decryptPlayIntegrityToken(req AttestationRequest) (*jwt.Token, error) {
    81  	lastErr := errors.New("no decryption key")
    82  	for _, key := range config.GetConfig().Flagship.PlayIntegrityDecryptionKeys {
    83  		decrypted, err := decryptPlayIntegrityJWE(req.Attestation, key)
    84  		if err == nil {
    85  			return parsePlayIntegrityToken(decrypted)
    86  		}
    87  		lastErr = err
    88  	}
    89  	return nil, lastErr
    90  }
    91  
    92  func decryptPlayIntegrityJWE(attestation string, rawKey string) ([]byte, error) {
    93  	parts := strings.Split(attestation, ".")
    94  	if len(parts) != 5 {
    95  		return nil, errors.New("invalid integrity token")
    96  	}
    97  	header := []byte(parts[0])
    98  	encryptedKey, err := base64.RawURLEncoding.DecodeString(parts[1])
    99  	// AES Key wrap works with 64 bits block, and the wrapped version has n+1
   100  	// blocks (for integrity check). The kek key is 256bits, thus the
   101  	// encryptedKey is 320bits => 40bytes.
   102  	if err != nil || len(encryptedKey) != 40 {
   103  		return nil, fmt.Errorf("invalid encrypted key: %w", err)
   104  	}
   105  	initVector, err := base64.RawURLEncoding.DecodeString(parts[2])
   106  	if err != nil {
   107  		return nil, fmt.Errorf("invalid initialization vector: %w", err)
   108  	}
   109  	cipherText, err := base64.RawURLEncoding.DecodeString(parts[3])
   110  	if err != nil {
   111  		return nil, fmt.Errorf("invalid ciphertext: %w", err)
   112  	}
   113  	authTag, err := base64.RawURLEncoding.DecodeString(parts[4])
   114  	if err != nil || len(authTag) != 16 { // GCM uses 128bits => 16bytes
   115  		return nil, fmt.Errorf("invalid authentication tag: %w", err)
   116  	}
   117  
   118  	kek, err := base64.StdEncoding.DecodeString(rawKey) // kek means Key-encryption key, cf RFC-3394
   119  	if err != nil {
   120  		return nil, fmt.Errorf("invalid decryption key: %w", err)
   121  	}
   122  	block, err := aes.NewCipher(kek)
   123  	if err != nil {
   124  		return nil, fmt.Errorf("invalid decryption key: %w", err)
   125  	}
   126  	contentKey, err := crypto.UnwrapA256KW(block, encryptedKey)
   127  	if err != nil {
   128  		return nil, fmt.Errorf("cannot unwrap the key: %w", err)
   129  	}
   130  	if len(contentKey) != 32 { // AES256 means 256bits => 32bytes
   131  		return nil, fmt.Errorf("invalid encrypted key: %w", err)
   132  	}
   133  
   134  	cek, err := aes.NewCipher(contentKey)
   135  	if err != nil {
   136  		return nil, fmt.Errorf("cannot load the cek: %w", err)
   137  	}
   138  	aesgcm, err := cipher.NewGCM(cek)
   139  	if err != nil {
   140  		return nil, fmt.Errorf("cannot initialize AES-GCM: %w", err)
   141  	}
   142  	if len(initVector) != aesgcm.NonceSize() {
   143  		return nil, fmt.Errorf("invalid initialization vector: %w", err)
   144  	}
   145  	decrypted, err := aesgcm.Open(nil, initVector, append(cipherText, authTag...), header)
   146  	if err != nil {
   147  		return nil, fmt.Errorf("cannot decrypt: %w", err)
   148  	}
   149  
   150  	return decrypted, nil
   151  }
   152  
   153  func parsePlayIntegrityToken(decrypted []byte) (*jwt.Token, error) {
   154  	lastErr := errors.New("no verification key")
   155  	for _, key := range config.GetConfig().Flagship.PlayIntegrityVerificationKeys {
   156  		token, err := parsePlayIntegrityJWT(decrypted, key)
   157  		if err == nil {
   158  			return token, err
   159  		}
   160  		lastErr = err
   161  	}
   162  	return nil, lastErr
   163  }
   164  
   165  func parsePlayIntegrityJWT(decrypted []byte, rawKey string) (*jwt.Token, error) {
   166  	return jwt.Parse(string(decrypted), func(token *jwt.Token) (interface{}, error) {
   167  		if _, ok := token.Method.(*jwt.SigningMethodECDSA); !ok {
   168  			return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
   169  		}
   170  		key, err := base64.StdEncoding.DecodeString(rawKey)
   171  		if err != nil {
   172  			return nil, fmt.Errorf("invalid verification key: %w", err)
   173  		}
   174  		pubKey, err := x509.ParsePKIXPublicKey(key)
   175  		if err != nil {
   176  			return nil, fmt.Errorf("invalid verification key: %w", err)
   177  		}
   178  		return pubKey, nil
   179  	})
   180  }
   181  
   182  func checkPlayIntegrityPackageName(claims jwt.MapClaims) error {
   183  	packageName, ok := getFromClaims(claims, "appIntegrity.packageName").(string)
   184  	if !ok || len(packageName) == 0 {
   185  		return errors.New("missing appIntegrity.packageName")
   186  	}
   187  	names := config.GetConfig().Flagship.APKPackageNames
   188  	for _, name := range names {
   189  		if name == packageName {
   190  			return nil
   191  		}
   192  	}
   193  	return fmt.Errorf("%s is not the package name of the flagship app", packageName)
   194  }
   195  
   196  func checkPlayIntegrityCertificateDigest(claims jwt.MapClaims) error {
   197  	certDigest, ok := getFromClaims(claims, "appIntegrity.certificateSha256Digest").([]interface{})
   198  	if !ok || len(certDigest) == 0 {
   199  		return errors.New("missing appIntegrity.certificateSha256Digest")
   200  	}
   201  	digests := config.GetConfig().Flagship.APKCertificateDigests
   202  	for _, digest := range digests {
   203  		if digest == certDigest[0] {
   204  			return nil
   205  		}
   206  		// XXX Google was using standard base64 for SafetyNet, but the safe-URL
   207  		// variant for Play Integrity...
   208  		urlSafeDigest := strings.TrimRight(digest, "=")
   209  		urlSafeDigest = strings.ReplaceAll(urlSafeDigest, "+", "-")
   210  		urlSafeDigest = strings.ReplaceAll(urlSafeDigest, "/", "_")
   211  		if urlSafeDigest == certDigest[0] {
   212  			return nil
   213  		}
   214  	}
   215  	logger.WithNamespace("oauth").
   216  		Debugf("Invalid certificate digest, expected %s, got %s", digests[0], certDigest[0])
   217  	return errors.New("invalid certificate digest")
   218  }
   219  
   220  func getFromClaims(claims jwt.MapClaims, key string) interface{} {
   221  	parts := strings.Split(key, ".")
   222  	var obj interface{} = map[string]interface{}(claims)
   223  	for _, part := range parts {
   224  		m, ok := obj.(map[string]interface{})
   225  		if !ok {
   226  			return nil
   227  		}
   228  		obj = m[part]
   229  	}
   230  	return obj
   231  }