github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/vendor_skip/go.mongodb.org/mongo-driver/x/mongo/driver/crypt.go (about)

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package driver
     8  
     9  import (
    10  	"context"
    11  	"crypto/tls"
    12  	"fmt"
    13  	"io"
    14  	"strings"
    15  	"time"
    16  
    17  	"go.mongodb.org/mongo-driver/bson/bsontype"
    18  	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
    19  	"go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt"
    20  	"go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options"
    21  )
    22  
    23  const (
    24  	defaultKmsPort    = 443
    25  	defaultKmsTimeout = 10 * time.Second
    26  )
    27  
    28  // CollectionInfoFn is a callback used to retrieve collection information.
    29  type CollectionInfoFn func(ctx context.Context, db string, filter bsoncore.Document) (bsoncore.Document, error)
    30  
    31  // KeyRetrieverFn is a callback used to retrieve keys from the key vault.
    32  type KeyRetrieverFn func(ctx context.Context, filter bsoncore.Document) ([]bsoncore.Document, error)
    33  
    34  // MarkCommandFn is a callback used to add encryption markings to a command.
    35  type MarkCommandFn func(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error)
    36  
    37  // CryptOptions specifies options to configure a Crypt instance.
    38  type CryptOptions struct {
    39  	MongoCrypt           *mongocrypt.MongoCrypt
    40  	CollInfoFn           CollectionInfoFn
    41  	KeyFn                KeyRetrieverFn
    42  	MarkFn               MarkCommandFn
    43  	TLSConfig            map[string]*tls.Config
    44  	BypassAutoEncryption bool
    45  	BypassQueryAnalysis  bool
    46  }
    47  
    48  // Crypt is an interface implemented by types that can encrypt and decrypt instances of
    49  // bsoncore.Document.
    50  //
    51  // Users should rely on the driver's crypt type (used by default) for encryption and decryption
    52  // unless they are perfectly confident in another implementation of Crypt.
    53  type Crypt interface {
    54  	// Encrypt encrypts the given command.
    55  	Encrypt(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error)
    56  	// Decrypt decrypts the given command response.
    57  	Decrypt(ctx context.Context, cmdResponse bsoncore.Document) (bsoncore.Document, error)
    58  	// CreateDataKey creates a data key using the given KMS provider and options.
    59  	CreateDataKey(ctx context.Context, kmsProvider string, opts *options.DataKeyOptions) (bsoncore.Document, error)
    60  	// EncryptExplicit encrypts the given value with the given options.
    61  	EncryptExplicit(ctx context.Context, val bsoncore.Value, opts *options.ExplicitEncryptionOptions) (byte, []byte, error)
    62  	// EncryptExplicitExpression encrypts the given expression with the given options.
    63  	EncryptExplicitExpression(ctx context.Context, val bsoncore.Document, opts *options.ExplicitEncryptionOptions) (bsoncore.Document, error)
    64  	// DecryptExplicit decrypts the given encrypted value.
    65  	DecryptExplicit(ctx context.Context, subtype byte, data []byte) (bsoncore.Value, error)
    66  	// Close cleans up any resources associated with the Crypt instance.
    67  	Close()
    68  	// BypassAutoEncryption returns true if auto-encryption should be bypassed.
    69  	BypassAutoEncryption() bool
    70  	// RewrapDataKey attempts to rewrap the document data keys matching the filter, preparing the re-wrapped documents
    71  	// to be returned as a slice of bsoncore.Document.
    72  	RewrapDataKey(ctx context.Context, filter []byte, opts *options.RewrapManyDataKeyOptions) ([]bsoncore.Document, error)
    73  }
    74  
    75  // crypt consumes the libmongocrypt.MongoCrypt type to iterate the mongocrypt state machine and perform encryption
    76  // and decryption.
    77  type crypt struct {
    78  	mongoCrypt *mongocrypt.MongoCrypt
    79  	collInfoFn CollectionInfoFn
    80  	keyFn      KeyRetrieverFn
    81  	markFn     MarkCommandFn
    82  	tlsConfig  map[string]*tls.Config
    83  
    84  	bypassAutoEncryption bool
    85  }
    86  
    87  // NewCrypt creates a new Crypt instance configured with the given AutoEncryptionOptions.
    88  func NewCrypt(opts *CryptOptions) Crypt {
    89  	c := &crypt{
    90  		mongoCrypt:           opts.MongoCrypt,
    91  		collInfoFn:           opts.CollInfoFn,
    92  		keyFn:                opts.KeyFn,
    93  		markFn:               opts.MarkFn,
    94  		tlsConfig:            opts.TLSConfig,
    95  		bypassAutoEncryption: opts.BypassAutoEncryption,
    96  	}
    97  	return c
    98  }
    99  
   100  // Encrypt encrypts the given command.
   101  func (c *crypt) Encrypt(ctx context.Context, db string, cmd bsoncore.Document) (bsoncore.Document, error) {
   102  	if c.bypassAutoEncryption {
   103  		return cmd, nil
   104  	}
   105  
   106  	cryptCtx, err := c.mongoCrypt.CreateEncryptionContext(db, cmd)
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  	defer cryptCtx.Close()
   111  
   112  	return c.executeStateMachine(ctx, cryptCtx, db)
   113  }
   114  
   115  // Decrypt decrypts the given command response.
   116  func (c *crypt) Decrypt(ctx context.Context, cmdResponse bsoncore.Document) (bsoncore.Document, error) {
   117  	cryptCtx, err := c.mongoCrypt.CreateDecryptionContext(cmdResponse)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  	defer cryptCtx.Close()
   122  
   123  	return c.executeStateMachine(ctx, cryptCtx, "")
   124  }
   125  
   126  // CreateDataKey creates a data key using the given KMS provider and options.
   127  func (c *crypt) CreateDataKey(ctx context.Context, kmsProvider string, opts *options.DataKeyOptions) (bsoncore.Document, error) {
   128  	cryptCtx, err := c.mongoCrypt.CreateDataKeyContext(kmsProvider, opts)
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  	defer cryptCtx.Close()
   133  
   134  	return c.executeStateMachine(ctx, cryptCtx, "")
   135  }
   136  
   137  // RewrapDataKey attempts to rewrap the document data keys matching the filter, preparing the re-wrapped documents to
   138  // be returned as a slice of bsoncore.Document.
   139  func (c *crypt) RewrapDataKey(ctx context.Context, filter []byte,
   140  	opts *options.RewrapManyDataKeyOptions) ([]bsoncore.Document, error) {
   141  
   142  	cryptCtx, err := c.mongoCrypt.RewrapDataKeyContext(filter, opts)
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  	defer cryptCtx.Close()
   147  
   148  	rewrappedBSON, err := c.executeStateMachine(ctx, cryptCtx, "")
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  	if rewrappedBSON == nil {
   153  		return nil, nil
   154  	}
   155  
   156  	// mongocrypt_ctx_rewrap_many_datakey_init wraps the documents in a BSON of the form { "v": [(BSON document), ...] }
   157  	// where each BSON document in the slice is a document containing a rewrapped datakey.
   158  	rewrappedDocumentBytes, err := rewrappedBSON.LookupErr("v")
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  
   163  	// Parse the resulting BSON as individual documents.
   164  	rewrappedDocsArray, ok := rewrappedDocumentBytes.ArrayOK()
   165  	if !ok {
   166  		return nil, fmt.Errorf("expected results from mongocrypt_ctx_rewrap_many_datakey_init to be an array")
   167  	}
   168  
   169  	rewrappedDocumentValues, err := rewrappedDocsArray.Values()
   170  	if err != nil {
   171  		return nil, err
   172  	}
   173  
   174  	rewrappedDocuments := []bsoncore.Document{}
   175  	for _, rewrappedDocumentValue := range rewrappedDocumentValues {
   176  		if rewrappedDocumentValue.Type != bsontype.EmbeddedDocument {
   177  			// If a value in the document's array returned by mongocrypt is anything other than an embedded document,
   178  			// then something is wrong and we should terminate the routine.
   179  			return nil, fmt.Errorf("expected value of type %q, got: %q",
   180  				bsontype.EmbeddedDocument.String(),
   181  				rewrappedDocumentValue.Type.String())
   182  		}
   183  		rewrappedDocuments = append(rewrappedDocuments, rewrappedDocumentValue.Document())
   184  	}
   185  	return rewrappedDocuments, nil
   186  }
   187  
   188  // EncryptExplicit encrypts the given value with the given options.
   189  func (c *crypt) EncryptExplicit(ctx context.Context, val bsoncore.Value, opts *options.ExplicitEncryptionOptions) (byte, []byte, error) {
   190  	idx, doc := bsoncore.AppendDocumentStart(nil)
   191  	doc = bsoncore.AppendValueElement(doc, "v", val)
   192  	doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
   193  
   194  	cryptCtx, err := c.mongoCrypt.CreateExplicitEncryptionContext(doc, opts)
   195  	if err != nil {
   196  		return 0, nil, err
   197  	}
   198  	defer cryptCtx.Close()
   199  
   200  	res, err := c.executeStateMachine(ctx, cryptCtx, "")
   201  	if err != nil {
   202  		return 0, nil, err
   203  	}
   204  
   205  	sub, data := res.Lookup("v").Binary()
   206  	return sub, data, nil
   207  }
   208  
   209  // EncryptExplicitExpression encrypts the given expression with the given options.
   210  func (c *crypt) EncryptExplicitExpression(ctx context.Context, expr bsoncore.Document, opts *options.ExplicitEncryptionOptions) (bsoncore.Document, error) {
   211  	idx, doc := bsoncore.AppendDocumentStart(nil)
   212  	doc = bsoncore.AppendDocumentElement(doc, "v", expr)
   213  	doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
   214  
   215  	cryptCtx, err := c.mongoCrypt.CreateExplicitEncryptionExpressionContext(doc, opts)
   216  	if err != nil {
   217  		return nil, err
   218  	}
   219  	defer cryptCtx.Close()
   220  
   221  	res, err := c.executeStateMachine(ctx, cryptCtx, "")
   222  	if err != nil {
   223  		return nil, err
   224  	}
   225  
   226  	encryptedExpr := res.Lookup("v").Document()
   227  	return encryptedExpr, nil
   228  }
   229  
   230  // DecryptExplicit decrypts the given encrypted value.
   231  func (c *crypt) DecryptExplicit(ctx context.Context, subtype byte, data []byte) (bsoncore.Value, error) {
   232  	idx, doc := bsoncore.AppendDocumentStart(nil)
   233  	doc = bsoncore.AppendBinaryElement(doc, "v", subtype, data)
   234  	doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
   235  
   236  	cryptCtx, err := c.mongoCrypt.CreateExplicitDecryptionContext(doc)
   237  	if err != nil {
   238  		return bsoncore.Value{}, err
   239  	}
   240  	defer cryptCtx.Close()
   241  
   242  	res, err := c.executeStateMachine(ctx, cryptCtx, "")
   243  	if err != nil {
   244  		return bsoncore.Value{}, err
   245  	}
   246  
   247  	return res.Lookup("v"), nil
   248  }
   249  
   250  // Close cleans up any resources associated with the Crypt instance.
   251  func (c *crypt) Close() {
   252  	c.mongoCrypt.Close()
   253  }
   254  
   255  func (c *crypt) BypassAutoEncryption() bool {
   256  	return c.bypassAutoEncryption
   257  }
   258  
   259  func (c *crypt) executeStateMachine(ctx context.Context, cryptCtx *mongocrypt.Context, db string) (bsoncore.Document, error) {
   260  	var err error
   261  	for {
   262  		state := cryptCtx.State()
   263  		switch state {
   264  		case mongocrypt.NeedMongoCollInfo:
   265  			err = c.collectionInfo(ctx, cryptCtx, db)
   266  		case mongocrypt.NeedMongoMarkings:
   267  			err = c.markCommand(ctx, cryptCtx, db)
   268  		case mongocrypt.NeedMongoKeys:
   269  			err = c.retrieveKeys(ctx, cryptCtx)
   270  		case mongocrypt.NeedKms:
   271  			err = c.decryptKeys(cryptCtx)
   272  		case mongocrypt.Ready:
   273  			return cryptCtx.Finish()
   274  		case mongocrypt.Done:
   275  			return nil, nil
   276  		case mongocrypt.NeedKmsCredentials:
   277  			err = c.provideKmsProviders(ctx, cryptCtx)
   278  		default:
   279  			return nil, fmt.Errorf("invalid Crypt state: %v", state)
   280  		}
   281  		if err != nil {
   282  			return nil, err
   283  		}
   284  	}
   285  }
   286  
   287  func (c *crypt) collectionInfo(ctx context.Context, cryptCtx *mongocrypt.Context, db string) error {
   288  	op, err := cryptCtx.NextOperation()
   289  	if err != nil {
   290  		return err
   291  	}
   292  
   293  	collInfo, err := c.collInfoFn(ctx, db, op)
   294  	if err != nil {
   295  		return err
   296  	}
   297  	if collInfo != nil {
   298  		if err = cryptCtx.AddOperationResult(collInfo); err != nil {
   299  			return err
   300  		}
   301  	}
   302  
   303  	return cryptCtx.CompleteOperation()
   304  }
   305  
   306  func (c *crypt) markCommand(ctx context.Context, cryptCtx *mongocrypt.Context, db string) error {
   307  	op, err := cryptCtx.NextOperation()
   308  	if err != nil {
   309  		return err
   310  	}
   311  
   312  	markedCmd, err := c.markFn(ctx, db, op)
   313  	if err != nil {
   314  		return err
   315  	}
   316  	if err = cryptCtx.AddOperationResult(markedCmd); err != nil {
   317  		return err
   318  	}
   319  
   320  	return cryptCtx.CompleteOperation()
   321  }
   322  
   323  func (c *crypt) retrieveKeys(ctx context.Context, cryptCtx *mongocrypt.Context) error {
   324  	op, err := cryptCtx.NextOperation()
   325  	if err != nil {
   326  		return err
   327  	}
   328  
   329  	keys, err := c.keyFn(ctx, op)
   330  	if err != nil {
   331  		return err
   332  	}
   333  
   334  	for _, key := range keys {
   335  		if err = cryptCtx.AddOperationResult(key); err != nil {
   336  			return err
   337  		}
   338  	}
   339  
   340  	return cryptCtx.CompleteOperation()
   341  }
   342  
   343  func (c *crypt) decryptKeys(cryptCtx *mongocrypt.Context) error {
   344  	for {
   345  		kmsCtx := cryptCtx.NextKmsContext()
   346  		if kmsCtx == nil {
   347  			break
   348  		}
   349  
   350  		if err := c.decryptKey(kmsCtx); err != nil {
   351  			return err
   352  		}
   353  	}
   354  
   355  	return cryptCtx.FinishKmsContexts()
   356  }
   357  
   358  func (c *crypt) decryptKey(kmsCtx *mongocrypt.KmsContext) error {
   359  	host, err := kmsCtx.HostName()
   360  	if err != nil {
   361  		return err
   362  	}
   363  	msg, err := kmsCtx.Message()
   364  	if err != nil {
   365  		return err
   366  	}
   367  
   368  	// add a port to the address if it's not already present
   369  	addr := host
   370  	if idx := strings.IndexByte(host, ':'); idx == -1 {
   371  		addr = fmt.Sprintf("%s:%d", host, defaultKmsPort)
   372  	}
   373  
   374  	kmsProvider := kmsCtx.KMSProvider()
   375  	tlsCfg := c.tlsConfig[kmsProvider]
   376  	if tlsCfg == nil {
   377  		tlsCfg = &tls.Config{MinVersion: tls.VersionTLS12}
   378  	}
   379  	conn, err := tls.Dial("tcp", addr, tlsCfg)
   380  	if err != nil {
   381  		return err
   382  	}
   383  	defer func() {
   384  		_ = conn.Close()
   385  	}()
   386  
   387  	if err = conn.SetWriteDeadline(time.Now().Add(defaultKmsTimeout)); err != nil {
   388  		return err
   389  	}
   390  	if _, err = conn.Write(msg); err != nil {
   391  		return err
   392  	}
   393  
   394  	for {
   395  		bytesNeeded := kmsCtx.BytesNeeded()
   396  		if bytesNeeded == 0 {
   397  			return nil
   398  		}
   399  
   400  		res := make([]byte, bytesNeeded)
   401  		bytesRead, err := conn.Read(res)
   402  		if err != nil && err != io.EOF {
   403  			return err
   404  		}
   405  
   406  		if err = kmsCtx.FeedResponse(res[:bytesRead]); err != nil {
   407  			return err
   408  		}
   409  	}
   410  }
   411  
   412  func (c *crypt) provideKmsProviders(ctx context.Context, cryptCtx *mongocrypt.Context) error {
   413  	kmsProviders, err := c.mongoCrypt.GetKmsProviders(ctx)
   414  	if err != nil {
   415  		return err
   416  	}
   417  	return cryptCtx.ProvideKmsProviders(kmsProviders)
   418  }