github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/cloud/ec2util/ec2util.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package ec2util contains a few helper functions related to EC2 (validating
     6  // an Instance Identity Document, extracting a Amazon Resource Name, etc).
     7  //
     8  // Some of the code from this file comes from a Hashicorp Vault
     9  // (covered by Mozilla Public License, version 2.0) file:
    10  // https://github.com/hashicorp/vault/blob/2500218a9cbd833057145aefec1802e6dd5ec8cc/builtin/credential/aws-ec2/path_config_certificate.go
    11  
    12  package ec2util
    13  
    14  import (
    15  	"bytes"
    16  	"crypto/x509"
    17  	"encoding/json"
    18  	"encoding/pem"
    19  	"fmt"
    20  	"regexp"
    21  	"strings"
    22  	"time"
    23  
    24  	"github.com/aws/aws-sdk-go/service/ec2"
    25  	"v.io/x/lib/vlog"
    26  	"go.mozilla.org/pkcs7"
    27  )
    28  
    29  type IdentityDocument struct {
    30  	InstanceID  string    `json:"instanceId,omitempty"`
    31  	AccountID   string    `json:"accountId,omitempty"`
    32  	Region      string    `json:"region,omitempty"`
    33  	PendingTime time.Time `json:"pendingTime,omitempty"`
    34  }
    35  
    36  var (
    37  	// TODO(razvanm): replace this with a proper parsing of ARNs.
    38  	// Potential source of inspiration: https://github.com/gigawattio/awsarn/blob/master/awsarn.go.
    39  	roleRE                = regexp.MustCompile("^arn:aws:iam::([0-9]*):instance-profile/(.*)$")
    40  	awsPublicCertificates []*x509.Certificate
    41  )
    42  
    43  func init() {
    44  	cert, err := DecodePEMAndParseCertificate(awsPublicCertificatePEM)
    45  	if err != nil {
    46  		panic(err)
    47  	}
    48  	awsPublicCertificates = []*x509.Certificate{cert}
    49  }
    50  
    51  func GetInstance(output *ec2.DescribeInstancesOutput) (*ec2.Instance, error) {
    52  	if len(output.Reservations) != 1 {
    53  		return nil, fmt.Errorf("unexpected number of Reservations (want 1): %+v", output)
    54  	}
    55  
    56  	reservation := output.Reservations[0]
    57  	if len(reservation.Instances) != 1 {
    58  		return nil, fmt.Errorf("unexpected number of Instances (want 1): %+v", output)
    59  	}
    60  
    61  	instance := reservation.Instances[0]
    62  	if instance.IamInstanceProfile == nil {
    63  		return nil, fmt.Errorf("non-nil IamInstanceProfile is required: %+v", output)
    64  	}
    65  
    66  	return instance, nil
    67  }
    68  
    69  // GetIamInstanceProfileARN extracts the ARN from the `instance` output of a call to
    70  // DescribeInstances. The ARN is expected to be non-empty.
    71  func GetIamInstanceProfileARN(instance *ec2.Instance) (string, error) {
    72  	if instance == nil {
    73  		return "", fmt.Errorf("non-nil instance is required: %+v", instance)
    74  	}
    75  
    76  	if instance.IamInstanceProfile == nil {
    77  		return "", fmt.Errorf("non-nil IamInstanceProfile is required: %+v", instance)
    78  	}
    79  
    80  	profile := instance.IamInstanceProfile
    81  	if profile.Arn == nil {
    82  		return "", fmt.Errorf("non-nil Arn is required: %+v", instance)
    83  	}
    84  
    85  	if len(*profile.Arn) == 0 {
    86  		return "", fmt.Errorf("non-empty Arn is required: %+v", instance)
    87  	}
    88  
    89  	return *profile.Arn, nil
    90  }
    91  
    92  // GetPublicIPAddress extracts the public IP address from the output of a call
    93  // to DescribeInstances Instance. The response is expected to be non-empty if the
    94  // instance has a public IP and empty ("") if the instance is private.
    95  func GetPublicIPAddress(instance *ec2.Instance) (string, error) {
    96  	if instance == nil {
    97  		return "", fmt.Errorf("non-nil instance is required: %+v", instance)
    98  	}
    99  
   100  	if instance.PublicIpAddress == nil || len(*instance.PublicIpAddress) == 0 {
   101  		return "", nil
   102  	}
   103  
   104  	return *instance.PublicIpAddress, nil
   105  }
   106  
   107  // GetPrivateIPAddress extracts the private IP address from the output of a call
   108  // to DescribeInstances Instance. The response is expected to be the first private IP
   109  // attached to the instance.
   110  // If the instances no attached interfaces, the value is empty ("")
   111  func GetPrivateIPAddress(instance *ec2.Instance) (string, error) {
   112  	if instance == nil {
   113  		return "", fmt.Errorf("non-nil instance is required: %+v", instance)
   114  	}
   115  
   116  	if instance.PrivateIpAddress == nil || len(*instance.PrivateIpAddress) == 0 {
   117  		return "", nil
   118  	}
   119  
   120  	return *instance.PrivateIpAddress, nil
   121  }
   122  
   123  // GetTags returns a map of Key/Value pairs representing the tags
   124  func GetTags(instance *ec2.Instance) ([]*ec2.Tag, error) {
   125  	if instance == nil {
   126  		return nil, fmt.Errorf("non-nil instance is required: %+v", instance)
   127  	}
   128  
   129  	if instance.Tags == nil || len(instance.Tags) == 0 {
   130  		return nil, nil
   131  	}
   132  
   133  	return instance.Tags, nil
   134  }
   135  
   136  // GetInstanceId returns the instanceID from the output of a call
   137  // to DescribeInstances Instance.
   138  func GetInstanceId(instance *ec2.Instance) (string, error) {
   139  	if instance == nil {
   140  		return "", fmt.Errorf("non-nil instance is required: %+v", instance)
   141  	}
   142  
   143  	if instance.InstanceId == nil || len(*instance.InstanceId) == 0 {
   144  		return "", nil
   145  	}
   146  
   147  	return *instance.InstanceId, nil
   148  }
   149  
   150  // ValidateInstance checks if an EC2 instance exists and it has the expected
   151  // IP. It returns the name of the instance profile (the IAM role).
   152  //
   153  // Note that this validation will not work for NATed VMs.
   154  func ValidateInstance(output *ec2.DescribeInstancesOutput, doc IdentityDocument, remoteAddr string) (role string, err error) {
   155  	vlog.Infof("reservations:\n%+v", output.Reservations)
   156  
   157  	instance, err := GetInstance(output)
   158  	if err != nil {
   159  		return "", err
   160  	}
   161  
   162  	publicIP, err := GetPublicIPAddress(instance)
   163  	if err != nil {
   164  		return "", err
   165  	}
   166  
   167  	// Instances that do not have a public IP should be able to authenticate
   168  	// with ticket server. Connections from such instances are routed through a
   169  	// NAT gateway with an Elastic IP. The following check which ensures the
   170  	// remoteAddr from which the connection originates is same as the public IP
   171  	// of the instance is skipped for private instances.
   172  	if remoteAddr != "" && publicIP != "" {
   173  		if !strings.HasPrefix(remoteAddr, publicIP+":") {
   174  			return "", fmt.Errorf("mismatch between the real peer address (%s) and public IP of the instance (%s)", remoteAddr, publicIP)
   175  		}
   176  	}
   177  
   178  	arn, err := GetIamInstanceProfileARN(instance)
   179  	if err != nil {
   180  		return "", err
   181  	}
   182  	m := roleRE.FindStringSubmatch(arn)
   183  	if len(m) != 3 {
   184  		return "", fmt.Errorf("unexpected ARN format for %q", arn)
   185  	}
   186  	vlog.Infof("IAM role: %q parsed: %q", arn, m)
   187  
   188  	accountID, role := m[1], m[2]
   189  
   190  	if accountID != doc.AccountID {
   191  		return "", fmt.Errorf("mismatch between account ID in Identity Doc (%q) and role (%q): %q", doc.AccountID, accountID, arn)
   192  	}
   193  	return role, nil
   194  }
   195  
   196  // ParseAndVerifyIdentityDocument parses and checks and identity document in
   197  // PKCS#7 format. Only some relevant fields are returned.
   198  func ParseAndVerifyIdentityDocument(pkcs7b64 string) (*IdentityDocument, string, error) {
   199  	// Insert the header and footer for the signature to be able to pem decode it.
   200  	s := fmt.Sprintf("-----BEGIN PKCS7-----\n%s\n-----END PKCS7-----", pkcs7b64)
   201  
   202  	// Decode the PEM encoded signature.
   203  	pkcs7BER, pkcs7Rest := pem.Decode([]byte(s))
   204  	if len(pkcs7Rest) != 0 {
   205  		return nil, "", fmt.Errorf("failed to decode the PKCS#7 signature")
   206  	}
   207  
   208  	// Parse the signature from asn1 format into a struct.
   209  	pkcs7Data, err := pkcs7.Parse(pkcs7BER.Bytes)
   210  	if err != nil {
   211  		return nil, "", fmt.Errorf("failed to parse the BER encoded PKCS#7 signature: %s\n", err)
   212  	}
   213  
   214  	pkcs7Data.Certificates = awsPublicCertificates
   215  
   216  	// Verify extracts the authenticated attributes in the PKCS#7
   217  	// signature, and verifies the authenticity of the content using
   218  	// 'dsa.PublicKey' embedded in the public certificate.
   219  	if err := pkcs7Data.Verify(); err != nil {
   220  		return nil, "", fmt.Errorf("failed to verify the signature: %v", err)
   221  	}
   222  
   223  	// Check if the signature has content inside of it.
   224  	if len(pkcs7Data.Content) == 0 {
   225  		return nil, "", fmt.Errorf("instance identity document could not be found in the signature")
   226  	}
   227  
   228  	var identityDoc IdentityDocument
   229  	content := string(pkcs7Data.Content)
   230  	vlog.VI(1).Infof("%v", content)
   231  	decoder := json.NewDecoder(bytes.NewReader(pkcs7Data.Content))
   232  	decoder.UseNumber()
   233  	if err := decoder.Decode(&identityDoc); err != nil {
   234  		return nil, "", err
   235  	}
   236  
   237  	return &identityDoc, content, nil
   238  }
   239  
   240  // DecodePEMAndParseCertificate decodes the PEM encoded certificate and
   241  // parses it into a x509 cert.
   242  func DecodePEMAndParseCertificate(certificate string) (*x509.Certificate, error) {
   243  	// Decode the PEM block and error out if a block is not detected in
   244  	// the first attempt.
   245  	decodedPublicCert, rest := pem.Decode([]byte(certificate))
   246  	if len(rest) != 0 {
   247  		return nil, fmt.Errorf("invalid certificate; should be one PEM block only")
   248  	}
   249  
   250  	// Check if the certificate can be parsed.
   251  	publicCert, err := x509.ParseCertificate(decodedPublicCert.Bytes)
   252  	if err != nil {
   253  		return nil, err
   254  	}
   255  	if publicCert == nil {
   256  		return nil, fmt.Errorf("invalid certificate; failed to parse certificate")
   257  	}
   258  	return publicCert, nil
   259  }