github.com/nginxinc/kubernetes-ingress@v1.12.5/cmd/nginx-ingress/aws.go (about)

     1  // +build aws
     2  
     3  package main
     4  
     5  import (
     6  	"context"
     7  	"crypto/rand"
     8  	"encoding/base64"
     9  	"errors"
    10  	"fmt"
    11  	"math/big"
    12  	"time"
    13  
    14  	"github.com/aws/aws-sdk-go-v2/config"
    15  	"github.com/aws/aws-sdk-go-v2/service/marketplacemetering"
    16  	"github.com/aws/aws-sdk-go-v2/service/marketplacemetering/types"
    17  
    18  	"github.com/dgrijalva/jwt-go/v4"
    19  )
    20  
    21  var (
    22  	productCode   string
    23  	pubKeyVersion int32 = 1
    24  	pubKeyString  string
    25  	nonce         string
    26  )
    27  
    28  func init() {
    29  	startupCheckFn = checkAWSEntitlement
    30  }
    31  
    32  func checkAWSEntitlement() error {
    33  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
    34  	defer cancel()
    35  
    36  	nonce, err := generateRandomString(255)
    37  	if err != nil {
    38  		return err
    39  	}
    40  
    41  	cfg, err := config.LoadDefaultConfig(ctx)
    42  	if err != nil {
    43  		return fmt.Errorf("error loading AWS configuration: %w", err)
    44  	}
    45  
    46  	mpm := marketplacemetering.New(marketplacemetering.Options{Region: cfg.Region, Credentials: cfg.Credentials})
    47  
    48  	out, err := mpm.RegisterUsage(ctx, &marketplacemetering.RegisterUsageInput{ProductCode: &productCode, PublicKeyVersion: &pubKeyVersion, Nonce: &nonce})
    49  	if err != nil {
    50  		var notEnt *types.CustomerNotEntitledException
    51  		if errors.As(err, &notEnt) {
    52  			return fmt.Errorf("user not entitled, code: %v, message: %v, fault: %v", notEnt.ErrorCode(), notEnt.ErrorMessage(), notEnt.ErrorFault().String())
    53  		}
    54  		return err
    55  	}
    56  
    57  	pk, err := base64.StdEncoding.DecodeString(pubKeyString)
    58  	if err != nil {
    59  		return fmt.Errorf("error decoding Public Key string: %w", err)
    60  	}
    61  	pubKey, err := jwt.ParseRSAPublicKeyFromPEM(pk)
    62  	if err != nil {
    63  		return fmt.Errorf("error parsing Public Key: %w", err)
    64  	}
    65  
    66  	token, err := jwt.ParseWithClaims(*out.Signature, &claims{}, jwt.KnownKeyfunc(jwt.SigningMethodPS256, pubKey))
    67  	if err != nil {
    68  		return fmt.Errorf("error parsing the JWT token: %w", err)
    69  	}
    70  
    71  	if claims, ok := token.Claims.(*claims); ok && token.Valid {
    72  		if claims.ProductCode != productCode || claims.PublicKeyVersion != pubKeyVersion || claims.Nonce != nonce {
    73  			return fmt.Errorf("the claims in the JWT token don't match the request")
    74  		}
    75  	} else {
    76  		return fmt.Errorf("something is wrong with the JWT token")
    77  	}
    78  
    79  	return nil
    80  }
    81  
    82  type claims struct {
    83  	ProductCode      string    `json:"productCode,omitempty"`
    84  	PublicKeyVersion int32     `json:"publicKeyVersion,omitempty"`
    85  	IssuedAt         *jwt.Time `json:"iat,omitempty"`
    86  	Nonce            string    `json:"nonce,omitempty"`
    87  }
    88  
    89  func (c claims) Valid(h *jwt.ValidationHelper) error {
    90  	if c.Nonce == "" {
    91  		return &jwt.InvalidClaimsError{Message: "the JWT token doesn't include the Nonce"}
    92  	}
    93  	if c.ProductCode == "" {
    94  		return &jwt.InvalidClaimsError{Message: "the JWT token doesn't include the ProductCode"}
    95  	}
    96  	if c.PublicKeyVersion == 0 {
    97  		return &jwt.InvalidClaimsError{Message: "the JWT token doesn't include the PublicKeyVersion"}
    98  	}
    99  	if err := h.ValidateNotBefore(c.IssuedAt); err != nil {
   100  		return err
   101  	}
   102  
   103  	return nil
   104  }
   105  
   106  func generateRandomString(n int) (string, error) {
   107  	const letters = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz-"
   108  	ret := make([]byte, n)
   109  	for i := 0; i < n; i++ {
   110  		num, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
   111  		if err != nil {
   112  			return "", err
   113  		}
   114  		ret[i] = letters[num.Int64()]
   115  	}
   116  
   117  	return string(ret), nil
   118  }