github.com/aavshr/aws-sdk-go@v1.41.3/service/s3/s3crypto/integration/main_test.go (about)

     1  //go:build integration && go1.14
     2  // +build integration,go1.14
     3  
     4  package integration
     5  
     6  import (
     7  	"bytes"
     8  	"crypto/rand"
     9  	"flag"
    10  	"io"
    11  	"io/ioutil"
    12  	"log"
    13  	"os"
    14  	"testing"
    15  
    16  	"github.com/aavshr/aws-sdk-go/aws"
    17  	"github.com/aavshr/aws-sdk-go/aws/session"
    18  	"github.com/aavshr/aws-sdk-go/awstesting/integration"
    19  	"github.com/aavshr/aws-sdk-go/service/kms"
    20  	"github.com/aavshr/aws-sdk-go/service/kms/kmsiface"
    21  	"github.com/aavshr/aws-sdk-go/service/s3"
    22  	"github.com/aavshr/aws-sdk-go/service/s3/s3crypto"
    23  	"github.com/aavshr/aws-sdk-go/service/s3/s3iface"
    24  )
    25  
    26  var config = &struct {
    27  	Enabled  bool
    28  	Region   string
    29  	KMSKeyID string
    30  	Bucket   string
    31  	Session  *session.Session
    32  	Clients  struct {
    33  		KMS kmsiface.KMSAPI
    34  		S3  s3iface.S3API
    35  	}
    36  }{}
    37  
    38  func init() {
    39  	flag.BoolVar(&config.Enabled, "enable", false, "enable integration testing")
    40  	flag.StringVar(&config.Region, "region", "us-west-2", "integration test region")
    41  	flag.StringVar(&config.KMSKeyID, "kms-key-id", "", "KMS CMK Key ID")
    42  	flag.StringVar(&config.Bucket, "bucket", "", "S3 Bucket Name")
    43  }
    44  
    45  func TestMain(m *testing.M) {
    46  	flag.Parse()
    47  	if !config.Enabled {
    48  		log.Println("skipping s3crypto integration tests")
    49  		os.Exit(0)
    50  	}
    51  
    52  	if len(config.Bucket) == 0 {
    53  		log.Fatal("bucket name must be provided")
    54  	}
    55  
    56  	if len(config.KMSKeyID) == 0 {
    57  		log.Fatal("kms cmk key id must be provided")
    58  	}
    59  
    60  	config.Session = session.Must(session.NewSession(&aws.Config{Region: &config.Region}))
    61  
    62  	config.Clients.KMS = kms.New(config.Session)
    63  	config.Clients.S3 = s3.New(config.Session)
    64  
    65  	m.Run()
    66  }
    67  
    68  func TestEncryptionV1_WithV2Interop(t *testing.T) {
    69  	kmsKeyGenerator := s3crypto.NewKMSKeyGenerator(config.Clients.KMS, config.KMSKeyID)
    70  
    71  	// 1020 is chosen here as it is not cleanly divisible by the AES-256 block size
    72  	testData := make([]byte, 1020)
    73  	_, err := rand.Read(testData)
    74  	if err != nil {
    75  		t.Fatalf("failed to read random data: %v", err)
    76  	}
    77  
    78  	v1DC := s3crypto.NewDecryptionClient(config.Session, func(client *s3crypto.DecryptionClient) {
    79  		client.S3Client = config.Clients.S3
    80  	})
    81  
    82  	cr := s3crypto.NewCryptoRegistry()
    83  	if err = s3crypto.RegisterKMSWrapWithAnyCMK(cr, config.Clients.KMS); err != nil {
    84  		t.Fatalf("expected no error, got %v", err)
    85  	}
    86  	if err = s3crypto.RegisterKMSContextWrapWithAnyCMK(cr, config.Clients.KMS); err != nil {
    87  		t.Fatalf("expected no error, got %v", err)
    88  	}
    89  	if err = s3crypto.RegisterAESGCMContentCipher(cr); err != nil {
    90  		t.Fatalf("expected no error, got %v", err)
    91  	}
    92  	if err = s3crypto.RegisterAESCBCContentCipher(cr, s3crypto.AESCBCPadder); err != nil {
    93  		t.Fatalf("expected no error, got %v", err)
    94  	}
    95  
    96  	v2DC, err := s3crypto.NewDecryptionClientV2(config.Session, cr, func(options *s3crypto.DecryptionClientOptions) {
    97  		options.S3Client = config.Clients.S3
    98  	})
    99  	if err != nil {
   100  		t.Fatalf("expected no error, got %v", err)
   101  	}
   102  
   103  	cases := map[string]s3crypto.ContentCipherBuilder{
   104  		"AES/GCM/NoPadding":    s3crypto.AESGCMContentCipherBuilder(kmsKeyGenerator),
   105  		"AES/CBC/PKCS5Padding": s3crypto.AESCBCContentCipherBuilder(kmsKeyGenerator, s3crypto.AESCBCPadder),
   106  	}
   107  
   108  	for name, ccb := range cases {
   109  		t.Run(name, func(t *testing.T) {
   110  			ec := s3crypto.NewEncryptionClient(config.Session, ccb, func(client *s3crypto.EncryptionClient) {
   111  				client.S3Client = config.Clients.S3
   112  			})
   113  			id := integration.UniqueID()
   114  			// PutObject with V1 Client
   115  			putObject(t, ec, id, bytes.NewReader(testData))
   116  			// Verify V1 Decryption Client
   117  			getObjectAndCompare(t, v1DC, id, testData)
   118  			// Verify V2 Decryption Client
   119  			getObjectAndCompare(t, v2DC, id, testData)
   120  		})
   121  	}
   122  }
   123  
   124  func TestEncryptionV2_WithV1Interop(t *testing.T) {
   125  	kmsKeyGenerator := s3crypto.NewKMSContextKeyGenerator(config.Clients.KMS, config.KMSKeyID, s3crypto.MaterialDescription{})
   126  	gcmContentCipherBuilder := s3crypto.AESGCMContentCipherBuilderV2(kmsKeyGenerator)
   127  
   128  	ec, err := s3crypto.NewEncryptionClientV2(config.Session, gcmContentCipherBuilder, func(options *s3crypto.EncryptionClientOptions) {
   129  		options.S3Client = config.Clients.S3
   130  	})
   131  	if err != nil {
   132  		t.Fatalf("failed to construct encryption decryptionClient: %v", err)
   133  	}
   134  
   135  	decryptionClient := s3crypto.NewDecryptionClient(config.Session, func(client *s3crypto.DecryptionClient) {
   136  		client.S3Client = config.Clients.S3
   137  	})
   138  
   139  	cr := s3crypto.NewCryptoRegistry()
   140  	if err = s3crypto.RegisterKMSContextWrapWithAnyCMK(cr, config.Clients.KMS); err != nil {
   141  		t.Fatalf("expected no error, got %v", err)
   142  	}
   143  	if err = s3crypto.RegisterAESGCMContentCipher(cr); err != nil {
   144  		t.Fatalf("expected no error, got %v", err)
   145  	}
   146  
   147  	decryptionClientV2, err := s3crypto.NewDecryptionClientV2(config.Session, cr, func(options *s3crypto.DecryptionClientOptions) {
   148  		options.S3Client = config.Clients.S3
   149  	})
   150  	if err != nil {
   151  		t.Fatalf("expected no error, got %v", err)
   152  	}
   153  
   154  	// 1020 is chosen here as it is not cleanly divisible by the AES-256 block size
   155  	testData := make([]byte, 1020)
   156  	_, err = rand.Read(testData)
   157  	if err != nil {
   158  		t.Fatalf("failed to read random data: %v", err)
   159  	}
   160  
   161  	keyId := integration.UniqueID()
   162  
   163  	// Upload V2 Objects with Encryption Client
   164  	putObject(t, ec, keyId, bytes.NewReader(testData))
   165  
   166  	// Verify V2 Object with V2 Decryption Client
   167  	getObjectAndCompare(t, decryptionClientV2, keyId, testData)
   168  
   169  	// Verify V2 Object with V1 Decryption Client
   170  	getObjectAndCompare(t, decryptionClient, keyId, testData)
   171  }
   172  
   173  type Encryptor interface {
   174  	PutObject(input *s3.PutObjectInput) (*s3.PutObjectOutput, error)
   175  }
   176  
   177  func putObject(t *testing.T, client Encryptor, key string, reader io.ReadSeeker) {
   178  	t.Helper()
   179  	_, err := client.PutObject(&s3.PutObjectInput{
   180  		Bucket: &config.Bucket,
   181  		Key:    &key,
   182  		Body:   reader,
   183  	})
   184  	if err != nil {
   185  		t.Fatalf("failed to upload object: %v", err)
   186  	}
   187  	t.Cleanup(doKeyCleanup(key))
   188  }
   189  
   190  type Decryptor interface {
   191  	GetObject(input *s3.GetObjectInput) (*s3.GetObjectOutput, error)
   192  }
   193  
   194  func getObjectAndCompare(t *testing.T, client Decryptor, key string, expected []byte) {
   195  	t.Helper()
   196  	output, err := client.GetObject(&s3.GetObjectInput{
   197  		Bucket: &config.Bucket,
   198  		Key:    &key,
   199  	})
   200  	if err != nil {
   201  		t.Fatalf("failed to get object: %v", err)
   202  	}
   203  
   204  	actual, err := ioutil.ReadAll(output.Body)
   205  	if err != nil {
   206  		t.Fatalf("failed to read body response: %v", err)
   207  	}
   208  
   209  	if bytes.Compare(expected, actual) != 0 {
   210  		t.Errorf("expected bytes did not match actual")
   211  	}
   212  }
   213  
   214  func doKeyCleanup(key string) func() {
   215  	return func() {
   216  		_, err := config.Clients.S3.DeleteObject(&s3.DeleteObjectInput{
   217  			Bucket: &config.Bucket,
   218  			Key:    &key,
   219  		})
   220  		if err != nil {
   221  			log.Printf("failed to delete %s: %v", key, err)
   222  		}
   223  	}
   224  }