github.com/niedbalski/juju@v0.0.0-20190215020005-8ff100488e47/state/statetest/persistence_stubs.go (about)

     1  // Copyright 2015 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package statetest
     5  
     6  import (
     7  	"reflect"
     8  
     9  	"github.com/juju/errors"
    10  	"github.com/juju/testing"
    11  	jujutxn "github.com/juju/txn"
    12  	"gopkg.in/mgo.v2/txn"
    13  )
    14  
    15  type StubPersistence struct {
    16  	*testing.Stub
    17  
    18  	RunFunc func(jujutxn.TransactionSource) error
    19  
    20  	ReturnAll interface{} // homegenous(?) list of doc struct (not pointers)
    21  	ReturnOne interface{} // a doc struct (not a pointer)
    22  
    23  	ReturnApplicationExistsOps       []txn.Op
    24  	ReturnIncCharmModifiedVersionOps []txn.Op
    25  }
    26  
    27  func NewStubPersistence(stub *testing.Stub) *StubPersistence {
    28  	s := &StubPersistence{
    29  		Stub: stub,
    30  	}
    31  	s.RunFunc = s.run
    32  	return s
    33  }
    34  
    35  func (s *StubPersistence) One(collName, id string, doc interface{}) error {
    36  	s.AddCall("One", collName, id, doc)
    37  	if err := s.NextErr(); err != nil {
    38  		return errors.Trace(err)
    39  	}
    40  
    41  	if reflect.TypeOf(s.ReturnOne) == nil {
    42  		return errors.NotFoundf("resource")
    43  	}
    44  	ptr := reflect.ValueOf(doc)
    45  	newVal := reflect.ValueOf(s.ReturnOne)
    46  	ptr.Elem().Set(newVal)
    47  	return nil
    48  }
    49  
    50  func (s *StubPersistence) All(collName string, query, docs interface{}) error {
    51  	s.AddCall("All", collName, query, docs)
    52  	if err := s.NextErr(); err != nil {
    53  		return errors.Trace(err)
    54  	}
    55  
    56  	ptr := reflect.ValueOf(docs)
    57  	if reflect.TypeOf(s.ReturnAll) == nil {
    58  		ptr.Elem().SetLen(0)
    59  	} else {
    60  		newVal := reflect.ValueOf(s.ReturnAll)
    61  		ptr.Elem().Set(newVal)
    62  	}
    63  	return nil
    64  }
    65  
    66  func (s *StubPersistence) Run(buildTxn jujutxn.TransactionSource) error {
    67  	s.AddCall("Run", buildTxn)
    68  	if err := s.NextErr(); err != nil {
    69  		return errors.Trace(err)
    70  	}
    71  
    72  	if err := s.run(buildTxn); err != nil {
    73  		return errors.Trace(err)
    74  	}
    75  
    76  	return nil
    77  }
    78  
    79  // See github.com/juju/txn.transactionRunner.Run.
    80  func (s *StubPersistence) run(buildTxn jujutxn.TransactionSource) error {
    81  	for i := 0; ; i++ {
    82  		ops, err := buildTxn(i)
    83  		if errors.Cause(err) == jujutxn.ErrTransientFailure {
    84  			continue
    85  		}
    86  		if errors.Cause(err) == jujutxn.ErrNoOperations {
    87  			return nil
    88  		}
    89  		if err != nil {
    90  			return err
    91  		}
    92  
    93  		err = s.RunTransaction(ops)
    94  		if errors.Cause(err) == txn.ErrAborted {
    95  			continue
    96  		}
    97  		return err
    98  	}
    99  }
   100  
   101  func (s *StubPersistence) RunTransaction(ops []txn.Op) error {
   102  	s.AddCall("RunTransaction", ops)
   103  	if err := s.NextErr(); err != nil {
   104  		return errors.Trace(err)
   105  	}
   106  
   107  	return nil
   108  }
   109  
   110  func (s *StubPersistence) ApplicationExistsOps(applicationID string) []txn.Op {
   111  	s.AddCall("ApplicationExistsOps", applicationID)
   112  	// pop off an error so num errors == num calls, even though this call
   113  	// doesn't actually use the error.
   114  	s.NextErr()
   115  
   116  	return s.ReturnApplicationExistsOps
   117  }
   118  
   119  func (s *StubPersistence) IncCharmModifiedVersionOps(applicationID string) []txn.Op {
   120  	s.AddCall("IncCharmModifiedVersionOps", applicationID)
   121  	// pop off an error so num errors == num calls, even though this call
   122  	// doesn't actually use the error.
   123  	s.NextErr()
   124  
   125  	return s.ReturnIncCharmModifiedVersionOps
   126  }