github.com/makyo/juju@v0.0.0-20160425123129-2608902037e9/payload/persistence/fakes_test.go (about)

     1  // Copyright 2015 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package persistence
     5  
     6  import (
     7  	"github.com/juju/errors"
     8  	gitjujutesting "github.com/juju/testing"
     9  	jc "github.com/juju/testing/checkers"
    10  	jujutxn "github.com/juju/txn"
    11  	gc "gopkg.in/check.v1"
    12  	"gopkg.in/mgo.v2/bson"
    13  	"gopkg.in/mgo.v2/txn"
    14  )
    15  
    16  type fakeStatePersistence struct {
    17  	*gitjujutesting.Stub
    18  
    19  	docs map[string]*payloadDoc
    20  	ops  [][]txn.Op
    21  }
    22  
    23  func (sp *fakeStatePersistence) SetDocs(docs ...*payloadDoc) {
    24  	if sp.docs == nil {
    25  		sp.docs = make(map[string]*payloadDoc)
    26  	}
    27  	for _, doc := range docs {
    28  		sp.docs[doc.DocID] = doc
    29  	}
    30  }
    31  
    32  func (sp fakeStatePersistence) CheckOps(c *gc.C, expected [][]txn.Op) {
    33  	if len(sp.ops) != len(expected) {
    34  		c.Check(sp.ops, jc.DeepEquals, expected)
    35  		return
    36  	}
    37  
    38  	for i, ops := range sp.ops {
    39  		c.Logf(" -- txn attempt %d --\n", i)
    40  		expectedRun := expected[i]
    41  		if len(ops) != len(expectedRun) {
    42  			c.Check(ops, jc.DeepEquals, expectedRun)
    43  			continue
    44  		}
    45  		for j, op := range ops {
    46  			c.Logf(" <op %d>\n", j)
    47  			expectedOp := expectedRun[j]
    48  			if expectedOp.Insert != nil {
    49  				if doc, ok := expectedOp.Insert.(*PayloadDoc); ok {
    50  					expectedOp.Insert = doc.convert()
    51  				}
    52  			} else if expectedOp.Update != nil {
    53  				if doc, ok := expectedOp.Update.(*PayloadDoc); ok {
    54  					expectedOp.Update = doc.convert()
    55  				}
    56  			}
    57  			c.Check(op, jc.DeepEquals, expectedOp)
    58  		}
    59  	}
    60  }
    61  
    62  func (sp fakeStatePersistence) CheckNoOps(c *gc.C) {
    63  	c.Check(sp.ops, gc.HasLen, 0)
    64  }
    65  
    66  func (sp fakeStatePersistence) One(collName, id string, doc interface{}) error {
    67  	sp.AddCall("One", collName, id, doc)
    68  	if err := sp.NextErr(); err != nil {
    69  		return errors.Trace(err)
    70  	}
    71  
    72  	if len(sp.docs) == 0 {
    73  		return errors.NotFoundf(id)
    74  	}
    75  	found, ok := sp.docs[id]
    76  	if !ok {
    77  		return errors.NotFoundf(id)
    78  	}
    79  	actual := doc.(*payloadDoc)
    80  	*actual = *found
    81  	return nil
    82  }
    83  
    84  func (sp fakeStatePersistence) All(collName string, query, docs interface{}) error {
    85  	sp.AddCall("All", collName, query, docs)
    86  	if err := sp.NextErr(); err != nil {
    87  		return errors.Trace(err)
    88  	}
    89  
    90  	var ids []string
    91  	elems := query.(bson.D)
    92  	if len(elems) < 1 {
    93  		err := errors.Errorf("bad query %v", query)
    94  		panic(err)
    95  	}
    96  	switch elems[0].Name {
    97  	case "_id":
    98  		if len(elems) != 1 {
    99  			err := errors.Errorf("bad query %v", query)
   100  			panic(err)
   101  		}
   102  		elems = elems[0].Value.(bson.D)
   103  		if len(elems) != 1 || elems[0].Name != "$in" {
   104  			err := errors.Errorf("bad query %v", query)
   105  			panic(err)
   106  		}
   107  		ids = elems[0].Value.([]string)
   108  	case "unitid":
   109  		for id := range sp.docs {
   110  			ids = append(ids, id)
   111  		}
   112  	default:
   113  		panic(query)
   114  	}
   115  
   116  	var found []payloadDoc
   117  	for _, id := range ids {
   118  		doc, ok := sp.docs[id]
   119  		if !ok {
   120  			continue
   121  		}
   122  		found = append(found, *doc)
   123  	}
   124  	actual := docs.(*[]payloadDoc)
   125  	*actual = found
   126  	return nil
   127  }
   128  
   129  func (sp *fakeStatePersistence) Run(transactions jujutxn.TransactionSource) error {
   130  	sp.AddCall("Run", transactions)
   131  
   132  	// See transactionRunner.Run in github.com/juju/txn.
   133  	for i := 0; ; i++ {
   134  		const nrRetries = 3
   135  		if i >= nrRetries {
   136  			return jujutxn.ErrExcessiveContention
   137  		}
   138  
   139  		// Get the ops.
   140  		ops, err := transactions(i)
   141  		if err == jujutxn.ErrTransientFailure {
   142  			continue
   143  		}
   144  		if err == jujutxn.ErrNoOperations {
   145  			break
   146  		}
   147  		if err != nil {
   148  			return err
   149  		}
   150  
   151  		// "run" the ops.
   152  		sp.ops = append(sp.ops, ops)
   153  		if err := sp.NextErr(); err == nil {
   154  			return nil
   155  		} else if err != txn.ErrAborted {
   156  			return err
   157  		}
   158  	}
   159  	return nil
   160  }