go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/gae/filter/count/rds.go (about)

     1  // Copyright 2015 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package count
    16  
    17  import (
    18  	"context"
    19  
    20  	ds "go.chromium.org/luci/gae/service/datastore"
    21  )
    22  
    23  // DSCounter is the counter object for the datastore service.
    24  type DSCounter struct {
    25  	AllocateIDs      Entry
    26  	DecodeCursor     Entry
    27  	RunInTransaction Entry
    28  	Run              Entry
    29  	Count            Entry
    30  	DeleteMulti      Entry
    31  	GetMulti         Entry
    32  	PutMulti         Entry
    33  }
    34  
    35  type dsCounter struct {
    36  	c *DSCounter
    37  
    38  	ds ds.RawInterface
    39  }
    40  
    41  var _ ds.RawInterface = (*dsCounter)(nil)
    42  
    43  func (r *dsCounter) AllocateIDs(keys []*ds.Key, cb ds.NewKeyCB) error {
    44  	return r.c.AllocateIDs.up(r.ds.AllocateIDs(keys, cb))
    45  }
    46  
    47  func (r *dsCounter) DecodeCursor(s string) (ds.Cursor, error) {
    48  	cursor, err := r.ds.DecodeCursor(s)
    49  	return cursor, r.c.DecodeCursor.up(err)
    50  }
    51  
    52  func (r *dsCounter) Run(q *ds.FinalizedQuery, cb ds.RawRunCB) error {
    53  	return r.c.Run.upFilterStop(r.ds.Run(q, cb))
    54  }
    55  
    56  func (r *dsCounter) Count(q *ds.FinalizedQuery) (int64, error) {
    57  	count, err := r.ds.Count(q)
    58  	return count, r.c.Count.up(err)
    59  }
    60  
    61  func (r *dsCounter) RunInTransaction(f func(context.Context) error, opts *ds.TransactionOptions) error {
    62  	return r.c.RunInTransaction.up(r.ds.RunInTransaction(f, opts))
    63  }
    64  
    65  func (r *dsCounter) DeleteMulti(keys []*ds.Key, cb ds.DeleteMultiCB) error {
    66  	return r.c.DeleteMulti.upFilterStop(r.ds.DeleteMulti(keys, cb))
    67  }
    68  
    69  func (r *dsCounter) GetMulti(keys []*ds.Key, meta ds.MultiMetaGetter, cb ds.GetMultiCB) error {
    70  	return r.c.GetMulti.upFilterStop(r.ds.GetMulti(keys, meta, cb))
    71  }
    72  
    73  func (r *dsCounter) PutMulti(keys []*ds.Key, vals []ds.PropertyMap, cb ds.NewKeyCB) error {
    74  	return r.c.PutMulti.upFilterStop(r.ds.PutMulti(keys, vals, cb))
    75  }
    76  
    77  func (r *dsCounter) CurrentTransaction() ds.Transaction {
    78  	return r.ds.CurrentTransaction()
    79  }
    80  func (r *dsCounter) WithoutTransaction() context.Context {
    81  	return r.ds.WithoutTransaction()
    82  }
    83  
    84  func (r *dsCounter) Constraints() ds.Constraints {
    85  	return r.ds.Constraints()
    86  }
    87  
    88  func (r *dsCounter) GetTestable() ds.Testable {
    89  	return r.ds.GetTestable()
    90  }
    91  
    92  // FilterRDS installs a counter datastore filter in the context.
    93  func FilterRDS(c context.Context) (context.Context, *DSCounter) {
    94  	state := &DSCounter{}
    95  	return ds.AddRawFilters(c, func(ic context.Context, ds ds.RawInterface) ds.RawInterface {
    96  		return &dsCounter{state, ds}
    97  	}), state
    98  }
    99  
   100  // upFilterStop wraps up, handling the special case datastore.Stop error.
   101  // datastore.Stop will pass through this function, but, unlike other error
   102  // codes, will be counted as a success.
   103  func (e *Entry) upFilterStop(err error) error {
   104  	upErr := err
   105  	if upErr == ds.Stop {
   106  		upErr = nil
   107  	}
   108  	e.up(upErr)
   109  	return err
   110  }