github.com/SaurabhDubey-Groww/go-cloud@v0.0.0-20221124105541-b26c29285fd8/docstore/memdocstore/mem.go (about)

     1  // Copyright 2019 The Go Cloud Development Kit Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package memdocstore provides an in-process in-memory implementation of the docstore
    16  // API. It is suitable for local development and testing.
    17  //
    18  // Every document in a memdocstore collection has a unique primary key. The primary
    19  // key values need not be strings; they may be any comparable Go value.
    20  //
    21  // # Action Lists
    22  //
    23  // Action lists are executed concurrently. Each action in an action list is executed
    24  // in a separate goroutine.
    25  //
    26  // memdocstore calls the BeforeDo function of an ActionList once before executing the
    27  // actions. Its As function never returns true.
    28  //
    29  // # URLs
    30  //
    31  // For docstore.OpenCollection, memdocstore registers for the scheme
    32  // "mem".
    33  // To customize the URL opener, or for more details on the URL format,
    34  // see URLOpener.
    35  // See https://gocloud.dev/concepts/urls/ for background information.
    36  package memdocstore // import "gocloud.dev/docstore/memdocstore"
    37  
    38  import (
    39  	"context"
    40  	"encoding/gob"
    41  	"fmt"
    42  	"os"
    43  	"reflect"
    44  	"sort"
    45  	"strconv"
    46  	"strings"
    47  	"sync"
    48  
    49  	"gocloud.dev/docstore"
    50  	"gocloud.dev/docstore/driver"
    51  	"gocloud.dev/gcerrors"
    52  	"gocloud.dev/internal/gcerr"
    53  )
    54  
    55  // Options are optional arguments to the OpenCollection functions.
    56  type Options struct {
    57  	// The name of the field holding the document revision.
    58  	// Defaults to docstore.DefaultRevisionField.
    59  	RevisionField string
    60  
    61  	// The maximum number of concurrent goroutines started for a single call to
    62  	// ActionList.Do. If less than 1, there is no limit.
    63  	MaxOutstandingActions int
    64  
    65  	// The filename associated with this collection.
    66  	// When a collection is opened with a non-nil filename, the collection
    67  	// is loaded from the file if it exists. Otherwise, an empty collection is created.
    68  	// When the collection is closed, its contents are saved to the file.
    69  	Filename string
    70  
    71  	// Call this function when the collection is closed.
    72  	// For internal use only.
    73  	onClose func()
    74  }
    75  
    76  // TODO(jba): make this package thread-safe.
    77  
    78  // OpenCollection creates a *docstore.Collection backed by memory. keyField is the
    79  // document field holding the primary key of the collection.
    80  func OpenCollection(keyField string, opts *Options) (*docstore.Collection, error) {
    81  	c, err := newCollection(keyField, nil, opts)
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  	return docstore.NewCollection(c), nil
    86  }
    87  
    88  // OpenCollectionWithKeyFunc creates a *docstore.Collection backed by memory. keyFunc takes
    89  // a document and returns the document's primary key. It should return nil if the
    90  // document is missing the information to construct a key. This will cause all
    91  // actions, even Create, to fail.
    92  //
    93  // For the collection to be usable with Query.Delete and Query.Update,
    94  // keyFunc must work with map[string]interface{} as well as whatever
    95  // struct type the collection normally uses (if any).
    96  func OpenCollectionWithKeyFunc(keyFunc func(docstore.Document) interface{}, opts *Options) (*docstore.Collection, error) {
    97  	c, err := newCollection("", keyFunc, opts)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  	return docstore.NewCollection(c), nil
   102  }
   103  
   104  func newCollection(keyField string, keyFunc func(docstore.Document) interface{}, opts *Options) (driver.Collection, error) {
   105  	if keyField == "" && keyFunc == nil {
   106  		return nil, gcerr.Newf(gcerr.InvalidArgument, nil, "must provide either keyField or keyFunc")
   107  	}
   108  	if opts == nil {
   109  		opts = &Options{}
   110  	}
   111  	if opts.RevisionField == "" {
   112  		opts.RevisionField = docstore.DefaultRevisionField
   113  	}
   114  	docs, err := loadDocs(opts.Filename)
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  	return &collection{
   119  		keyField:    keyField,
   120  		keyFunc:     keyFunc,
   121  		docs:        docs,
   122  		opts:        opts,
   123  		curRevision: 0,
   124  	}, nil
   125  }
   126  
   127  // A storedDoc is a document that is stored in a collection.
   128  //
   129  // We store documents as maps from keys to values. Even if the user is using
   130  // map[string]interface{}, we make our own copy.
   131  //
   132  // Using a separate helps distinguish documents coming from a user (those "on
   133  // the client," in a more typical driver that acts as a network client) from
   134  // those stored in a collection (those "on the server").
   135  type storedDoc map[string]interface{}
   136  
   137  type collection struct {
   138  	keyField    string
   139  	keyFunc     func(docstore.Document) interface{}
   140  	opts        *Options
   141  	mu          sync.Mutex
   142  	docs        map[interface{}]storedDoc
   143  	curRevision int64 // incremented on each write
   144  }
   145  
   146  func (c *collection) Key(doc driver.Document) (interface{}, error) {
   147  	if c.keyField != "" {
   148  		key, _ := doc.GetField(c.keyField) // no error on missing key, and it will be nil
   149  		return key, nil
   150  	}
   151  	key := c.keyFunc(doc.Origin)
   152  	if key == nil || driver.IsEmptyValue(reflect.ValueOf(key)) {
   153  		return nil, gcerr.Newf(gcerr.InvalidArgument, nil, "missing document key")
   154  	}
   155  	return key, nil
   156  }
   157  
   158  func (c *collection) RevisionField() string {
   159  	return c.opts.RevisionField
   160  }
   161  
   162  // ErrorCode implements driver.ErrorCode.
   163  func (c *collection) ErrorCode(err error) gcerrors.ErrorCode {
   164  	return gcerrors.Code(err)
   165  }
   166  
   167  // RunActions implements driver.RunActions.
   168  func (c *collection) RunActions(ctx context.Context, actions []*driver.Action, opts *driver.RunActionsOptions) driver.ActionListError {
   169  	errs := make([]error, len(actions))
   170  
   171  	// Run the actions concurrently with each other.
   172  	run := func(as []*driver.Action) {
   173  		t := driver.NewThrottle(c.opts.MaxOutstandingActions)
   174  		for _, a := range as {
   175  			a := a
   176  			t.Acquire()
   177  			go func() {
   178  				defer t.Release()
   179  				errs[a.Index] = c.runAction(ctx, a)
   180  			}()
   181  		}
   182  		t.Wait()
   183  	}
   184  
   185  	if opts.BeforeDo != nil {
   186  		if err := opts.BeforeDo(func(interface{}) bool { return false }); err != nil {
   187  			for i := range errs {
   188  				errs[i] = err
   189  			}
   190  			return driver.NewActionListError(errs)
   191  		}
   192  	}
   193  
   194  	beforeGets, gets, writes, afterGets := driver.GroupActions(actions)
   195  	run(beforeGets)
   196  	run(gets)
   197  	run(writes)
   198  	run(afterGets)
   199  	return driver.NewActionListError(errs)
   200  }
   201  
   202  // runAction executes a single action.
   203  func (c *collection) runAction(ctx context.Context, a *driver.Action) error {
   204  	// Stop if the context is done.
   205  	if ctx.Err() != nil {
   206  		return ctx.Err()
   207  	}
   208  	// Get the key from the doc so we can look it up in the map.
   209  	c.mu.Lock()
   210  	defer c.mu.Unlock()
   211  	// If there is a key, get the current document with that key.
   212  	var (
   213  		current storedDoc
   214  		exists  bool
   215  	)
   216  	if a.Key != nil {
   217  		current, exists = c.docs[a.Key]
   218  	}
   219  	// Check for a NotFound error.
   220  	if !exists && (a.Kind == driver.Replace || a.Kind == driver.Update || a.Kind == driver.Get) {
   221  		return gcerr.Newf(gcerr.NotFound, nil, "document with key %v does not exist", a.Key)
   222  	}
   223  	switch a.Kind {
   224  	case driver.Create:
   225  		// It is an error to attempt to create an existing document.
   226  		if exists {
   227  			return gcerr.Newf(gcerr.AlreadyExists, nil, "Create: document with key %v exists", a.Key)
   228  		}
   229  		// If the user didn't supply a value for the key field, create a new one.
   230  		if a.Key == nil {
   231  			a.Key = driver.UniqueString()
   232  			// Set the new key in the document.
   233  			if err := a.Doc.SetField(c.keyField, a.Key); err != nil {
   234  				return gcerr.Newf(gcerr.InvalidArgument, nil, "cannot set key field %q", c.keyField)
   235  			}
   236  		}
   237  		fallthrough
   238  
   239  	case driver.Replace, driver.Put:
   240  		if err := c.checkRevision(a.Doc, current); err != nil {
   241  			return err
   242  		}
   243  		doc, err := encodeDoc(a.Doc)
   244  		if err != nil {
   245  			return err
   246  		}
   247  		if a.Doc.HasField(c.opts.RevisionField) {
   248  			c.changeRevision(doc)
   249  			if err := a.Doc.SetField(c.opts.RevisionField, doc[c.opts.RevisionField]); err != nil {
   250  				return err
   251  			}
   252  		}
   253  		c.docs[a.Key] = doc
   254  
   255  	case driver.Delete:
   256  		if err := c.checkRevision(a.Doc, current); err != nil {
   257  			return err
   258  		}
   259  		delete(c.docs, a.Key)
   260  
   261  	case driver.Update:
   262  		if err := c.checkRevision(a.Doc, current); err != nil {
   263  			return err
   264  		}
   265  		if err := c.update(current, a.Mods); err != nil {
   266  			return err
   267  		}
   268  		if a.Doc.HasField(c.opts.RevisionField) {
   269  			c.changeRevision(current)
   270  			if err := a.Doc.SetField(c.opts.RevisionField, current[c.opts.RevisionField]); err != nil {
   271  				return err
   272  			}
   273  		}
   274  
   275  	case driver.Get:
   276  		// We've already retrieved the document into current, above.
   277  		// Now we copy its fields into the user-provided document.
   278  		if err := decodeDoc(current, a.Doc, a.FieldPaths); err != nil {
   279  			return err
   280  		}
   281  	default:
   282  		return gcerr.Newf(gcerr.Internal, nil, "unknown kind %v", a.Kind)
   283  	}
   284  	return nil
   285  }
   286  
   287  // Must be called with the lock held.
   288  // Does not change the stored doc's revision field; that is up to the caller.
   289  func (c *collection) update(doc storedDoc, mods []driver.Mod) error {
   290  	// Sort mods by first field path element so tests are deterministic.
   291  	sort.Slice(mods, func(i, j int) bool { return mods[i].FieldPath[0] < mods[j].FieldPath[0] })
   292  
   293  	// To make update atomic, we first convert the actions into a form that can't
   294  	// fail.
   295  	type guaranteedMod struct {
   296  		parentMap    map[string]interface{} // the map holding the key to be modified
   297  		key          string
   298  		encodedValue interface{} // the value after encoding
   299  	}
   300  
   301  	gmods := make([]guaranteedMod, len(mods))
   302  	var err error
   303  	for i, mod := range mods {
   304  		gmod := &gmods[i]
   305  		// Check that the field path is valid. That is, every component of the path
   306  		// but the last refers to a map, and no component along the way is nil.
   307  		if gmod.parentMap, err = getParentMap(doc, mod.FieldPath, false); err != nil {
   308  			return err
   309  		}
   310  		gmod.key = mod.FieldPath[len(mod.FieldPath)-1]
   311  		if inc, ok := mod.Value.(driver.IncOp); ok {
   312  			amt, err := encodeValue(inc.Amount)
   313  			if err != nil {
   314  				return err
   315  			}
   316  			if gmod.encodedValue, err = add(gmod.parentMap[gmod.key], amt); err != nil {
   317  				return err
   318  			}
   319  		} else if mod.Value != nil {
   320  			// Make sure the value encodes successfully.
   321  			if gmod.encodedValue, err = encodeValue(mod.Value); err != nil {
   322  				return err
   323  			}
   324  		}
   325  	}
   326  	// Now execute the guaranteed mods.
   327  	for _, m := range gmods {
   328  		if m.encodedValue == nil {
   329  			delete(m.parentMap, m.key)
   330  		} else {
   331  			m.parentMap[m.key] = m.encodedValue
   332  		}
   333  	}
   334  	return nil
   335  }
   336  
   337  // Add two encoded numbers.
   338  // Since they're encoded, they are either int64 or float64.
   339  // Allow adding a float to an int, producing a float.
   340  // TODO(jba): see how other drivers handle that.
   341  func add(x, y interface{}) (interface{}, error) {
   342  	if x == nil {
   343  		return y, nil
   344  	}
   345  	switch x := x.(type) {
   346  	case int64:
   347  		switch y := y.(type) {
   348  		case int64:
   349  			return x + y, nil
   350  		case float64:
   351  			return float64(x) + y, nil
   352  		default:
   353  			// This shouldn't happen because it should be checked by docstore.
   354  			return nil, gcerr.Newf(gcerr.Internal, nil, "bad increment aount type %T", y)
   355  		}
   356  	case float64:
   357  		switch y := y.(type) {
   358  		case int64:
   359  			return x + float64(y), nil
   360  		case float64:
   361  			return x + y, nil
   362  		default:
   363  			// This shouldn't happen because it should be checked by docstore.
   364  			return nil, gcerr.Newf(gcerr.Internal, nil, "bad increment aount type %T", y)
   365  		}
   366  	default:
   367  		return nil, gcerr.Newf(gcerr.InvalidArgument, nil, "value %v being incremented not int64 or float64", x)
   368  	}
   369  }
   370  
   371  // Must be called with the lock held.
   372  func (c *collection) changeRevision(doc storedDoc) {
   373  	c.curRevision++
   374  	doc[c.opts.RevisionField] = c.curRevision
   375  }
   376  
   377  func (c *collection) checkRevision(arg driver.Document, current storedDoc) error {
   378  	if current == nil {
   379  		return nil // no existing document or the incoming doc has no revision
   380  	}
   381  	curRev, ok := current[c.opts.RevisionField]
   382  	if !ok {
   383  		return nil // there is no revision to check
   384  	}
   385  	curRev = curRev.(int64)
   386  	r, err := arg.GetField(c.opts.RevisionField)
   387  	if err != nil || r == nil {
   388  		return nil // no incoming revision information: nothing to check
   389  	}
   390  	wantRev, ok := r.(int64)
   391  	if !ok {
   392  		return gcerr.Newf(gcerr.InvalidArgument, nil, "revision field %s is not an int64", c.opts.RevisionField)
   393  	}
   394  	if wantRev != curRev {
   395  		return gcerr.Newf(gcerr.FailedPrecondition, nil, "mismatched revisions: want %d, current %d", wantRev, curRev)
   396  	}
   397  	return nil
   398  }
   399  
   400  // getAtFieldPath gets the value of m at fp. It returns an error if fp is invalid
   401  // (see getParentMap).
   402  func getAtFieldPath(m map[string]interface{}, fp []string) (interface{}, error) {
   403  	m2, err := getParentMap(m, fp, false)
   404  	if err != nil {
   405  		return nil, err
   406  	}
   407  	v, ok := m2[fp[len(fp)-1]]
   408  	if ok {
   409  		return v, nil
   410  	}
   411  	return nil, gcerr.Newf(gcerr.NotFound, nil, "field %s not found", fp)
   412  }
   413  
   414  // setAtFieldPath sets m's value at fp to val. It creates intermediate maps as
   415  // needed. It returns an error if a non-final component of fp does not denote a map.
   416  func setAtFieldPath(m map[string]interface{}, fp []string, val interface{}) error {
   417  	m2, err := getParentMap(m, fp, true)
   418  	if err != nil {
   419  		return err
   420  	}
   421  	m2[fp[len(fp)-1]] = val
   422  	return nil
   423  }
   424  
   425  // Delete the value from m at the given field path, if it exists.
   426  func deleteAtFieldPath(m map[string]interface{}, fp []string) {
   427  	m2, _ := getParentMap(m, fp, false) // ignore error
   428  	if m2 != nil {
   429  		delete(m2, fp[len(fp)-1])
   430  	}
   431  }
   432  
   433  // getParentMap returns the map that directly contains the given field path;
   434  // that is, the value of m at the field path that excludes the last component
   435  // of fp. If a non-map is encountered along the way, an InvalidArgument error is
   436  // returned. If nil is encountered, nil is returned unless create is true, in
   437  // which case a map is added at that point.
   438  func getParentMap(m map[string]interface{}, fp []string, create bool) (map[string]interface{}, error) {
   439  	var ok bool
   440  	for _, k := range fp[:len(fp)-1] {
   441  		if m[k] == nil {
   442  			if !create {
   443  				return nil, nil
   444  			}
   445  			m[k] = map[string]interface{}{}
   446  		}
   447  		m, ok = m[k].(map[string]interface{})
   448  		if !ok {
   449  			return nil, gcerr.Newf(gcerr.InvalidArgument, nil, "invalid field path %q at %q", strings.Join(fp, "."), k)
   450  		}
   451  	}
   452  	return m, nil
   453  }
   454  
   455  // RevisionToBytes implements driver.RevisionToBytes.
   456  func (c *collection) RevisionToBytes(rev interface{}) ([]byte, error) {
   457  	r, ok := rev.(int64)
   458  	if !ok {
   459  		return nil, gcerr.Newf(gcerr.InvalidArgument, nil, "revision %v of type %[1]T is not an int64", rev)
   460  	}
   461  	return strconv.AppendInt(nil, r, 10), nil
   462  }
   463  
   464  // BytesToRevision implements driver.BytesToRevision.
   465  func (c *collection) BytesToRevision(b []byte) (interface{}, error) {
   466  	return strconv.ParseInt(string(b), 10, 64)
   467  }
   468  
   469  // As implements driver.As.
   470  func (c *collection) As(i interface{}) bool { return false }
   471  
   472  // As implements driver.Collection.ErrorAs.
   473  func (c *collection) ErrorAs(err error, i interface{}) bool { return false }
   474  
   475  // Close implements driver.Collection.Close.
   476  // If the collection was created with a Filename option, Close writes the
   477  // collection's documents to the file.
   478  func (c *collection) Close() error {
   479  	if c.opts.onClose != nil {
   480  		c.opts.onClose()
   481  	}
   482  	return saveDocs(c.opts.Filename, c.docs)
   483  }
   484  
   485  type mapOfDocs = map[interface{}]storedDoc
   486  
   487  // Read a map from the filename if is is not empty and the file exists.
   488  // Otherwise return an empty (not nil) map.
   489  func loadDocs(filename string) (mapOfDocs, error) {
   490  	if filename == "" {
   491  		return mapOfDocs{}, nil
   492  	}
   493  	f, err := os.Open(filename)
   494  	if err != nil {
   495  		if !os.IsNotExist(err) {
   496  			return nil, err
   497  		}
   498  		// If the file doesn't exist, return an empty map without error.
   499  		return mapOfDocs{}, nil
   500  	}
   501  	defer f.Close()
   502  	var m mapOfDocs
   503  	if err := gob.NewDecoder(f).Decode(&m); err != nil {
   504  		return nil, fmt.Errorf("failed to decode from %q: %v", filename, err)
   505  	}
   506  	return m, nil
   507  }
   508  
   509  // saveDocs saves m to filename if filename is not empty.
   510  func saveDocs(filename string, m mapOfDocs) error {
   511  	if filename == "" {
   512  		return nil
   513  	}
   514  	f, err := os.Create(filename)
   515  	if err != nil {
   516  		return err
   517  	}
   518  	if err := gob.NewEncoder(f).Encode(m); err != nil {
   519  		_ = f.Close()
   520  		return fmt.Errorf("failed to encode to %q: %v", filename, err)
   521  	}
   522  	return f.Close()
   523  }