github.com/SaurabhDubey-Groww/go-cloud@v0.0.0-20221124105541-b26c29285fd8/docstore/awsdynamodb/dynamo.go (about)

     1  // Copyright 2019 The Go Cloud Development Kit Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package awsdynamodb provides a docstore implementation backed by Amazon
    16  // DynamoDB.
    17  // Use OpenCollection to construct a *docstore.Collection.
    18  //
    19  // # URLs
    20  //
    21  // For docstore.OpenCollection, awsdynamodb registers for the scheme
    22  // "dynamodb". The default URL opener will use an AWS session with the default
    23  // credentials and configuration; see
    24  // https://docs.aws.amazon.com/sdk-for-go/api/aws/session/ for more details.
    25  // To customize the URL opener, or for more details on the URL format, see
    26  // URLOpener.
    27  // See https://gocloud.dev/concepts/urls/ for background information.
    28  //
    29  // # As
    30  //
    31  // awsdynamodb exposes the following types for As:
    32  //   - Collection.As: *dynamodb.DynamoDB
    33  //   - ActionList.BeforeDo: *dynamodb.BatchGetItemInput or *dynamodb.PutItemInput or *dynamodb.DeleteItemInput
    34  //     or *dynamodb.UpdateItemInput
    35  //   - Query.BeforeQuery: *dynamodb.QueryInput or *dynamodb.ScanInput
    36  //   - DocumentIterator: *dynamodb.QueryOutput or *dynamodb.ScanOutput
    37  //   - ErrorAs: awserr.Error
    38  package awsdynamodb
    39  
    40  import (
    41  	"context"
    42  	"fmt"
    43  	"reflect"
    44  	"strings"
    45  
    46  	"github.com/aws/aws-sdk-go/aws"
    47  	"github.com/aws/aws-sdk-go/aws/awserr"
    48  	dyn "github.com/aws/aws-sdk-go/service/dynamodb"
    49  	"github.com/aws/aws-sdk-go/service/dynamodb/expression"
    50  	"github.com/google/wire"
    51  	"gocloud.dev/docstore"
    52  	"gocloud.dev/docstore/driver"
    53  	"gocloud.dev/gcerrors"
    54  	"gocloud.dev/internal/gcerr"
    55  )
    56  
    57  // Set holds Wire providers for this package.
    58  var Set = wire.NewSet(
    59  	wire.Struct(new(URLOpener), "ConfigProvider"),
    60  )
    61  
    62  type collection struct {
    63  	db           *dyn.DynamoDB
    64  	table        string // DynamoDB table name
    65  	partitionKey string
    66  	sortKey      string
    67  	description  *dyn.TableDescription
    68  	opts         *Options
    69  }
    70  
    71  // FallbackFunc is a function for executing queries that cannot be run by the built-in
    72  // awsdynamodb logic. See Options.RunQueryFunc for details.
    73  type FallbackFunc func(context.Context, *driver.Query, RunQueryFunc) (driver.DocumentIterator, error)
    74  
    75  // Options holds various options.
    76  type Options struct {
    77  	// If false, queries that can only be executed by scanning the entire table
    78  	// return an error instead (with the exception of a query with no filters).
    79  	AllowScans bool
    80  
    81  	// The name of the field holding the document revision.
    82  	// Defaults to docstore.DefaultRevisionField.
    83  	RevisionField string
    84  
    85  	// If set, call this function on queries that we cannot execute at all (for
    86  	// example, a query with an OrderBy clause that lacks an equality filter on a
    87  	// partition key). The function should execute the query however it wishes, and
    88  	// return an iterator over the results. It can use the RunQueryFunc passed as its
    89  	// third argument to have the DynamoDB driver run a query, for instance a
    90  	// modified version of the original query.
    91  	//
    92  	// If RunQueryFallback is nil, queries that cannot be executed will fail with a
    93  	// error that has code Unimplemented.
    94  	RunQueryFallback FallbackFunc
    95  
    96  	// The maximum number of concurrent goroutines started for a single call to
    97  	// ActionList.Do. If less than 1, there is no limit.
    98  	MaxOutstandingActionRPCs int
    99  
   100  	// If true, a strongly consistent read is used whenever possible, including
   101  	// get, query, scan, etc.; default to false, where an eventually consistent
   102  	// read is used.
   103  	//
   104  	// Not all read operations support this mode however, such as querying against
   105  	// a global secondary index, the operation will return an InvalidArgument error
   106  	// in such case, please check the official DynamoDB documentation for more
   107  	// details.
   108  	//
   109  	// The native client for DynamoDB uses this option in a per-action basis, if
   110  	// you need the flexibility to run both modes on the same collection, create
   111  	// two collections with different mode.
   112  	ConsistentRead bool
   113  }
   114  
   115  // RunQueryFunc is the type of the function passed to RunQueryFallback.
   116  type RunQueryFunc func(context.Context, *driver.Query) (driver.DocumentIterator, error)
   117  
   118  // OpenCollection creates a *docstore.Collection representing a DynamoDB collection.
   119  func OpenCollection(db *dyn.DynamoDB, tableName, partitionKey, sortKey string, opts *Options) (*docstore.Collection, error) {
   120  	c, err := newCollection(db, tableName, partitionKey, sortKey, opts)
   121  	if err != nil {
   122  		return nil, err
   123  	}
   124  	return docstore.NewCollection(c), nil
   125  }
   126  
   127  func newCollection(db *dyn.DynamoDB, tableName, partitionKey, sortKey string, opts *Options) (*collection, error) {
   128  	out, err := db.DescribeTable(&dyn.DescribeTableInput{TableName: &tableName})
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  	if opts == nil {
   133  		opts = &Options{}
   134  	}
   135  	if opts.RevisionField == "" {
   136  		opts.RevisionField = docstore.DefaultRevisionField
   137  	}
   138  	return &collection{
   139  		db:           db,
   140  		table:        tableName,
   141  		partitionKey: partitionKey,
   142  		sortKey:      sortKey,
   143  		description:  out.Table,
   144  		opts:         opts,
   145  	}, nil
   146  }
   147  
   148  // Key returns a two-element array with the partition key and sort key, if any.
   149  func (c *collection) Key(doc driver.Document) (interface{}, error) {
   150  	pkey, err := doc.GetField(c.partitionKey)
   151  	if err != nil || pkey == nil || driver.IsEmptyValue(reflect.ValueOf(pkey)) {
   152  		return nil, nil // missing key is not an error
   153  	}
   154  	keys := [2]interface{}{pkey}
   155  	if c.sortKey != "" {
   156  		keys[1], _ = doc.GetField(c.sortKey) // ignore error since keys[1] is nil in that case
   157  	}
   158  	return keys, nil
   159  }
   160  
   161  func (c *collection) RevisionField() string { return c.opts.RevisionField }
   162  
   163  func (c *collection) RunActions(ctx context.Context, actions []*driver.Action, opts *driver.RunActionsOptions) driver.ActionListError {
   164  	errs := make([]error, len(actions))
   165  	beforeGets, gets, writes, afterGets := driver.GroupActions(actions)
   166  	c.runGets(ctx, beforeGets, errs, opts)
   167  	ch := make(chan struct{})
   168  	go func() { defer close(ch); c.runWrites(ctx, writes, errs, opts) }()
   169  	c.runGets(ctx, gets, errs, opts)
   170  	<-ch
   171  	c.runGets(ctx, afterGets, errs, opts)
   172  	return driver.NewActionListError(errs)
   173  }
   174  
   175  func (c *collection) runGets(ctx context.Context, actions []*driver.Action, errs []error, opts *driver.RunActionsOptions) {
   176  	const batchSize = 100
   177  	t := driver.NewThrottle(c.opts.MaxOutstandingActionRPCs)
   178  	for _, group := range driver.GroupByFieldPath(actions) {
   179  		n := len(group) / batchSize
   180  		for i := 0; i < n; i++ {
   181  			i := i
   182  			t.Acquire()
   183  			go func() {
   184  				defer t.Release()
   185  				c.batchGet(ctx, group, errs, opts, batchSize*i, batchSize*(i+1)-1)
   186  			}()
   187  		}
   188  		if n*batchSize < len(group) {
   189  			t.Acquire()
   190  			go func() {
   191  				defer t.Release()
   192  				c.batchGet(ctx, group, errs, opts, batchSize*n, len(group)-1)
   193  			}()
   194  		}
   195  	}
   196  	t.Wait()
   197  }
   198  
   199  func (c *collection) batchGet(ctx context.Context, gets []*driver.Action, errs []error, opts *driver.RunActionsOptions, start, end int) {
   200  	// errors need to be mapped to the actions' indices.
   201  	setErr := func(err error) {
   202  		for i := start; i <= end; i++ {
   203  			errs[gets[i].Index] = err
   204  		}
   205  	}
   206  
   207  	keys := make([]map[string]*dyn.AttributeValue, 0, end-start+1)
   208  	for i := start; i <= end; i++ {
   209  		av, err := encodeDocKeyFields(gets[i].Doc, c.partitionKey, c.sortKey)
   210  		if err != nil {
   211  			errs[gets[i].Index] = err
   212  		}
   213  
   214  		keys = append(keys, av.M)
   215  	}
   216  	ka := &dyn.KeysAndAttributes{
   217  		Keys:           keys,
   218  		ConsistentRead: aws.Bool(c.opts.ConsistentRead),
   219  	}
   220  	if len(gets[start].FieldPaths) != 0 {
   221  		// We need to add the key fields if the user doesn't include them. The
   222  		// BatchGet API doesn't return them otherwise.
   223  		var hasP, hasS bool
   224  		var nbs []expression.NameBuilder
   225  		for _, fp := range gets[start].FieldPaths {
   226  			p := strings.Join(fp, ".")
   227  			nbs = append(nbs, expression.Name(p))
   228  			if p == c.partitionKey {
   229  				hasP = true
   230  			} else if p == c.sortKey {
   231  				hasS = true
   232  			}
   233  		}
   234  		if !hasP {
   235  			nbs = append(nbs, expression.Name(c.partitionKey))
   236  		}
   237  		if c.sortKey != "" && !hasS {
   238  			nbs = append(nbs, expression.Name(c.sortKey))
   239  		}
   240  		expr, err := expression.NewBuilder().
   241  			WithProjection(expression.AddNames(expression.ProjectionBuilder{}, nbs...)).
   242  			Build()
   243  		if err != nil {
   244  			setErr(err)
   245  			return
   246  		}
   247  		ka.ProjectionExpression = expr.Projection()
   248  		ka.ExpressionAttributeNames = expr.Names()
   249  	}
   250  	in := &dyn.BatchGetItemInput{RequestItems: map[string]*dyn.KeysAndAttributes{c.table: ka}}
   251  	if opts.BeforeDo != nil {
   252  		if err := opts.BeforeDo(driver.AsFunc(in)); err != nil {
   253  			setErr(err)
   254  			return
   255  		}
   256  	}
   257  	out, err := c.db.BatchGetItemWithContext(ctx, in)
   258  	if err != nil {
   259  		setErr(err)
   260  		return
   261  	}
   262  	found := make([]bool, end-start+1)
   263  	am := mapActionIndices(gets, start, end)
   264  	for _, item := range out.Responses[c.table] {
   265  		if item != nil {
   266  			key := map[string]interface{}{c.partitionKey: nil}
   267  			if c.sortKey != "" {
   268  				key[c.sortKey] = nil
   269  			}
   270  			keysOnly, err := driver.NewDocument(key)
   271  			if err != nil {
   272  				panic(err)
   273  			}
   274  			err = decodeDoc(&dyn.AttributeValue{M: item}, keysOnly)
   275  			if err != nil {
   276  				continue
   277  			}
   278  			decKey, err := c.Key(keysOnly)
   279  			if err != nil {
   280  				continue
   281  			}
   282  			i := am[decKey]
   283  			errs[gets[i].Index] = decodeDoc(&dyn.AttributeValue{M: item}, gets[i].Doc)
   284  			found[i-start] = true
   285  		}
   286  	}
   287  	for delta, f := range found {
   288  		if !f {
   289  			errs[gets[start+delta].Index] = gcerr.Newf(gcerr.NotFound, nil, "item %v not found", gets[start+delta].Doc)
   290  		}
   291  	}
   292  }
   293  
   294  func mapActionIndices(actions []*driver.Action, start, end int) map[interface{}]int {
   295  	m := make(map[interface{}]int)
   296  	for i := start; i <= end; i++ {
   297  		m[actions[i].Key] = i
   298  	}
   299  	return m
   300  }
   301  
   302  // runWrites executes all the writes as separate RPCs, concurrently.
   303  func (c *collection) runWrites(ctx context.Context, writes []*driver.Action, errs []error, opts *driver.RunActionsOptions) {
   304  	var ops []*writeOp
   305  	for _, w := range writes {
   306  		op, err := c.newWriteOp(w, opts)
   307  		if err != nil {
   308  			errs[w.Index] = err
   309  		} else {
   310  			ops = append(ops, op)
   311  		}
   312  	}
   313  
   314  	t := driver.NewThrottle(c.opts.MaxOutstandingActionRPCs)
   315  	for _, op := range ops {
   316  		op := op
   317  		t.Acquire()
   318  		go func() {
   319  			defer t.Release()
   320  			err := op.run(ctx)
   321  			a := op.action
   322  			if err != nil {
   323  				errs[a.Index] = err
   324  			} else {
   325  				errs[a.Index] = c.onSuccess(op)
   326  			}
   327  		}()
   328  	}
   329  	t.Wait()
   330  }
   331  
   332  // A writeOp describes a single write to DynamoDB. The write can be executed
   333  // on its own, or included as part of a transaction.
   334  type writeOp struct {
   335  	action          *driver.Action
   336  	writeItem       *dyn.TransactWriteItem // for inclusion in a transaction
   337  	newPartitionKey string                 // for a Create on a document without a partition key
   338  	newRevision     string
   339  	run             func(context.Context) error // run as a single RPC
   340  }
   341  
   342  func (c *collection) newWriteOp(a *driver.Action, opts *driver.RunActionsOptions) (*writeOp, error) {
   343  	switch a.Kind {
   344  	case driver.Create, driver.Replace, driver.Put:
   345  		return c.newPut(a, opts)
   346  	case driver.Update:
   347  		return c.newUpdate(a, opts)
   348  	case driver.Delete:
   349  		return c.newDelete(a, opts)
   350  	default:
   351  		panic("bad write kind")
   352  	}
   353  }
   354  
   355  func (c *collection) newPut(a *driver.Action, opts *driver.RunActionsOptions) (*writeOp, error) {
   356  	av, err := encodeDoc(a.Doc)
   357  	if err != nil {
   358  		return nil, err
   359  	}
   360  	mf := c.missingKeyField(av.M)
   361  	if a.Kind != driver.Create && mf != "" {
   362  		return nil, fmt.Errorf("missing key field %q", mf)
   363  	}
   364  	var newPartitionKey string
   365  	if mf == c.partitionKey {
   366  		newPartitionKey = driver.UniqueString()
   367  		av.M[c.partitionKey] = new(dyn.AttributeValue).SetS(newPartitionKey)
   368  	}
   369  	if c.sortKey != "" && mf == c.sortKey {
   370  		// It doesn't make sense to generate a random sort key.
   371  		return nil, fmt.Errorf("missing sort key %q", c.sortKey)
   372  	}
   373  	var rev string
   374  	if a.Doc.HasField(c.opts.RevisionField) {
   375  		rev = driver.UniqueString()
   376  		if av.M[c.opts.RevisionField], err = encodeValue(rev); err != nil {
   377  			return nil, err
   378  		}
   379  	}
   380  	dput := &dyn.Put{
   381  		TableName: &c.table,
   382  		Item:      av.M,
   383  	}
   384  	cb, err := c.precondition(a)
   385  	if err != nil {
   386  		return nil, err
   387  	}
   388  	if cb != nil {
   389  		ce, err := expression.NewBuilder().WithCondition(*cb).Build()
   390  		if err != nil {
   391  			return nil, err
   392  		}
   393  		dput.ExpressionAttributeNames = ce.Names()
   394  		dput.ExpressionAttributeValues = ce.Values()
   395  		dput.ConditionExpression = ce.Condition()
   396  	}
   397  	return &writeOp{
   398  		action:          a,
   399  		writeItem:       &dyn.TransactWriteItem{Put: dput},
   400  		newPartitionKey: newPartitionKey,
   401  		newRevision:     rev,
   402  		run: func(ctx context.Context) error {
   403  			return c.runPut(ctx, dput, a, opts)
   404  		},
   405  	}, nil
   406  }
   407  
   408  func (c *collection) runPut(ctx context.Context, dput *dyn.Put, a *driver.Action, opts *driver.RunActionsOptions) error {
   409  	in := &dyn.PutItemInput{
   410  		TableName:                 dput.TableName,
   411  		Item:                      dput.Item,
   412  		ConditionExpression:       dput.ConditionExpression,
   413  		ExpressionAttributeNames:  dput.ExpressionAttributeNames,
   414  		ExpressionAttributeValues: dput.ExpressionAttributeValues,
   415  	}
   416  	if opts.BeforeDo != nil {
   417  		if err := opts.BeforeDo(driver.AsFunc(in)); err != nil {
   418  			return err
   419  		}
   420  	}
   421  	_, err := c.db.PutItemWithContext(ctx, in)
   422  	if ae, ok := err.(awserr.Error); ok && ae.Code() == dyn.ErrCodeConditionalCheckFailedException {
   423  		if a.Kind == driver.Create {
   424  			err = gcerr.Newf(gcerr.AlreadyExists, err, "document already exists")
   425  		}
   426  		if rev, _ := a.Doc.GetField(c.opts.RevisionField); rev == nil && a.Kind == driver.Replace {
   427  			err = gcerr.Newf(gcerr.NotFound, nil, "document not found")
   428  		}
   429  	}
   430  	return err
   431  }
   432  
   433  func (c *collection) newDelete(a *driver.Action, opts *driver.RunActionsOptions) (*writeOp, error) {
   434  	av, err := encodeDocKeyFields(a.Doc, c.partitionKey, c.sortKey)
   435  	if err != nil {
   436  		return nil, err
   437  	}
   438  	del := &dyn.Delete{
   439  		TableName: &c.table,
   440  		Key:       av.M,
   441  	}
   442  	cb, err := c.precondition(a)
   443  	if err != nil {
   444  		return nil, err
   445  	}
   446  	if cb != nil {
   447  		ce, err := expression.NewBuilder().WithCondition(*cb).Build()
   448  		if err != nil {
   449  			return nil, err
   450  		}
   451  		del.ExpressionAttributeNames = ce.Names()
   452  		del.ExpressionAttributeValues = ce.Values()
   453  		del.ConditionExpression = ce.Condition()
   454  	}
   455  	return &writeOp{
   456  		action:    a,
   457  		writeItem: &dyn.TransactWriteItem{Delete: del},
   458  		run: func(ctx context.Context) error {
   459  			in := &dyn.DeleteItemInput{
   460  				TableName:                 del.TableName,
   461  				Key:                       del.Key,
   462  				ConditionExpression:       del.ConditionExpression,
   463  				ExpressionAttributeNames:  del.ExpressionAttributeNames,
   464  				ExpressionAttributeValues: del.ExpressionAttributeValues,
   465  			}
   466  			if opts.BeforeDo != nil {
   467  				if err := opts.BeforeDo(driver.AsFunc(in)); err != nil {
   468  					return err
   469  				}
   470  			}
   471  			_, err := c.db.DeleteItemWithContext(ctx, in)
   472  			return err
   473  		},
   474  	}, nil
   475  }
   476  
   477  func (c *collection) newUpdate(a *driver.Action, opts *driver.RunActionsOptions) (*writeOp, error) {
   478  	av, err := encodeDocKeyFields(a.Doc, c.partitionKey, c.sortKey)
   479  	if err != nil {
   480  		return nil, err
   481  	}
   482  	var ub expression.UpdateBuilder
   483  	for _, m := range a.Mods {
   484  		// TODO(shantuo): check for invalid field paths
   485  		fp := expression.Name(strings.Join(m.FieldPath, "."))
   486  		if inc, ok := m.Value.(driver.IncOp); ok {
   487  			ub = ub.Add(fp, expression.Value(inc.Amount))
   488  		} else if m.Value == nil {
   489  			ub = ub.Remove(fp)
   490  		} else {
   491  			ub = ub.Set(fp, expression.Value(m.Value))
   492  		}
   493  	}
   494  	var rev string
   495  	if a.Doc.HasField(c.opts.RevisionField) {
   496  		rev = driver.UniqueString()
   497  		ub = ub.Set(expression.Name(c.opts.RevisionField), expression.Value(rev))
   498  	}
   499  	cb, err := c.precondition(a)
   500  	if err != nil {
   501  		return nil, err
   502  	}
   503  	ce, err := expression.NewBuilder().WithCondition(*cb).WithUpdate(ub).Build()
   504  	if err != nil {
   505  		return nil, err
   506  	}
   507  	up := &dyn.Update{
   508  		TableName:                 &c.table,
   509  		Key:                       av.M,
   510  		ConditionExpression:       ce.Condition(),
   511  		UpdateExpression:          ce.Update(),
   512  		ExpressionAttributeNames:  ce.Names(),
   513  		ExpressionAttributeValues: ce.Values(),
   514  	}
   515  	return &writeOp{
   516  		action:      a,
   517  		writeItem:   &dyn.TransactWriteItem{Update: up},
   518  		newRevision: rev,
   519  		run: func(ctx context.Context) error {
   520  			in := &dyn.UpdateItemInput{
   521  				TableName:                 up.TableName,
   522  				Key:                       up.Key,
   523  				ConditionExpression:       up.ConditionExpression,
   524  				UpdateExpression:          up.UpdateExpression,
   525  				ExpressionAttributeNames:  up.ExpressionAttributeNames,
   526  				ExpressionAttributeValues: up.ExpressionAttributeValues,
   527  			}
   528  			if opts.BeforeDo != nil {
   529  				if err := opts.BeforeDo(driver.AsFunc(in)); err != nil {
   530  					return err
   531  				}
   532  			}
   533  			_, err := c.db.UpdateItemWithContext(ctx, in)
   534  			return err
   535  		},
   536  	}, nil
   537  }
   538  
   539  // Handle the effects of successful execution.
   540  func (c *collection) onSuccess(op *writeOp) error {
   541  	// Set the new partition key (if any) and the new revision into the user's document.
   542  	if op.newPartitionKey != "" {
   543  		_ = op.action.Doc.SetField(c.partitionKey, op.newPartitionKey) // cannot fail
   544  	}
   545  	if op.newRevision != "" {
   546  		return op.action.Doc.SetField(c.opts.RevisionField, op.newRevision)
   547  	}
   548  	return nil
   549  }
   550  
   551  func (c *collection) missingKeyField(m map[string]*dyn.AttributeValue) string {
   552  	if v, ok := m[c.partitionKey]; !ok || v.NULL != nil {
   553  		return c.partitionKey
   554  	}
   555  	if v, ok := m[c.sortKey]; (!ok || v.NULL != nil) && c.sortKey != "" {
   556  		return c.sortKey
   557  	}
   558  	return ""
   559  }
   560  
   561  // Construct the precondition for the action.
   562  func (c *collection) precondition(a *driver.Action) (*expression.ConditionBuilder, error) {
   563  	switch a.Kind {
   564  	case driver.Create:
   565  		// Precondition: the document doesn't already exist. (Precisely: the partitionKey
   566  		// field is not on the document.)
   567  		c := expression.AttributeNotExists(expression.Name(c.partitionKey))
   568  		return &c, nil
   569  	case driver.Replace, driver.Update:
   570  		// Precondition: the revision matches, or if there is no revision, then
   571  		// the document exists.
   572  		cb, err := revisionPrecondition(a.Doc, c.opts.RevisionField)
   573  		if err != nil {
   574  			return nil, err
   575  		}
   576  		if cb == nil {
   577  			c := expression.AttributeExists(expression.Name(c.partitionKey))
   578  			cb = &c
   579  		}
   580  		return cb, nil
   581  	case driver.Put, driver.Delete:
   582  		// Precondition: the revision matches, if any.
   583  		return revisionPrecondition(a.Doc, c.opts.RevisionField)
   584  	case driver.Get:
   585  		// No preconditions on a Get.
   586  		return nil, nil
   587  	default:
   588  		panic("bad action kind")
   589  	}
   590  }
   591  
   592  // revisionPrecondition returns a DynamoDB expression that asserts that the
   593  // stored document's revision matches the revision of doc.
   594  func revisionPrecondition(doc driver.Document, revField string) (*expression.ConditionBuilder, error) {
   595  	v, err := doc.GetField(revField)
   596  	if err != nil { // field not present
   597  		return nil, nil
   598  	}
   599  	if v == nil { // field is present, but nil
   600  		return nil, nil
   601  	}
   602  	rev, ok := v.(string)
   603  	if !ok {
   604  		return nil, gcerr.Newf(gcerr.InvalidArgument, nil,
   605  			"%s field contains wrong type: got %T, want string",
   606  			revField, v)
   607  	}
   608  	if rev == "" {
   609  		return nil, nil
   610  	}
   611  	// Value encodes rev to an attribute value.
   612  	cb := expression.Name(revField).Equal(expression.Value(rev))
   613  	return &cb, nil
   614  }
   615  
   616  // TODO(jba): use this if/when we support atomic writes.
   617  func (c *collection) transactWrite(ctx context.Context, actions []*driver.Action, errs []error, opts *driver.RunActionsOptions, start, end int) {
   618  	setErr := func(err error) {
   619  		for i := start; i <= end; i++ {
   620  			errs[actions[i].Index] = err
   621  		}
   622  	}
   623  
   624  	var ops []*writeOp
   625  	tws := make([]*dyn.TransactWriteItem, 0, end-start+1)
   626  	for i := start; i <= end; i++ {
   627  		a := actions[i]
   628  		op, err := c.newWriteOp(a, opts)
   629  		if err != nil {
   630  			setErr(err)
   631  			return
   632  		}
   633  		ops = append(ops, op)
   634  		tws = append(tws, op.writeItem)
   635  	}
   636  
   637  	in := &dyn.TransactWriteItemsInput{
   638  		ClientRequestToken: aws.String(driver.UniqueString()),
   639  		TransactItems:      tws,
   640  	}
   641  
   642  	if opts.BeforeDo != nil {
   643  		asFunc := func(i interface{}) bool {
   644  			p, ok := i.(**dyn.TransactWriteItemsInput)
   645  			if !ok {
   646  				return false
   647  			}
   648  			*p = in
   649  			return true
   650  		}
   651  		if err := opts.BeforeDo(asFunc); err != nil {
   652  			setErr(err)
   653  			return
   654  		}
   655  	}
   656  	if _, err := c.db.TransactWriteItemsWithContext(ctx, in); err != nil {
   657  		setErr(err)
   658  		return
   659  	}
   660  	for _, op := range ops {
   661  		errs[op.action.Index] = c.onSuccess(op)
   662  	}
   663  }
   664  
   665  // RevisionToBytes implements driver.RevisionToBytes.
   666  func (c *collection) RevisionToBytes(rev interface{}) ([]byte, error) {
   667  	s, ok := rev.(string)
   668  	if !ok {
   669  		return nil, gcerr.Newf(gcerr.InvalidArgument, nil, "revision %v of type %[1]T is not a string", rev)
   670  	}
   671  	return []byte(s), nil
   672  }
   673  
   674  // BytesToRevision implements driver.BytesToRevision.
   675  func (c *collection) BytesToRevision(b []byte) (interface{}, error) {
   676  	return string(b), nil
   677  }
   678  
   679  func (c *collection) As(i interface{}) bool {
   680  	p, ok := i.(**dyn.DynamoDB)
   681  	if !ok {
   682  		return false
   683  	}
   684  	*p = c.db
   685  	return true
   686  }
   687  
   688  // ErrorAs implements driver.Collection.ErrorAs.
   689  func (c *collection) ErrorAs(err error, i interface{}) bool {
   690  	e, ok := err.(awserr.Error)
   691  	if !ok {
   692  		return false
   693  	}
   694  	p, ok := i.(*awserr.Error)
   695  	if !ok {
   696  		return false
   697  	}
   698  	*p = e
   699  	return true
   700  }
   701  
   702  func (c *collection) ErrorCode(err error) gcerrors.ErrorCode {
   703  	ae, ok := err.(awserr.Error)
   704  	if !ok {
   705  		return gcerrors.Unknown
   706  	}
   707  	ec, ok := errorCodeMap[ae.Code()]
   708  	if !ok {
   709  		return gcerrors.Unknown
   710  	}
   711  	return ec
   712  }
   713  
   714  var errorCodeMap = map[string]gcerrors.ErrorCode{
   715  	dyn.ErrCodeConditionalCheckFailedException:          gcerrors.FailedPrecondition,
   716  	dyn.ErrCodeProvisionedThroughputExceededException:   gcerrors.ResourceExhausted,
   717  	dyn.ErrCodeResourceNotFoundException:                gcerrors.NotFound,
   718  	dyn.ErrCodeItemCollectionSizeLimitExceededException: gcerrors.ResourceExhausted,
   719  	dyn.ErrCodeTransactionConflictException:             gcerrors.Internal,
   720  	dyn.ErrCodeRequestLimitExceeded:                     gcerrors.ResourceExhausted,
   721  	dyn.ErrCodeInternalServerError:                      gcerrors.Internal,
   722  	dyn.ErrCodeTransactionCanceledException:             gcerrors.FailedPrecondition,
   723  	dyn.ErrCodeTransactionInProgressException:           gcerrors.InvalidArgument,
   724  	dyn.ErrCodeIdempotentParameterMismatchException:     gcerrors.InvalidArgument,
   725  	"ValidationException":                               gcerrors.InvalidArgument,
   726  }
   727  
   728  // Close implements driver.Collection.Close.
   729  func (c *collection) Close() error { return nil }