github.com/thiagoyeds/go-cloud@v0.26.0/docstore/awsdynamodb/dynamo.go (about)

     1  // Copyright 2019 The Go Cloud Development Kit Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package awsdynamodb provides a docstore implementation backed by Amazon
    16  // DynamoDB.
    17  // Use OpenCollection to construct a *docstore.Collection.
    18  //
    19  // URLs
    20  //
    21  // For docstore.OpenCollection, awsdynamodb registers for the scheme
    22  // "dynamodb". The default URL opener will use an AWS session with the default
    23  // credentials and configuration; see
    24  // https://docs.aws.amazon.com/sdk-for-go/api/aws/session/ for more details.
    25  // To customize the URL opener, or for more details on the URL format, see
    26  // URLOpener.
    27  // See https://gocloud.dev/concepts/urls/ for background information.
    28  //
    29  // As
    30  //
    31  // awsdynamodb exposes the following types for As:
    32  //  - Collection.As: *dynamodb.DynamoDB
    33  //  - ActionList.BeforeDo: *dynamodb.BatchGetItemInput or *dynamodb.PutItemInput or *dynamodb.DeleteItemInput
    34  //                         or *dynamodb.UpdateItemInput
    35  //  - Query.BeforeQuery: *dynamodb.QueryInput or *dynamodb.ScanInput
    36  //  - DocumentIterator: *dynamodb.QueryOutput or *dynamodb.ScanOutput
    37  //  - ErrorAs: awserr.Error
    38  package awsdynamodb
    39  
    40  import (
    41  	"context"
    42  	"fmt"
    43  	"reflect"
    44  	"strings"
    45  
    46  	"github.com/aws/aws-sdk-go/aws"
    47  	"github.com/aws/aws-sdk-go/aws/awserr"
    48  	dyn "github.com/aws/aws-sdk-go/service/dynamodb"
    49  	"github.com/aws/aws-sdk-go/service/dynamodb/expression"
    50  	"github.com/google/wire"
    51  	"gocloud.dev/docstore"
    52  	"gocloud.dev/docstore/driver"
    53  	"gocloud.dev/gcerrors"
    54  	"gocloud.dev/internal/gcerr"
    55  )
    56  
    57  // Set holds Wire providers for this package.
    58  var Set = wire.NewSet(
    59  	wire.Struct(new(URLOpener), "ConfigProvider"),
    60  )
    61  
    62  type collection struct {
    63  	db           *dyn.DynamoDB
    64  	table        string // DynamoDB table name
    65  	partitionKey string
    66  	sortKey      string
    67  	description  *dyn.TableDescription
    68  	opts         *Options
    69  }
    70  
    71  // FallbackFunc is a function for executing queries that cannot be run by the built-in
    72  // awsdynamodb logic. See Options.RunQueryFunc for details.
    73  type FallbackFunc func(context.Context, *driver.Query, RunQueryFunc) (driver.DocumentIterator, error)
    74  
    75  type Options struct {
    76  	// If false, queries that can only be executed by scanning the entire table
    77  	// return an error instead (with the exception of a query with no filters).
    78  	AllowScans bool
    79  
    80  	// The name of the field holding the document revision.
    81  	// Defaults to docstore.DefaultRevisionField.
    82  	RevisionField string
    83  
    84  	// If set, call this function on queries that we cannot execute at all (for
    85  	// example, a query with an OrderBy clause that lacks an equality filter on a
    86  	// partition key). The function should execute the query however it wishes, and
    87  	// return an iterator over the results. It can use the RunQueryFunc passed as its
    88  	// third argument to have the DynamoDB driver run a query, for instance a
    89  	// modified version of the original query.
    90  	//
    91  	// If RunQueryFallback is nil, queries that cannot be executed will fail with a
    92  	// error that has code Unimplemented.
    93  	RunQueryFallback FallbackFunc
    94  
    95  	// The maximum number of concurrent goroutines started for a single call to
    96  	// ActionList.Do. If less than 1, there is no limit.
    97  	MaxOutstandingActionRPCs int
    98  
    99  	// If true, a strongly consistent read is used whenever possible, including
   100  	// get, query, scan, etc.; default to false, where an eventually consistent
   101  	// read is used.
   102  	//
   103  	// Not all read operations support this mode however, such as querying against
   104  	// a global secondary index, the operation will return an InvalidArgument error
   105  	// in such case, please check the official DynamoDB documentation for more
   106  	// details.
   107  	//
   108  	// The native client for DynamoDB uses this option in a per-action basis, if
   109  	// you need the flexibility to run both modes on the same collection, create
   110  	// two collections with different mode.
   111  	ConsistentRead bool
   112  }
   113  
   114  // RunQueryFunc is the type of the function passed to RunQueryFallback.
   115  type RunQueryFunc func(context.Context, *driver.Query) (driver.DocumentIterator, error)
   116  
   117  // OpenCollection creates a *docstore.Collection representing a DynamoDB collection.
   118  func OpenCollection(db *dyn.DynamoDB, tableName, partitionKey, sortKey string, opts *Options) (*docstore.Collection, error) {
   119  	c, err := newCollection(db, tableName, partitionKey, sortKey, opts)
   120  	if err != nil {
   121  		return nil, err
   122  	}
   123  	return docstore.NewCollection(c), nil
   124  }
   125  
   126  func newCollection(db *dyn.DynamoDB, tableName, partitionKey, sortKey string, opts *Options) (*collection, error) {
   127  	out, err := db.DescribeTable(&dyn.DescribeTableInput{TableName: &tableName})
   128  	if err != nil {
   129  		return nil, err
   130  	}
   131  	if opts == nil {
   132  		opts = &Options{}
   133  	}
   134  	if opts.RevisionField == "" {
   135  		opts.RevisionField = docstore.DefaultRevisionField
   136  	}
   137  	return &collection{
   138  		db:           db,
   139  		table:        tableName,
   140  		partitionKey: partitionKey,
   141  		sortKey:      sortKey,
   142  		description:  out.Table,
   143  		opts:         opts,
   144  	}, nil
   145  }
   146  
   147  // Key returns a two-element array with the partition key and sort key, if any.
   148  func (c *collection) Key(doc driver.Document) (interface{}, error) {
   149  	pkey, err := doc.GetField(c.partitionKey)
   150  	if err != nil || pkey == nil || driver.IsEmptyValue(reflect.ValueOf(pkey)) {
   151  		return nil, nil // missing key is not an error
   152  	}
   153  	keys := [2]interface{}{pkey}
   154  	if c.sortKey != "" {
   155  		keys[1], _ = doc.GetField(c.sortKey) // ignore error since keys[1] is nil in that case
   156  	}
   157  	return keys, nil
   158  }
   159  
   160  func (c *collection) RevisionField() string { return c.opts.RevisionField }
   161  
   162  func (c *collection) RunActions(ctx context.Context, actions []*driver.Action, opts *driver.RunActionsOptions) driver.ActionListError {
   163  	errs := make([]error, len(actions))
   164  	beforeGets, gets, writes, afterGets := driver.GroupActions(actions)
   165  	c.runGets(ctx, beforeGets, errs, opts)
   166  	ch := make(chan struct{})
   167  	go func() { defer close(ch); c.runWrites(ctx, writes, errs, opts) }()
   168  	c.runGets(ctx, gets, errs, opts)
   169  	<-ch
   170  	c.runGets(ctx, afterGets, errs, opts)
   171  	return driver.NewActionListError(errs)
   172  }
   173  
   174  func (c *collection) runGets(ctx context.Context, actions []*driver.Action, errs []error, opts *driver.RunActionsOptions) {
   175  	const batchSize = 100
   176  	t := driver.NewThrottle(c.opts.MaxOutstandingActionRPCs)
   177  	for _, group := range driver.GroupByFieldPath(actions) {
   178  		n := len(group) / batchSize
   179  		for i := 0; i < n; i++ {
   180  			i := i
   181  			t.Acquire()
   182  			go func() {
   183  				defer t.Release()
   184  				c.batchGet(ctx, group, errs, opts, batchSize*i, batchSize*(i+1)-1)
   185  			}()
   186  		}
   187  		if n*batchSize < len(group) {
   188  			t.Acquire()
   189  			go func() {
   190  				defer t.Release()
   191  				c.batchGet(ctx, group, errs, opts, batchSize*n, len(group)-1)
   192  			}()
   193  		}
   194  	}
   195  	t.Wait()
   196  }
   197  
   198  func (c *collection) batchGet(ctx context.Context, gets []*driver.Action, errs []error, opts *driver.RunActionsOptions, start, end int) {
   199  	// errors need to be mapped to the actions' indices.
   200  	setErr := func(err error) {
   201  		for i := start; i <= end; i++ {
   202  			errs[gets[i].Index] = err
   203  		}
   204  	}
   205  
   206  	keys := make([]map[string]*dyn.AttributeValue, 0, end-start+1)
   207  	for i := start; i <= end; i++ {
   208  		av, err := encodeDocKeyFields(gets[i].Doc, c.partitionKey, c.sortKey)
   209  		if err != nil {
   210  			errs[gets[i].Index] = err
   211  		}
   212  
   213  		keys = append(keys, av.M)
   214  	}
   215  	ka := &dyn.KeysAndAttributes{
   216  		Keys:           keys,
   217  		ConsistentRead: aws.Bool(c.opts.ConsistentRead),
   218  	}
   219  	if len(gets[start].FieldPaths) != 0 {
   220  		// We need to add the key fields if the user doesn't include them. The
   221  		// BatchGet API doesn't return them otherwise.
   222  		var hasP, hasS bool
   223  		var nbs []expression.NameBuilder
   224  		for _, fp := range gets[start].FieldPaths {
   225  			p := strings.Join(fp, ".")
   226  			nbs = append(nbs, expression.Name(p))
   227  			if p == c.partitionKey {
   228  				hasP = true
   229  			} else if p == c.sortKey {
   230  				hasS = true
   231  			}
   232  		}
   233  		if !hasP {
   234  			nbs = append(nbs, expression.Name(c.partitionKey))
   235  		}
   236  		if c.sortKey != "" && !hasS {
   237  			nbs = append(nbs, expression.Name(c.sortKey))
   238  		}
   239  		expr, err := expression.NewBuilder().
   240  			WithProjection(expression.AddNames(expression.ProjectionBuilder{}, nbs...)).
   241  			Build()
   242  		if err != nil {
   243  			setErr(err)
   244  			return
   245  		}
   246  		ka.ProjectionExpression = expr.Projection()
   247  		ka.ExpressionAttributeNames = expr.Names()
   248  	}
   249  	in := &dyn.BatchGetItemInput{RequestItems: map[string]*dyn.KeysAndAttributes{c.table: ka}}
   250  	if opts.BeforeDo != nil {
   251  		if err := opts.BeforeDo(driver.AsFunc(in)); err != nil {
   252  			setErr(err)
   253  			return
   254  		}
   255  	}
   256  	out, err := c.db.BatchGetItemWithContext(ctx, in)
   257  	if err != nil {
   258  		setErr(err)
   259  		return
   260  	}
   261  	found := make([]bool, end-start+1)
   262  	am := mapActionIndices(gets, start, end)
   263  	for _, item := range out.Responses[c.table] {
   264  		if item != nil {
   265  			key := map[string]interface{}{c.partitionKey: nil}
   266  			if c.sortKey != "" {
   267  				key[c.sortKey] = nil
   268  			}
   269  			keysOnly, err := driver.NewDocument(key)
   270  			if err != nil {
   271  				panic(err)
   272  			}
   273  			err = decodeDoc(&dyn.AttributeValue{M: item}, keysOnly)
   274  			if err != nil {
   275  				continue
   276  			}
   277  			decKey, err := c.Key(keysOnly)
   278  			if err != nil {
   279  				continue
   280  			}
   281  			i := am[decKey]
   282  			errs[gets[i].Index] = decodeDoc(&dyn.AttributeValue{M: item}, gets[i].Doc)
   283  			found[i-start] = true
   284  		}
   285  	}
   286  	for delta, f := range found {
   287  		if !f {
   288  			errs[gets[start+delta].Index] = gcerr.Newf(gcerr.NotFound, nil, "item %v not found", gets[start+delta].Doc)
   289  		}
   290  	}
   291  }
   292  
   293  func mapActionIndices(actions []*driver.Action, start, end int) map[interface{}]int {
   294  	m := make(map[interface{}]int)
   295  	for i := start; i <= end; i++ {
   296  		m[actions[i].Key] = i
   297  	}
   298  	return m
   299  }
   300  
   301  // runWrites executes all the writes as separate RPCs, concurrently.
   302  func (c *collection) runWrites(ctx context.Context, writes []*driver.Action, errs []error, opts *driver.RunActionsOptions) {
   303  	var ops []*writeOp
   304  	for _, w := range writes {
   305  		op, err := c.newWriteOp(w, opts)
   306  		if err != nil {
   307  			errs[w.Index] = err
   308  		} else {
   309  			ops = append(ops, op)
   310  		}
   311  	}
   312  
   313  	t := driver.NewThrottle(c.opts.MaxOutstandingActionRPCs)
   314  	for _, op := range ops {
   315  		op := op
   316  		t.Acquire()
   317  		go func() {
   318  			defer t.Release()
   319  			err := op.run(ctx)
   320  			a := op.action
   321  			if err != nil {
   322  				errs[a.Index] = err
   323  			} else {
   324  				errs[a.Index] = c.onSuccess(op)
   325  			}
   326  		}()
   327  	}
   328  	t.Wait()
   329  }
   330  
   331  // A writeOp describes a single write to DynamoDB. The write can be executed
   332  // on its own, or included as part of a transaction.
   333  type writeOp struct {
   334  	action          *driver.Action
   335  	writeItem       *dyn.TransactWriteItem // for inclusion in a transaction
   336  	newPartitionKey string                 // for a Create on a document without a partition key
   337  	newRevision     string
   338  	run             func(context.Context) error // run as a single RPC
   339  }
   340  
   341  func (c *collection) newWriteOp(a *driver.Action, opts *driver.RunActionsOptions) (*writeOp, error) {
   342  	switch a.Kind {
   343  	case driver.Create, driver.Replace, driver.Put:
   344  		return c.newPut(a, opts)
   345  	case driver.Update:
   346  		return c.newUpdate(a, opts)
   347  	case driver.Delete:
   348  		return c.newDelete(a, opts)
   349  	default:
   350  		panic("bad write kind")
   351  	}
   352  }
   353  
   354  func (c *collection) newPut(a *driver.Action, opts *driver.RunActionsOptions) (*writeOp, error) {
   355  	av, err := encodeDoc(a.Doc)
   356  	if err != nil {
   357  		return nil, err
   358  	}
   359  	mf := c.missingKeyField(av.M)
   360  	if a.Kind != driver.Create && mf != "" {
   361  		return nil, fmt.Errorf("missing key field %q", mf)
   362  	}
   363  	var newPartitionKey string
   364  	if mf == c.partitionKey {
   365  		newPartitionKey = driver.UniqueString()
   366  		av.M[c.partitionKey] = new(dyn.AttributeValue).SetS(newPartitionKey)
   367  	}
   368  	if c.sortKey != "" && mf == c.sortKey {
   369  		// It doesn't make sense to generate a random sort key.
   370  		return nil, fmt.Errorf("missing sort key %q", c.sortKey)
   371  	}
   372  	var rev string
   373  	if a.Doc.HasField(c.opts.RevisionField) {
   374  		rev = driver.UniqueString()
   375  		if av.M[c.opts.RevisionField], err = encodeValue(rev); err != nil {
   376  			return nil, err
   377  		}
   378  	}
   379  	dput := &dyn.Put{
   380  		TableName: &c.table,
   381  		Item:      av.M,
   382  	}
   383  	cb, err := c.precondition(a)
   384  	if err != nil {
   385  		return nil, err
   386  	}
   387  	if cb != nil {
   388  		ce, err := expression.NewBuilder().WithCondition(*cb).Build()
   389  		if err != nil {
   390  			return nil, err
   391  		}
   392  		dput.ExpressionAttributeNames = ce.Names()
   393  		dput.ExpressionAttributeValues = ce.Values()
   394  		dput.ConditionExpression = ce.Condition()
   395  	}
   396  	return &writeOp{
   397  		action:          a,
   398  		writeItem:       &dyn.TransactWriteItem{Put: dput},
   399  		newPartitionKey: newPartitionKey,
   400  		newRevision:     rev,
   401  		run: func(ctx context.Context) error {
   402  			return c.runPut(ctx, dput, a, opts)
   403  		},
   404  	}, nil
   405  }
   406  
   407  func (c *collection) runPut(ctx context.Context, dput *dyn.Put, a *driver.Action, opts *driver.RunActionsOptions) error {
   408  	in := &dyn.PutItemInput{
   409  		TableName:                 dput.TableName,
   410  		Item:                      dput.Item,
   411  		ConditionExpression:       dput.ConditionExpression,
   412  		ExpressionAttributeNames:  dput.ExpressionAttributeNames,
   413  		ExpressionAttributeValues: dput.ExpressionAttributeValues,
   414  	}
   415  	if opts.BeforeDo != nil {
   416  		if err := opts.BeforeDo(driver.AsFunc(in)); err != nil {
   417  			return err
   418  		}
   419  	}
   420  	_, err := c.db.PutItemWithContext(ctx, in)
   421  	if ae, ok := err.(awserr.Error); ok && ae.Code() == dyn.ErrCodeConditionalCheckFailedException {
   422  		if a.Kind == driver.Create {
   423  			err = gcerr.Newf(gcerr.AlreadyExists, err, "document already exists")
   424  		}
   425  		if rev, _ := a.Doc.GetField(c.opts.RevisionField); rev == nil && a.Kind == driver.Replace {
   426  			err = gcerr.Newf(gcerr.NotFound, nil, "document not found")
   427  		}
   428  	}
   429  	return err
   430  }
   431  
   432  func (c *collection) newDelete(a *driver.Action, opts *driver.RunActionsOptions) (*writeOp, error) {
   433  	av, err := encodeDocKeyFields(a.Doc, c.partitionKey, c.sortKey)
   434  	if err != nil {
   435  		return nil, err
   436  	}
   437  	del := &dyn.Delete{
   438  		TableName: &c.table,
   439  		Key:       av.M,
   440  	}
   441  	cb, err := c.precondition(a)
   442  	if err != nil {
   443  		return nil, err
   444  	}
   445  	if cb != nil {
   446  		ce, err := expression.NewBuilder().WithCondition(*cb).Build()
   447  		if err != nil {
   448  			return nil, err
   449  		}
   450  		del.ExpressionAttributeNames = ce.Names()
   451  		del.ExpressionAttributeValues = ce.Values()
   452  		del.ConditionExpression = ce.Condition()
   453  	}
   454  	return &writeOp{
   455  		action:    a,
   456  		writeItem: &dyn.TransactWriteItem{Delete: del},
   457  		run: func(ctx context.Context) error {
   458  			in := &dyn.DeleteItemInput{
   459  				TableName:                 del.TableName,
   460  				Key:                       del.Key,
   461  				ConditionExpression:       del.ConditionExpression,
   462  				ExpressionAttributeNames:  del.ExpressionAttributeNames,
   463  				ExpressionAttributeValues: del.ExpressionAttributeValues,
   464  			}
   465  			if opts.BeforeDo != nil {
   466  				if err := opts.BeforeDo(driver.AsFunc(in)); err != nil {
   467  					return err
   468  				}
   469  			}
   470  			_, err := c.db.DeleteItemWithContext(ctx, in)
   471  			return err
   472  		},
   473  	}, nil
   474  }
   475  
   476  func (c *collection) newUpdate(a *driver.Action, opts *driver.RunActionsOptions) (*writeOp, error) {
   477  	av, err := encodeDocKeyFields(a.Doc, c.partitionKey, c.sortKey)
   478  	if err != nil {
   479  		return nil, err
   480  	}
   481  	var ub expression.UpdateBuilder
   482  	for _, m := range a.Mods {
   483  		// TODO(shantuo): check for invalid field paths
   484  		fp := expression.Name(strings.Join(m.FieldPath, "."))
   485  		if inc, ok := m.Value.(driver.IncOp); ok {
   486  			ub = ub.Add(fp, expression.Value(inc.Amount))
   487  		} else if m.Value == nil {
   488  			ub = ub.Remove(fp)
   489  		} else {
   490  			ub = ub.Set(fp, expression.Value(m.Value))
   491  		}
   492  	}
   493  	var rev string
   494  	if a.Doc.HasField(c.opts.RevisionField) {
   495  		rev = driver.UniqueString()
   496  		ub = ub.Set(expression.Name(c.opts.RevisionField), expression.Value(rev))
   497  	}
   498  	cb, err := c.precondition(a)
   499  	if err != nil {
   500  		return nil, err
   501  	}
   502  	ce, err := expression.NewBuilder().WithCondition(*cb).WithUpdate(ub).Build()
   503  	if err != nil {
   504  		return nil, err
   505  	}
   506  	up := &dyn.Update{
   507  		TableName:                 &c.table,
   508  		Key:                       av.M,
   509  		ConditionExpression:       ce.Condition(),
   510  		UpdateExpression:          ce.Update(),
   511  		ExpressionAttributeNames:  ce.Names(),
   512  		ExpressionAttributeValues: ce.Values(),
   513  	}
   514  	return &writeOp{
   515  		action:      a,
   516  		writeItem:   &dyn.TransactWriteItem{Update: up},
   517  		newRevision: rev,
   518  		run: func(ctx context.Context) error {
   519  			in := &dyn.UpdateItemInput{
   520  				TableName:                 up.TableName,
   521  				Key:                       up.Key,
   522  				ConditionExpression:       up.ConditionExpression,
   523  				UpdateExpression:          up.UpdateExpression,
   524  				ExpressionAttributeNames:  up.ExpressionAttributeNames,
   525  				ExpressionAttributeValues: up.ExpressionAttributeValues,
   526  			}
   527  			if opts.BeforeDo != nil {
   528  				if err := opts.BeforeDo(driver.AsFunc(in)); err != nil {
   529  					return err
   530  				}
   531  			}
   532  			_, err := c.db.UpdateItemWithContext(ctx, in)
   533  			return err
   534  		},
   535  	}, nil
   536  }
   537  
   538  // Handle the effects of successful execution.
   539  func (c *collection) onSuccess(op *writeOp) error {
   540  	// Set the new partition key (if any) and the new revision into the user's document.
   541  	if op.newPartitionKey != "" {
   542  		_ = op.action.Doc.SetField(c.partitionKey, op.newPartitionKey) // cannot fail
   543  	}
   544  	if op.newRevision != "" {
   545  		return op.action.Doc.SetField(c.opts.RevisionField, op.newRevision)
   546  	}
   547  	return nil
   548  }
   549  
   550  func (c *collection) missingKeyField(m map[string]*dyn.AttributeValue) string {
   551  	if v, ok := m[c.partitionKey]; !ok || v.NULL != nil {
   552  		return c.partitionKey
   553  	}
   554  	if v, ok := m[c.sortKey]; (!ok || v.NULL != nil) && c.sortKey != "" {
   555  		return c.sortKey
   556  	}
   557  	return ""
   558  }
   559  
   560  // Construct the precondition for the action.
   561  func (c *collection) precondition(a *driver.Action) (*expression.ConditionBuilder, error) {
   562  	switch a.Kind {
   563  	case driver.Create:
   564  		// Precondition: the document doesn't already exist. (Precisely: the partitionKey
   565  		// field is not on the document.)
   566  		c := expression.AttributeNotExists(expression.Name(c.partitionKey))
   567  		return &c, nil
   568  	case driver.Replace, driver.Update:
   569  		// Precondition: the revision matches, or if there is no revision, then
   570  		// the document exists.
   571  		cb, err := revisionPrecondition(a.Doc, c.opts.RevisionField)
   572  		if err != nil {
   573  			return nil, err
   574  		}
   575  		if cb == nil {
   576  			c := expression.AttributeExists(expression.Name(c.partitionKey))
   577  			cb = &c
   578  		}
   579  		return cb, nil
   580  	case driver.Put, driver.Delete:
   581  		// Precondition: the revision matches, if any.
   582  		return revisionPrecondition(a.Doc, c.opts.RevisionField)
   583  	case driver.Get:
   584  		// No preconditions on a Get.
   585  		return nil, nil
   586  	default:
   587  		panic("bad action kind")
   588  	}
   589  }
   590  
   591  // revisionPrecondition returns a DynamoDB expression that asserts that the
   592  // stored document's revision matches the revision of doc.
   593  func revisionPrecondition(doc driver.Document, revField string) (*expression.ConditionBuilder, error) {
   594  	v, err := doc.GetField(revField)
   595  	if err != nil { // field not present
   596  		return nil, nil
   597  	}
   598  	if v == nil { // field is present, but nil
   599  		return nil, nil
   600  	}
   601  	rev, ok := v.(string)
   602  	if !ok {
   603  		return nil, gcerr.Newf(gcerr.InvalidArgument, nil,
   604  			"%s field contains wrong type: got %T, want string",
   605  			revField, v)
   606  	}
   607  	if rev == "" {
   608  		return nil, nil
   609  	}
   610  	// Value encodes rev to an attribute value.
   611  	cb := expression.Name(revField).Equal(expression.Value(rev))
   612  	return &cb, nil
   613  }
   614  
   615  // TODO(jba): use this if/when we support atomic writes.
   616  func (c *collection) transactWrite(ctx context.Context, actions []*driver.Action, errs []error, opts *driver.RunActionsOptions, start, end int) {
   617  	setErr := func(err error) {
   618  		for i := start; i <= end; i++ {
   619  			errs[actions[i].Index] = err
   620  		}
   621  	}
   622  
   623  	var ops []*writeOp
   624  	tws := make([]*dyn.TransactWriteItem, 0, end-start+1)
   625  	for i := start; i <= end; i++ {
   626  		a := actions[i]
   627  		op, err := c.newWriteOp(a, opts)
   628  		if err != nil {
   629  			setErr(err)
   630  			return
   631  		}
   632  		ops = append(ops, op)
   633  		tws = append(tws, op.writeItem)
   634  	}
   635  
   636  	in := &dyn.TransactWriteItemsInput{
   637  		ClientRequestToken: aws.String(driver.UniqueString()),
   638  		TransactItems:      tws,
   639  	}
   640  
   641  	if opts.BeforeDo != nil {
   642  		asFunc := func(i interface{}) bool {
   643  			p, ok := i.(**dyn.TransactWriteItemsInput)
   644  			if !ok {
   645  				return false
   646  			}
   647  			*p = in
   648  			return true
   649  		}
   650  		if err := opts.BeforeDo(asFunc); err != nil {
   651  			setErr(err)
   652  			return
   653  		}
   654  	}
   655  	if _, err := c.db.TransactWriteItemsWithContext(ctx, in); err != nil {
   656  		setErr(err)
   657  		return
   658  	}
   659  	for _, op := range ops {
   660  		errs[op.action.Index] = c.onSuccess(op)
   661  	}
   662  }
   663  
   664  // RevisionToBytes implements driver.RevisionToBytes.
   665  func (c *collection) RevisionToBytes(rev interface{}) ([]byte, error) {
   666  	s, ok := rev.(string)
   667  	if !ok {
   668  		return nil, gcerr.Newf(gcerr.InvalidArgument, nil, "revision %v of type %[1]T is not a string", rev)
   669  	}
   670  	return []byte(s), nil
   671  }
   672  
   673  // BytesToRevision implements driver.BytesToRevision.
   674  func (c *collection) BytesToRevision(b []byte) (interface{}, error) {
   675  	return string(b), nil
   676  }
   677  
   678  func (c *collection) As(i interface{}) bool {
   679  	p, ok := i.(**dyn.DynamoDB)
   680  	if !ok {
   681  		return false
   682  	}
   683  	*p = c.db
   684  	return true
   685  }
   686  
   687  // ErrorAs implements driver.Collection.ErrorAs.
   688  func (c *collection) ErrorAs(err error, i interface{}) bool {
   689  	e, ok := err.(awserr.Error)
   690  	if !ok {
   691  		return false
   692  	}
   693  	p, ok := i.(*awserr.Error)
   694  	if !ok {
   695  		return false
   696  	}
   697  	*p = e
   698  	return true
   699  }
   700  
   701  func (c *collection) ErrorCode(err error) gcerrors.ErrorCode {
   702  	ae, ok := err.(awserr.Error)
   703  	if !ok {
   704  		return gcerrors.Unknown
   705  	}
   706  	ec, ok := errorCodeMap[ae.Code()]
   707  	if !ok {
   708  		return gcerrors.Unknown
   709  	}
   710  	return ec
   711  }
   712  
   713  var errorCodeMap = map[string]gcerrors.ErrorCode{
   714  	dyn.ErrCodeConditionalCheckFailedException:          gcerrors.FailedPrecondition,
   715  	dyn.ErrCodeProvisionedThroughputExceededException:   gcerrors.ResourceExhausted,
   716  	dyn.ErrCodeResourceNotFoundException:                gcerrors.NotFound,
   717  	dyn.ErrCodeItemCollectionSizeLimitExceededException: gcerrors.ResourceExhausted,
   718  	dyn.ErrCodeTransactionConflictException:             gcerrors.Internal,
   719  	dyn.ErrCodeRequestLimitExceeded:                     gcerrors.ResourceExhausted,
   720  	dyn.ErrCodeInternalServerError:                      gcerrors.Internal,
   721  	dyn.ErrCodeTransactionCanceledException:             gcerrors.FailedPrecondition,
   722  	dyn.ErrCodeTransactionInProgressException:           gcerrors.InvalidArgument,
   723  	dyn.ErrCodeIdempotentParameterMismatchException:     gcerrors.InvalidArgument,
   724  	"ValidationException":                               gcerrors.InvalidArgument,
   725  }
   726  
   727  // Close implements driver.Collection.Close.
   728  func (c *collection) Close() error { return nil }