go.mercari.io/datastore@v1.8.2/aedatastore/transaction.go (about)

     1  package aedatastore
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"sync"
     7  
     8  	w "go.mercari.io/datastore"
     9  	"go.mercari.io/datastore/internal/shared"
    10  	netcontext "golang.org/x/net/context"
    11  	"google.golang.org/appengine/datastore"
    12  )
    13  
    14  var _ w.Transaction = (*transactionImpl)(nil)
    15  var _ w.Commit = (*commitImpl)(nil)
    16  
    17  type contextTransaction struct{}
    18  
    19  type txExtractor struct {
    20  	sync.Mutex
    21  	txCtx   context.Context
    22  	finishC chan txResult
    23  	resultC chan error
    24  }
    25  
    26  type txResult struct {
    27  	commit   bool
    28  	rollback bool
    29  }
    30  
    31  // TransactionContext returns context that is under the AppEngine Datastore's transaction.
    32  func TransactionContext(tx w.Transaction) context.Context {
    33  	txImpl := tx.(*transactionImpl)
    34  	return txImpl.client.ctx
    35  }
    36  
    37  func newTxExtractor(ctx context.Context) (*txExtractor, error) {
    38  	ctxC := make(chan context.Context)
    39  
    40  	ext := &txExtractor{
    41  		finishC: make(chan txResult),
    42  		resultC: make(chan error),
    43  	}
    44  
    45  	rollbackErr := errors.New("rollback requested")
    46  
    47  	go func() {
    48  		// NOTE RunInTransactionが自動的にリトライされるのは初心者殺しなのでリトライしたかったらアプリ側でループしてほしいという意思
    49  		err := datastore.RunInTransaction(ctx, func(ctx netcontext.Context) error {
    50  			// ctxC <- ctx の前にある限りnilではない
    51  			finishC := ext.finishC
    52  
    53  			ctxC <- ctx
    54  
    55  			result, ok := <-finishC
    56  			if !ok {
    57  				return errors.New("channel closed")
    58  			}
    59  
    60  			if result.commit {
    61  				return nil
    62  			} else if result.rollback {
    63  				return rollbackErr
    64  			}
    65  
    66  			panic("unexpected tx state")
    67  
    68  		}, &datastore.TransactionOptions{XG: true, Attempts: 1})
    69  		if err == rollbackErr {
    70  			// This is intended error
    71  			err = nil
    72  		}
    73  		ext.resultC <- toWrapperError(err)
    74  	}()
    75  
    76  	select {
    77  	case txCtx := <-ctxC:
    78  		ext.txCtx = txCtx
    79  	case err := <-ext.resultC:
    80  		if err == nil {
    81  			panic("unexpected state")
    82  		}
    83  		return nil, toWrapperError(err)
    84  	}
    85  
    86  	return ext, nil
    87  }
    88  
    89  func getTxExtractor(ctx context.Context) *txExtractor {
    90  	tx := ctx.Value(contextTransaction{})
    91  	if tx != nil {
    92  		return tx.(*txExtractor)
    93  	}
    94  
    95  	return nil
    96  }
    97  
    98  type transactionImpl struct {
    99  	client    *datastoreImpl
   100  	cacheInfo *w.MiddlewareInfo
   101  }
   102  
   103  type commitImpl struct {
   104  }
   105  
   106  func (tx *transactionImpl) Get(key w.Key, dst interface{}) error {
   107  	err := tx.GetMulti([]w.Key{key}, []interface{}{dst})
   108  	if merr, ok := err.(w.MultiError); ok {
   109  		return merr[0]
   110  	} else if err != nil {
   111  		return err
   112  	}
   113  
   114  	return nil
   115  }
   116  
   117  func (tx *transactionImpl) GetMulti(keys []w.Key, dst interface{}) error {
   118  	cb := shared.NewCacheBridge(tx.cacheInfo, &originalClientBridgeImpl{tx.client}, &originalTransactionBridgeImpl{tx: tx}, nil, tx.client.middlewares)
   119  
   120  	err := shared.GetMultiOps(tx.client.ctx, keys, dst, func(keys []w.Key, dst []w.PropertyList) error {
   121  		return cb.GetMultiWithTx(tx.cacheInfo, keys, dst)
   122  	})
   123  
   124  	return err
   125  }
   126  
   127  func (tx *transactionImpl) Put(key w.Key, src interface{}) (w.PendingKey, error) {
   128  	pKeys, err := tx.PutMulti([]w.Key{key}, []interface{}{src})
   129  	if merr, ok := err.(w.MultiError); ok {
   130  		return nil, merr[0]
   131  	} else if err != nil {
   132  		return nil, err
   133  	}
   134  
   135  	return pKeys[0], nil
   136  }
   137  
   138  func (tx *transactionImpl) PutMulti(keys []w.Key, src interface{}) ([]w.PendingKey, error) {
   139  	cb := shared.NewCacheBridge(tx.cacheInfo, &originalClientBridgeImpl{tx.client}, &originalTransactionBridgeImpl{tx: tx}, nil, tx.client.middlewares)
   140  
   141  	_, pKeys, err := shared.PutMultiOps(tx.client.ctx, keys, src, func(keys []w.Key, src []w.PropertyList) ([]w.Key, []w.PendingKey, error) {
   142  		pKeys, err := cb.PutMultiWithTx(tx.cacheInfo, keys, src)
   143  		return nil, pKeys, err
   144  	})
   145  
   146  	if err != nil {
   147  		return nil, err
   148  	}
   149  
   150  	return pKeys, nil
   151  }
   152  
   153  func (tx *transactionImpl) Delete(key w.Key) error {
   154  	err := tx.DeleteMulti([]w.Key{key})
   155  	if merr, ok := err.(w.MultiError); ok {
   156  		return merr[0]
   157  	} else if err != nil {
   158  		return err
   159  	}
   160  
   161  	return nil
   162  }
   163  
   164  func (tx *transactionImpl) DeleteMulti(keys []w.Key) error {
   165  	cb := shared.NewCacheBridge(tx.cacheInfo, &originalClientBridgeImpl{tx.client}, &originalTransactionBridgeImpl{tx: tx}, nil, tx.client.middlewares)
   166  
   167  	err := shared.DeleteMultiOps(tx.client.ctx, keys, func(keys []w.Key) error {
   168  		return cb.DeleteMultiWithTx(tx.cacheInfo, keys)
   169  	})
   170  
   171  	return err
   172  }
   173  
   174  func (tx *transactionImpl) Commit() (w.Commit, error) {
   175  	ext := getTxExtractor(tx.client.ctx)
   176  	if ext == nil {
   177  		return nil, errors.New("unexpected context")
   178  	}
   179  
   180  	err := ext.commit()
   181  	if err != nil {
   182  		return nil, err
   183  	}
   184  
   185  	cb := shared.NewCacheBridge(tx.cacheInfo, &originalClientBridgeImpl{tx.client}, &originalTransactionBridgeImpl{tx: tx}, nil, tx.client.middlewares)
   186  	commitImpl := &commitImpl{}
   187  	err = cb.PostCommit(tx.cacheInfo, tx, commitImpl)
   188  
   189  	if err != nil {
   190  		return nil, err
   191  	}
   192  
   193  	return commitImpl, nil
   194  }
   195  
   196  func (tx *transactionImpl) Rollback() error {
   197  	ext := getTxExtractor(tx.client.ctx)
   198  	if ext == nil {
   199  		return errors.New("unexpected context")
   200  	}
   201  
   202  	err := ext.rollback()
   203  	if err != nil {
   204  		return err
   205  	}
   206  
   207  	cb := shared.NewCacheBridge(tx.cacheInfo, &originalClientBridgeImpl{tx.client}, &originalTransactionBridgeImpl{tx: tx}, nil, tx.client.middlewares)
   208  	return cb.PostRollback(tx.cacheInfo, tx)
   209  }
   210  
   211  func (tx *transactionImpl) Batch() *w.TransactionBatch {
   212  	return &w.TransactionBatch{Transaction: tx}
   213  }
   214  
   215  func (c *commitImpl) Key(p w.PendingKey) w.Key {
   216  	pk := toOriginalPendingKey(p)
   217  	return toWrapperKey(p.StoredContext(), pk)
   218  }
   219  
   220  func (ext *txExtractor) commit() error {
   221  	ext.Lock()
   222  	finishC := ext.finishC
   223  	if finishC != nil {
   224  		ext.finishC = nil
   225  	}
   226  	ext.Unlock()
   227  	if finishC == nil {
   228  		return errors.New("datastore: transaction expired")
   229  	}
   230  	finishC <- txResult{commit: true}
   231  	return <-ext.resultC
   232  }
   233  
   234  func (ext *txExtractor) rollback() error {
   235  	ext.Lock()
   236  	finishC := ext.finishC
   237  	if finishC != nil {
   238  		ext.finishC = nil
   239  	}
   240  	ext.Unlock()
   241  	if finishC == nil {
   242  		return errors.New("datastore: transaction expired")
   243  	}
   244  	finishC <- txResult{rollback: true}
   245  	return <-ext.resultC
   246  }