github.com/cozy/cozy-stack@v0.0.0-20240603063001-31110fa4cae1/model/oauth/android_safety_net.go (about)

     1  package oauth
     2  
     3  import (
     4  	"crypto/rsa"
     5  	"crypto/x509"
     6  	"encoding/base64"
     7  	"errors"
     8  	"fmt"
     9  
    10  	"github.com/cozy/cozy-stack/model/instance"
    11  	"github.com/cozy/cozy-stack/pkg/config/config"
    12  	"github.com/cozy/cozy-stack/pkg/logger"
    13  	jwt "github.com/golang-jwt/jwt/v5"
    14  )
    15  
    16  // checkSafetyNetAttestation will check an attestation made by the SafetyNet API.
    17  // Cf https://developer.android.com/training/safetynet/attestation#use-response-server
    18  func (c *Client) checkSafetyNetAttestation(inst *instance.Instance, req AttestationRequest) error {
    19  	store := GetStore()
    20  	if ok := store.CheckAndClearChallenge(inst, c.ID(), req.Challenge); !ok {
    21  		return errors.New("invalid challenge")
    22  	}
    23  
    24  	token, err := jwt.Parse(req.Attestation, safetyNetKeyFunc)
    25  	if err != nil {
    26  		return fmt.Errorf("cannot parse attestation: %s", err)
    27  	}
    28  	claims, ok := token.Claims.(jwt.MapClaims)
    29  	if !ok {
    30  		return errors.New("invalid claims type")
    31  	}
    32  	inst.Logger().Debugf("checkSafetyNetAttestation claims = %#v", claims)
    33  
    34  	nonce, ok := claims["nonce"].(string)
    35  	if !ok || len(nonce) == 0 {
    36  		return errors.New("missing nonce")
    37  	}
    38  	if req.Challenge != nonce {
    39  		return errors.New("invalid nonce")
    40  	}
    41  
    42  	if err := checkSafetyNetPackageName(claims); err != nil {
    43  		return err
    44  	}
    45  	if err := checkSafetyNetCertificateDigest(claims); err != nil {
    46  		return err
    47  	}
    48  	return nil
    49  }
    50  
    51  func checkSafetyNetPackageName(claims jwt.MapClaims) error {
    52  	packageName, ok := claims["apkPackageName"].(string)
    53  	if !ok || len(packageName) == 0 {
    54  		return errors.New("missing apkPackageName")
    55  	}
    56  	names := config.GetConfig().Flagship.APKPackageNames
    57  	for _, name := range names {
    58  		if name == packageName {
    59  			return nil
    60  		}
    61  	}
    62  	return fmt.Errorf("%s is not the package name of the flagship app", packageName)
    63  }
    64  
    65  func checkSafetyNetCertificateDigest(claims jwt.MapClaims) error {
    66  	certDigest, ok := claims["apkCertificateDigestSha256"].([]interface{})
    67  	if !ok || len(certDigest) == 0 {
    68  		return errors.New("missing apkCertificateDigestSha256")
    69  	}
    70  	digests := config.GetConfig().Flagship.APKCertificateDigests
    71  	for _, digest := range digests {
    72  		if digest == certDigest[0] {
    73  			return nil
    74  		}
    75  	}
    76  	logger.WithNamespace("oauth").
    77  		Debugf("Invalid certificate digest, expected %s, got %s", digests[0], certDigest[0])
    78  	return errors.New("invalid certificate digest")
    79  }
    80  
    81  func safetyNetKeyFunc(token *jwt.Token) (interface{}, error) {
    82  	if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
    83  		return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
    84  	}
    85  	x5c, ok := token.Header["x5c"].([]interface{})
    86  	if !ok || len(x5c) == 0 {
    87  		return nil, errors.New("missing certification")
    88  	}
    89  
    90  	certs := make([]*x509.Certificate, 0, len(x5c))
    91  	for _, raw := range x5c {
    92  		rawStr, ok := raw.(string)
    93  		if !ok {
    94  			return nil, errors.New("missing certification")
    95  		}
    96  		buf, err := base64.StdEncoding.DecodeString(rawStr)
    97  		if err != nil {
    98  			return nil, fmt.Errorf("error decoding cert as base64: %s", err)
    99  		}
   100  		cert, err := x509.ParseCertificate(buf)
   101  		if err != nil {
   102  			return nil, fmt.Errorf("error parsing cert: %s", err)
   103  		}
   104  		certs = append(certs, cert)
   105  	}
   106  	intermediates := x509.NewCertPool()
   107  	for _, cert := range certs {
   108  		intermediates.AddCert(cert)
   109  	}
   110  
   111  	opts := x509.VerifyOptions{
   112  		DNSName:       "attest.android.com",
   113  		Intermediates: intermediates,
   114  	}
   115  	if _, err := certs[0].Verify(opts); err != nil {
   116  		return nil, err
   117  	}
   118  
   119  	rsaKey, ok := certs[0].PublicKey.(*rsa.PublicKey)
   120  	if !ok {
   121  		return nil, errors.New("invalid certification")
   122  	}
   123  	return rsaKey, nil
   124  }