github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/vendor_skip/go.mongodb.org/mongo-driver/mongo/bulk_write.go (about)

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package mongo
     8  
     9  import (
    10  	"context"
    11  
    12  	"go.mongodb.org/mongo-driver/bson/bsoncodec"
    13  	"go.mongodb.org/mongo-driver/bson/primitive"
    14  	"go.mongodb.org/mongo-driver/mongo/description"
    15  	"go.mongodb.org/mongo-driver/mongo/options"
    16  	"go.mongodb.org/mongo-driver/mongo/writeconcern"
    17  	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
    18  	"go.mongodb.org/mongo-driver/x/mongo/driver"
    19  	"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
    20  	"go.mongodb.org/mongo-driver/x/mongo/driver/session"
    21  )
    22  
    23  type bulkWriteBatch struct {
    24  	models   []WriteModel
    25  	canRetry bool
    26  	indexes  []int
    27  }
    28  
    29  // bulkWrite perfoms a bulkwrite operation
    30  type bulkWrite struct {
    31  	comment                  interface{}
    32  	ordered                  *bool
    33  	bypassDocumentValidation *bool
    34  	models                   []WriteModel
    35  	session                  *session.Client
    36  	collection               *Collection
    37  	selector                 description.ServerSelector
    38  	writeConcern             *writeconcern.WriteConcern
    39  	result                   BulkWriteResult
    40  	let                      interface{}
    41  }
    42  
    43  func (bw *bulkWrite) execute(ctx context.Context) error {
    44  	ordered := true
    45  	if bw.ordered != nil {
    46  		ordered = *bw.ordered
    47  	}
    48  
    49  	batches := createBatches(bw.models, ordered)
    50  	bw.result = BulkWriteResult{
    51  		UpsertedIDs: make(map[int64]interface{}),
    52  	}
    53  
    54  	bwErr := BulkWriteException{
    55  		WriteErrors: make([]BulkWriteError, 0),
    56  	}
    57  
    58  	var lastErr error
    59  	continueOnError := !ordered
    60  	for _, batch := range batches {
    61  		if len(batch.models) == 0 {
    62  			continue
    63  		}
    64  
    65  		batchRes, batchErr, err := bw.runBatch(ctx, batch)
    66  
    67  		bw.mergeResults(batchRes)
    68  
    69  		bwErr.WriteConcernError = batchErr.WriteConcernError
    70  		bwErr.Labels = append(bwErr.Labels, batchErr.Labels...)
    71  
    72  		bwErr.WriteErrors = append(bwErr.WriteErrors, batchErr.WriteErrors...)
    73  
    74  		commandErrorOccurred := err != nil && err != driver.ErrUnacknowledgedWrite
    75  		writeErrorOccurred := len(batchErr.WriteErrors) > 0 || batchErr.WriteConcernError != nil
    76  		if !continueOnError && (commandErrorOccurred || writeErrorOccurred) {
    77  			if err != nil {
    78  				return err
    79  			}
    80  
    81  			return bwErr
    82  		}
    83  
    84  		if err != nil {
    85  			lastErr = err
    86  		}
    87  	}
    88  
    89  	bw.result.MatchedCount -= bw.result.UpsertedCount
    90  	if lastErr != nil {
    91  		_, lastErr = processWriteError(lastErr)
    92  		return lastErr
    93  	}
    94  	if len(bwErr.WriteErrors) > 0 || bwErr.WriteConcernError != nil {
    95  		return bwErr
    96  	}
    97  	return nil
    98  }
    99  
   100  func (bw *bulkWrite) runBatch(ctx context.Context, batch bulkWriteBatch) (BulkWriteResult, BulkWriteException, error) {
   101  	batchRes := BulkWriteResult{
   102  		UpsertedIDs: make(map[int64]interface{}),
   103  	}
   104  	batchErr := BulkWriteException{}
   105  
   106  	var writeErrors []driver.WriteError
   107  	switch batch.models[0].(type) {
   108  	case *InsertOneModel:
   109  		res, err := bw.runInsert(ctx, batch)
   110  		if err != nil {
   111  			writeErr, ok := err.(driver.WriteCommandError)
   112  			if !ok {
   113  				return BulkWriteResult{}, batchErr, err
   114  			}
   115  			writeErrors = writeErr.WriteErrors
   116  			batchErr.Labels = writeErr.Labels
   117  			batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError)
   118  		}
   119  		batchRes.InsertedCount = res.N
   120  	case *DeleteOneModel, *DeleteManyModel:
   121  		res, err := bw.runDelete(ctx, batch)
   122  		if err != nil {
   123  			writeErr, ok := err.(driver.WriteCommandError)
   124  			if !ok {
   125  				return BulkWriteResult{}, batchErr, err
   126  			}
   127  			writeErrors = writeErr.WriteErrors
   128  			batchErr.Labels = writeErr.Labels
   129  			batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError)
   130  		}
   131  		batchRes.DeletedCount = res.N
   132  	case *ReplaceOneModel, *UpdateOneModel, *UpdateManyModel:
   133  		res, err := bw.runUpdate(ctx, batch)
   134  		if err != nil {
   135  			writeErr, ok := err.(driver.WriteCommandError)
   136  			if !ok {
   137  				return BulkWriteResult{}, batchErr, err
   138  			}
   139  			writeErrors = writeErr.WriteErrors
   140  			batchErr.Labels = writeErr.Labels
   141  			batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError)
   142  		}
   143  		batchRes.MatchedCount = res.N
   144  		batchRes.ModifiedCount = res.NModified
   145  		batchRes.UpsertedCount = int64(len(res.Upserted))
   146  		for _, upsert := range res.Upserted {
   147  			batchRes.UpsertedIDs[int64(batch.indexes[upsert.Index])] = upsert.ID
   148  		}
   149  	}
   150  
   151  	batchErr.WriteErrors = make([]BulkWriteError, 0, len(writeErrors))
   152  	convWriteErrors := writeErrorsFromDriverWriteErrors(writeErrors)
   153  	for _, we := range convWriteErrors {
   154  		request := batch.models[we.Index]
   155  		we.Index = batch.indexes[we.Index]
   156  		batchErr.WriteErrors = append(batchErr.WriteErrors, BulkWriteError{
   157  			WriteError: we,
   158  			Request:    request,
   159  		})
   160  	}
   161  	return batchRes, batchErr, nil
   162  }
   163  
   164  func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (operation.InsertResult, error) {
   165  	docs := make([]bsoncore.Document, len(batch.models))
   166  	var i int
   167  	for _, model := range batch.models {
   168  		converted := model.(*InsertOneModel)
   169  		doc, err := marshal(converted.Document, bw.collection.bsonOpts, bw.collection.registry)
   170  		if err != nil {
   171  			return operation.InsertResult{}, err
   172  		}
   173  		doc, _, err = ensureID(doc, primitive.NewObjectID(), bw.collection.bsonOpts, bw.collection.registry)
   174  		if err != nil {
   175  			return operation.InsertResult{}, err
   176  		}
   177  
   178  		docs[i] = doc
   179  		i++
   180  	}
   181  
   182  	op := operation.NewInsert(docs...).
   183  		Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor).
   184  		ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock).
   185  		Database(bw.collection.db.name).Collection(bw.collection.name).
   186  		Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE).
   187  		ServerAPI(bw.collection.client.serverAPI).Timeout(bw.collection.client.timeout).
   188  		Logger(bw.collection.client.logger)
   189  	if bw.comment != nil {
   190  		comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry)
   191  		if err != nil {
   192  			return op.Result(), err
   193  		}
   194  		op.Comment(comment)
   195  	}
   196  	if bw.bypassDocumentValidation != nil && *bw.bypassDocumentValidation {
   197  		op = op.BypassDocumentValidation(*bw.bypassDocumentValidation)
   198  	}
   199  	if bw.ordered != nil {
   200  		op = op.Ordered(*bw.ordered)
   201  	}
   202  
   203  	retry := driver.RetryNone
   204  	if bw.collection.client.retryWrites && batch.canRetry {
   205  		retry = driver.RetryOncePerCommand
   206  	}
   207  	op = op.Retry(retry)
   208  
   209  	err := op.Execute(ctx)
   210  
   211  	return op.Result(), err
   212  }
   213  
   214  func (bw *bulkWrite) runDelete(ctx context.Context, batch bulkWriteBatch) (operation.DeleteResult, error) {
   215  	docs := make([]bsoncore.Document, len(batch.models))
   216  	var i int
   217  	var hasHint bool
   218  
   219  	for _, model := range batch.models {
   220  		var doc bsoncore.Document
   221  		var err error
   222  
   223  		switch converted := model.(type) {
   224  		case *DeleteOneModel:
   225  			doc, err = createDeleteDoc(
   226  				converted.Filter,
   227  				converted.Collation,
   228  				converted.Hint,
   229  				true,
   230  				bw.collection.bsonOpts,
   231  				bw.collection.registry)
   232  			hasHint = hasHint || (converted.Hint != nil)
   233  		case *DeleteManyModel:
   234  			doc, err = createDeleteDoc(
   235  				converted.Filter,
   236  				converted.Collation,
   237  				converted.Hint,
   238  				false,
   239  				bw.collection.bsonOpts,
   240  				bw.collection.registry)
   241  			hasHint = hasHint || (converted.Hint != nil)
   242  		}
   243  
   244  		if err != nil {
   245  			return operation.DeleteResult{}, err
   246  		}
   247  
   248  		docs[i] = doc
   249  		i++
   250  	}
   251  
   252  	op := operation.NewDelete(docs...).
   253  		Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor).
   254  		ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock).
   255  		Database(bw.collection.db.name).Collection(bw.collection.name).
   256  		Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE).Hint(hasHint).
   257  		ServerAPI(bw.collection.client.serverAPI).Timeout(bw.collection.client.timeout).
   258  		Logger(bw.collection.client.logger)
   259  	if bw.comment != nil {
   260  		comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry)
   261  		if err != nil {
   262  			return op.Result(), err
   263  		}
   264  		op.Comment(comment)
   265  	}
   266  	if bw.let != nil {
   267  		let, err := marshal(bw.let, bw.collection.bsonOpts, bw.collection.registry)
   268  		if err != nil {
   269  			return operation.DeleteResult{}, err
   270  		}
   271  		op = op.Let(let)
   272  	}
   273  	if bw.ordered != nil {
   274  		op = op.Ordered(*bw.ordered)
   275  	}
   276  	retry := driver.RetryNone
   277  	if bw.collection.client.retryWrites && batch.canRetry {
   278  		retry = driver.RetryOncePerCommand
   279  	}
   280  	op = op.Retry(retry)
   281  
   282  	err := op.Execute(ctx)
   283  
   284  	return op.Result(), err
   285  }
   286  
   287  func createDeleteDoc(
   288  	filter interface{},
   289  	collation *options.Collation,
   290  	hint interface{},
   291  	deleteOne bool,
   292  	bsonOpts *options.BSONOptions,
   293  	registry *bsoncodec.Registry,
   294  ) (bsoncore.Document, error) {
   295  	f, err := marshal(filter, bsonOpts, registry)
   296  	if err != nil {
   297  		return nil, err
   298  	}
   299  
   300  	var limit int32
   301  	if deleteOne {
   302  		limit = 1
   303  	}
   304  	didx, doc := bsoncore.AppendDocumentStart(nil)
   305  	doc = bsoncore.AppendDocumentElement(doc, "q", f)
   306  	doc = bsoncore.AppendInt32Element(doc, "limit", limit)
   307  	if collation != nil {
   308  		doc = bsoncore.AppendDocumentElement(doc, "collation", collation.ToDocument())
   309  	}
   310  	if hint != nil {
   311  		if isUnorderedMap(hint) {
   312  			return nil, ErrMapForOrderedArgument{"hint"}
   313  		}
   314  		hintVal, err := marshalValue(hint, bsonOpts, registry)
   315  		if err != nil {
   316  			return nil, err
   317  		}
   318  		doc = bsoncore.AppendValueElement(doc, "hint", hintVal)
   319  	}
   320  	doc, _ = bsoncore.AppendDocumentEnd(doc, didx)
   321  
   322  	return doc, nil
   323  }
   324  
   325  func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (operation.UpdateResult, error) {
   326  	docs := make([]bsoncore.Document, len(batch.models))
   327  	var hasHint bool
   328  	var hasArrayFilters bool
   329  	for i, model := range batch.models {
   330  		var doc bsoncore.Document
   331  		var err error
   332  
   333  		switch converted := model.(type) {
   334  		case *ReplaceOneModel:
   335  			doc, err = createUpdateDoc(
   336  				converted.Filter,
   337  				converted.Replacement,
   338  				converted.Hint,
   339  				nil,
   340  				converted.Collation,
   341  				converted.Upsert,
   342  				false,
   343  				false,
   344  				bw.collection.bsonOpts,
   345  				bw.collection.registry)
   346  			hasHint = hasHint || (converted.Hint != nil)
   347  		case *UpdateOneModel:
   348  			doc, err = createUpdateDoc(
   349  				converted.Filter,
   350  				converted.Update,
   351  				converted.Hint,
   352  				converted.ArrayFilters,
   353  				converted.Collation,
   354  				converted.Upsert,
   355  				false,
   356  				true,
   357  				bw.collection.bsonOpts,
   358  				bw.collection.registry)
   359  			hasHint = hasHint || (converted.Hint != nil)
   360  			hasArrayFilters = hasArrayFilters || (converted.ArrayFilters != nil)
   361  		case *UpdateManyModel:
   362  			doc, err = createUpdateDoc(
   363  				converted.Filter,
   364  				converted.Update,
   365  				converted.Hint,
   366  				converted.ArrayFilters,
   367  				converted.Collation,
   368  				converted.Upsert,
   369  				true,
   370  				true,
   371  				bw.collection.bsonOpts,
   372  				bw.collection.registry)
   373  			hasHint = hasHint || (converted.Hint != nil)
   374  			hasArrayFilters = hasArrayFilters || (converted.ArrayFilters != nil)
   375  		}
   376  		if err != nil {
   377  			return operation.UpdateResult{}, err
   378  		}
   379  
   380  		docs[i] = doc
   381  	}
   382  
   383  	op := operation.NewUpdate(docs...).
   384  		Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor).
   385  		ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock).
   386  		Database(bw.collection.db.name).Collection(bw.collection.name).
   387  		Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE).Hint(hasHint).
   388  		ArrayFilters(hasArrayFilters).ServerAPI(bw.collection.client.serverAPI).
   389  		Timeout(bw.collection.client.timeout).Logger(bw.collection.client.logger)
   390  	if bw.comment != nil {
   391  		comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry)
   392  		if err != nil {
   393  			return op.Result(), err
   394  		}
   395  		op.Comment(comment)
   396  	}
   397  	if bw.let != nil {
   398  		let, err := marshal(bw.let, bw.collection.bsonOpts, bw.collection.registry)
   399  		if err != nil {
   400  			return operation.UpdateResult{}, err
   401  		}
   402  		op = op.Let(let)
   403  	}
   404  	if bw.ordered != nil {
   405  		op = op.Ordered(*bw.ordered)
   406  	}
   407  	if bw.bypassDocumentValidation != nil && *bw.bypassDocumentValidation {
   408  		op = op.BypassDocumentValidation(*bw.bypassDocumentValidation)
   409  	}
   410  	retry := driver.RetryNone
   411  	if bw.collection.client.retryWrites && batch.canRetry {
   412  		retry = driver.RetryOncePerCommand
   413  	}
   414  	op = op.Retry(retry)
   415  
   416  	err := op.Execute(ctx)
   417  
   418  	return op.Result(), err
   419  }
   420  
   421  func createUpdateDoc(
   422  	filter interface{},
   423  	update interface{},
   424  	hint interface{},
   425  	arrayFilters *options.ArrayFilters,
   426  	collation *options.Collation,
   427  	upsert *bool,
   428  	multi bool,
   429  	checkDollarKey bool,
   430  	bsonOpts *options.BSONOptions,
   431  	registry *bsoncodec.Registry,
   432  ) (bsoncore.Document, error) {
   433  	f, err := marshal(filter, bsonOpts, registry)
   434  	if err != nil {
   435  		return nil, err
   436  	}
   437  
   438  	uidx, updateDoc := bsoncore.AppendDocumentStart(nil)
   439  	updateDoc = bsoncore.AppendDocumentElement(updateDoc, "q", f)
   440  
   441  	u, err := marshalUpdateValue(update, bsonOpts, registry, checkDollarKey)
   442  	if err != nil {
   443  		return nil, err
   444  	}
   445  
   446  	updateDoc = bsoncore.AppendValueElement(updateDoc, "u", u)
   447  
   448  	if multi {
   449  		updateDoc = bsoncore.AppendBooleanElement(updateDoc, "multi", multi)
   450  	}
   451  
   452  	if arrayFilters != nil {
   453  		reg := registry
   454  		if arrayFilters.Registry != nil {
   455  			reg = arrayFilters.Registry
   456  		}
   457  		arr, err := marshalValue(arrayFilters.Filters, bsonOpts, reg)
   458  		if err != nil {
   459  			return nil, err
   460  		}
   461  		updateDoc = bsoncore.AppendArrayElement(updateDoc, "arrayFilters", arr.Data)
   462  	}
   463  
   464  	if collation != nil {
   465  		updateDoc = bsoncore.AppendDocumentElement(updateDoc, "collation", bsoncore.Document(collation.ToDocument()))
   466  	}
   467  
   468  	if upsert != nil {
   469  		updateDoc = bsoncore.AppendBooleanElement(updateDoc, "upsert", *upsert)
   470  	}
   471  
   472  	if hint != nil {
   473  		if isUnorderedMap(hint) {
   474  			return nil, ErrMapForOrderedArgument{"hint"}
   475  		}
   476  		hintVal, err := marshalValue(hint, bsonOpts, registry)
   477  		if err != nil {
   478  			return nil, err
   479  		}
   480  		updateDoc = bsoncore.AppendValueElement(updateDoc, "hint", hintVal)
   481  	}
   482  
   483  	updateDoc, _ = bsoncore.AppendDocumentEnd(updateDoc, uidx)
   484  	return updateDoc, nil
   485  }
   486  
   487  func createBatches(models []WriteModel, ordered bool) []bulkWriteBatch {
   488  	if ordered {
   489  		return createOrderedBatches(models)
   490  	}
   491  
   492  	batches := make([]bulkWriteBatch, 5)
   493  	batches[insertCommand].canRetry = true
   494  	batches[deleteOneCommand].canRetry = true
   495  	batches[updateOneCommand].canRetry = true
   496  
   497  	// TODO(GODRIVER-1157): fix batching once operation retryability is fixed
   498  	for i, model := range models {
   499  		switch model.(type) {
   500  		case *InsertOneModel:
   501  			batches[insertCommand].models = append(batches[insertCommand].models, model)
   502  			batches[insertCommand].indexes = append(batches[insertCommand].indexes, i)
   503  		case *DeleteOneModel:
   504  			batches[deleteOneCommand].models = append(batches[deleteOneCommand].models, model)
   505  			batches[deleteOneCommand].indexes = append(batches[deleteOneCommand].indexes, i)
   506  		case *DeleteManyModel:
   507  			batches[deleteManyCommand].models = append(batches[deleteManyCommand].models, model)
   508  			batches[deleteManyCommand].indexes = append(batches[deleteManyCommand].indexes, i)
   509  		case *ReplaceOneModel, *UpdateOneModel:
   510  			batches[updateOneCommand].models = append(batches[updateOneCommand].models, model)
   511  			batches[updateOneCommand].indexes = append(batches[updateOneCommand].indexes, i)
   512  		case *UpdateManyModel:
   513  			batches[updateManyCommand].models = append(batches[updateManyCommand].models, model)
   514  			batches[updateManyCommand].indexes = append(batches[updateManyCommand].indexes, i)
   515  		}
   516  	}
   517  
   518  	return batches
   519  }
   520  
   521  func createOrderedBatches(models []WriteModel) []bulkWriteBatch {
   522  	var batches []bulkWriteBatch
   523  	var prevKind writeCommandKind = -1
   524  	i := -1 // batch index
   525  
   526  	for ind, model := range models {
   527  		var createNewBatch bool
   528  		var canRetry bool
   529  		var newKind writeCommandKind
   530  
   531  		// TODO(GODRIVER-1157): fix batching once operation retryability is fixed
   532  		switch model.(type) {
   533  		case *InsertOneModel:
   534  			createNewBatch = prevKind != insertCommand
   535  			canRetry = true
   536  			newKind = insertCommand
   537  		case *DeleteOneModel:
   538  			createNewBatch = prevKind != deleteOneCommand
   539  			canRetry = true
   540  			newKind = deleteOneCommand
   541  		case *DeleteManyModel:
   542  			createNewBatch = prevKind != deleteManyCommand
   543  			newKind = deleteManyCommand
   544  		case *ReplaceOneModel, *UpdateOneModel:
   545  			createNewBatch = prevKind != updateOneCommand
   546  			canRetry = true
   547  			newKind = updateOneCommand
   548  		case *UpdateManyModel:
   549  			createNewBatch = prevKind != updateManyCommand
   550  			newKind = updateManyCommand
   551  		}
   552  
   553  		if createNewBatch {
   554  			batches = append(batches, bulkWriteBatch{
   555  				models:   []WriteModel{model},
   556  				canRetry: canRetry,
   557  				indexes:  []int{ind},
   558  			})
   559  			i++
   560  		} else {
   561  			batches[i].models = append(batches[i].models, model)
   562  			if !canRetry {
   563  				batches[i].canRetry = false // don't make it true if it was already false
   564  			}
   565  			batches[i].indexes = append(batches[i].indexes, ind)
   566  		}
   567  
   568  		prevKind = newKind
   569  	}
   570  
   571  	return batches
   572  }
   573  
   574  func (bw *bulkWrite) mergeResults(newResult BulkWriteResult) {
   575  	bw.result.InsertedCount += newResult.InsertedCount
   576  	bw.result.MatchedCount += newResult.MatchedCount
   577  	bw.result.ModifiedCount += newResult.ModifiedCount
   578  	bw.result.DeletedCount += newResult.DeletedCount
   579  	bw.result.UpsertedCount += newResult.UpsertedCount
   580  
   581  	for index, upsertID := range newResult.UpsertedIDs {
   582  		bw.result.UpsertedIDs[index] = upsertID
   583  	}
   584  }
   585  
   586  // WriteCommandKind is the type of command represented by a Write
   587  type writeCommandKind int8
   588  
   589  // These constants represent the valid types of write commands.
   590  const (
   591  	insertCommand writeCommandKind = iota
   592  	updateOneCommand
   593  	updateManyCommand
   594  	deleteOneCommand
   595  	deleteManyCommand
   596  )