github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/kv/kvserver/protectedts/ptstorage/storage_with_database.go (about)

     1  // Copyright 2019 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package ptstorage
    12  
    13  import (
    14  	"context"
    15  
    16  	"github.com/cockroachdb/cockroach/pkg/kv"
    17  	"github.com/cockroachdb/cockroach/pkg/kv/kvserver/protectedts"
    18  	"github.com/cockroachdb/cockroach/pkg/kv/kvserver/protectedts/ptpb"
    19  	"github.com/cockroachdb/cockroach/pkg/util/uuid"
    20  )
    21  
    22  // WithDatabase wraps s such that any calls made with a nil *Txn will be wrapped
    23  // in a call to db.Txn. This is often convenient in testing.
    24  func WithDatabase(s protectedts.Storage, db *kv.DB) protectedts.Storage {
    25  	return &storageWithDatabase{s: s, db: db}
    26  }
    27  
    28  type storageWithDatabase struct {
    29  	db *kv.DB
    30  	s  protectedts.Storage
    31  }
    32  
    33  func (s *storageWithDatabase) Protect(ctx context.Context, txn *kv.Txn, r *ptpb.Record) error {
    34  	if txn == nil {
    35  		return s.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
    36  			return s.s.Protect(ctx, txn, r)
    37  		})
    38  	}
    39  	return s.s.Protect(ctx, txn, r)
    40  }
    41  
    42  func (s *storageWithDatabase) GetRecord(
    43  	ctx context.Context, txn *kv.Txn, id uuid.UUID,
    44  ) (r *ptpb.Record, err error) {
    45  	if txn == nil {
    46  		err = s.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
    47  			r, err = s.s.GetRecord(ctx, txn, id)
    48  			return err
    49  		})
    50  		return r, err
    51  	}
    52  	return s.s.GetRecord(ctx, txn, id)
    53  }
    54  
    55  func (s *storageWithDatabase) MarkVerified(ctx context.Context, txn *kv.Txn, id uuid.UUID) error {
    56  	if txn == nil {
    57  		return s.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
    58  			return s.s.Release(ctx, txn, id)
    59  		})
    60  	}
    61  	return s.s.Release(ctx, txn, id)
    62  }
    63  
    64  func (s *storageWithDatabase) Release(ctx context.Context, txn *kv.Txn, id uuid.UUID) error {
    65  	if txn == nil {
    66  		return s.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
    67  			return s.s.Release(ctx, txn, id)
    68  		})
    69  	}
    70  	return s.s.Release(ctx, txn, id)
    71  }
    72  
    73  func (s *storageWithDatabase) GetMetadata(
    74  	ctx context.Context, txn *kv.Txn,
    75  ) (md ptpb.Metadata, err error) {
    76  	if txn == nil {
    77  		err = s.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
    78  			md, err = s.s.GetMetadata(ctx, txn)
    79  			return err
    80  		})
    81  		return md, err
    82  	}
    83  	return s.s.GetMetadata(ctx, txn)
    84  }
    85  
    86  func (s *storageWithDatabase) GetState(
    87  	ctx context.Context, txn *kv.Txn,
    88  ) (state ptpb.State, err error) {
    89  	if txn == nil {
    90  		err = s.db.Txn(ctx, func(ctx context.Context, txn *kv.Txn) error {
    91  			state, err = s.s.GetState(ctx, txn)
    92  			return err
    93  		})
    94  		return state, err
    95  	}
    96  	return s.s.GetState(ctx, txn)
    97  }