github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/datastore/proxy/proxy_test/mock.go (about)

     1  package proxy_test
     2  
     3  import (
     4  	"context"
     5  
     6  	v1 "github.com/authzed/authzed-go/proto/authzed/api/v1"
     7  	"github.com/stretchr/testify/mock"
     8  
     9  	"github.com/authzed/spicedb/pkg/datastore"
    10  	"github.com/authzed/spicedb/pkg/datastore/options"
    11  	core "github.com/authzed/spicedb/pkg/proto/core/v1"
    12  )
    13  
    14  type MockDatastore struct {
    15  	mock.Mock
    16  }
    17  
    18  func (dm *MockDatastore) SnapshotReader(rev datastore.Revision) datastore.Reader {
    19  	args := dm.Called(rev)
    20  	return args.Get(0).(datastore.Reader)
    21  }
    22  
    23  func (dm *MockDatastore) ReadWriteTx(
    24  	ctx context.Context,
    25  	f datastore.TxUserFunc,
    26  	opts ...options.RWTOptionsOption,
    27  ) (datastore.Revision, error) {
    28  	args := dm.Called(opts)
    29  	mockRWT := args.Get(0).(datastore.ReadWriteTransaction)
    30  
    31  	if err := f(ctx, mockRWT); err != nil {
    32  		return datastore.NoRevision, err
    33  	}
    34  
    35  	return args.Get(1).(datastore.Revision), args.Error(2)
    36  }
    37  
    38  func (dm *MockDatastore) OptimizedRevision(_ context.Context) (datastore.Revision, error) {
    39  	args := dm.Called()
    40  	return args.Get(0).(datastore.Revision), args.Error(1)
    41  }
    42  
    43  func (dm *MockDatastore) HeadRevision(_ context.Context) (datastore.Revision, error) {
    44  	args := dm.Called()
    45  	return args.Get(0).(datastore.Revision), args.Error(1)
    46  }
    47  
    48  func (dm *MockDatastore) CheckRevision(_ context.Context, revision datastore.Revision) error {
    49  	args := dm.Called(revision)
    50  	return args.Error(0)
    51  }
    52  
    53  func (dm *MockDatastore) RevisionFromString(s string) (datastore.Revision, error) {
    54  	args := dm.Called(s)
    55  	return args.Get(0).(datastore.Revision), args.Error(1)
    56  }
    57  
    58  func (dm *MockDatastore) Watch(_ context.Context, afterRevision datastore.Revision, _ datastore.WatchOptions) (<-chan *datastore.RevisionChanges, <-chan error) {
    59  	args := dm.Called(afterRevision)
    60  	return args.Get(0).(<-chan *datastore.RevisionChanges), args.Get(1).(<-chan error)
    61  }
    62  
    63  func (dm *MockDatastore) ReadyState(_ context.Context) (datastore.ReadyState, error) {
    64  	args := dm.Called()
    65  	return args.Get(0).(datastore.ReadyState), args.Error(1)
    66  }
    67  
    68  func (dm *MockDatastore) Features(_ context.Context) (*datastore.Features, error) {
    69  	args := dm.Called()
    70  	return args.Get(0).(*datastore.Features), args.Error(1)
    71  }
    72  
    73  func (dm *MockDatastore) Statistics(_ context.Context) (datastore.Stats, error) {
    74  	args := dm.Called()
    75  	return args.Get(0).(datastore.Stats), args.Error(1)
    76  }
    77  
    78  func (dm *MockDatastore) Close() error {
    79  	args := dm.Called()
    80  	return args.Error(0)
    81  }
    82  
    83  type MockReader struct {
    84  	mock.Mock
    85  }
    86  
    87  func (dm *MockReader) ReadNamespaceByName(
    88  	_ context.Context,
    89  	nsName string,
    90  ) (*core.NamespaceDefinition, datastore.Revision, error) {
    91  	args := dm.Called(nsName)
    92  
    93  	var def *core.NamespaceDefinition
    94  	if args.Get(0) != nil {
    95  		def = args.Get(0).(*core.NamespaceDefinition)
    96  	}
    97  
    98  	return def, args.Get(1).(datastore.Revision), args.Error(2)
    99  }
   100  
   101  func (dm *MockReader) QueryRelationships(
   102  	_ context.Context,
   103  	filter datastore.RelationshipsFilter,
   104  	options ...options.QueryOptionsOption,
   105  ) (datastore.RelationshipIterator, error) {
   106  	callArgs := make([]interface{}, 0, len(options)+1)
   107  	callArgs = append(callArgs, filter)
   108  	for _, option := range options {
   109  		callArgs = append(callArgs, option)
   110  	}
   111  
   112  	args := dm.Called(callArgs...)
   113  	var results datastore.RelationshipIterator
   114  	if args.Get(0) != nil {
   115  		results = args.Get(0).(datastore.RelationshipIterator)
   116  	}
   117  
   118  	return results, args.Error(1)
   119  }
   120  
   121  func (dm *MockReader) ReverseQueryRelationships(
   122  	_ context.Context,
   123  	subjectsFilter datastore.SubjectsFilter,
   124  	options ...options.ReverseQueryOptionsOption,
   125  ) (datastore.RelationshipIterator, error) {
   126  	callArgs := make([]interface{}, 0, len(options)+1)
   127  	callArgs = append(callArgs, subjectsFilter)
   128  	for _, option := range options {
   129  		callArgs = append(callArgs, option)
   130  	}
   131  
   132  	args := dm.Called(callArgs...)
   133  	var results datastore.RelationshipIterator
   134  	if args.Get(0) != nil {
   135  		results = args.Get(0).(datastore.RelationshipIterator)
   136  	}
   137  
   138  	return results, args.Error(1)
   139  }
   140  
   141  func (dm *MockReader) ListAllNamespaces(_ context.Context) ([]datastore.RevisionedNamespace, error) {
   142  	args := dm.Called()
   143  	return args.Get(0).([]datastore.RevisionedNamespace), args.Error(1)
   144  }
   145  
   146  func (dm *MockReader) LookupNamespacesWithNames(_ context.Context, nsNames []string) ([]datastore.RevisionedNamespace, error) {
   147  	args := dm.Called(nsNames)
   148  	return args.Get(0).([]datastore.RevisionedNamespace), args.Error(1)
   149  }
   150  
   151  func (dm *MockReader) ReadCaveatByName(_ context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) {
   152  	args := dm.Called(name)
   153  
   154  	var def *core.CaveatDefinition
   155  	if args.Get(0) != nil {
   156  		def = args.Get(0).(*core.CaveatDefinition)
   157  	}
   158  
   159  	return def, args.Get(1).(datastore.Revision), args.Error(2)
   160  }
   161  
   162  func (dm *MockReader) LookupCaveatsWithNames(_ context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) {
   163  	args := dm.Called(caveatNames)
   164  	return args.Get(0).([]datastore.RevisionedCaveat), args.Error(1)
   165  }
   166  
   167  func (dm *MockReader) ListAllCaveats(_ context.Context) ([]datastore.RevisionedCaveat, error) {
   168  	args := dm.Called()
   169  	return args.Get(0).([]datastore.RevisionedCaveat), args.Error(1)
   170  }
   171  
   172  type MockReadWriteTransaction struct {
   173  	mock.Mock
   174  }
   175  
   176  func (dm *MockReadWriteTransaction) ReadNamespaceByName(
   177  	_ context.Context,
   178  	nsName string,
   179  ) (*core.NamespaceDefinition, datastore.Revision, error) {
   180  	args := dm.Called(nsName)
   181  
   182  	var def *core.NamespaceDefinition
   183  	if args.Get(0) != nil {
   184  		def = args.Get(0).(*core.NamespaceDefinition)
   185  	}
   186  
   187  	return def, args.Get(1).(datastore.Revision), args.Error(2)
   188  }
   189  
   190  func (dm *MockReadWriteTransaction) QueryRelationships(
   191  	_ context.Context,
   192  	filter datastore.RelationshipsFilter,
   193  	options ...options.QueryOptionsOption,
   194  ) (datastore.RelationshipIterator, error) {
   195  	callArgs := make([]interface{}, 0, len(options)+1)
   196  	callArgs = append(callArgs, filter)
   197  	for _, option := range options {
   198  		callArgs = append(callArgs, option)
   199  	}
   200  
   201  	args := dm.Called(callArgs...)
   202  	var results datastore.RelationshipIterator
   203  	if args.Get(0) != nil {
   204  		results = args.Get(0).(datastore.RelationshipIterator)
   205  	}
   206  
   207  	return results, args.Error(1)
   208  }
   209  
   210  func (dm *MockReadWriteTransaction) ReverseQueryRelationships(
   211  	_ context.Context,
   212  	subjectsFilter datastore.SubjectsFilter,
   213  	options ...options.ReverseQueryOptionsOption,
   214  ) (datastore.RelationshipIterator, error) {
   215  	callArgs := make([]interface{}, 0, len(options)+1)
   216  	callArgs = append(callArgs, subjectsFilter)
   217  	for _, option := range options {
   218  		callArgs = append(callArgs, option)
   219  	}
   220  
   221  	args := dm.Called(callArgs...)
   222  	var results datastore.RelationshipIterator
   223  	if args.Get(0) != nil {
   224  		results = args.Get(0).(datastore.RelationshipIterator)
   225  	}
   226  
   227  	return results, args.Error(1)
   228  }
   229  
   230  func (dm *MockReadWriteTransaction) ListAllNamespaces(_ context.Context) ([]datastore.RevisionedNamespace, error) {
   231  	args := dm.Called()
   232  	return args.Get(0).([]datastore.RevisionedNamespace), args.Error(1)
   233  }
   234  
   235  func (dm *MockReadWriteTransaction) LookupNamespacesWithNames(_ context.Context, nsNames []string) ([]datastore.RevisionedNamespace, error) {
   236  	args := dm.Called(nsNames)
   237  	return args.Get(0).([]datastore.RevisionedNamespace), args.Error(1)
   238  }
   239  
   240  func (dm *MockReadWriteTransaction) WriteRelationships(_ context.Context, mutations []*core.RelationTupleUpdate) error {
   241  	args := dm.Called(mutations)
   242  	return args.Error(0)
   243  }
   244  
   245  func (dm *MockReadWriteTransaction) DeleteRelationships(_ context.Context, filter *v1.RelationshipFilter, options ...options.DeleteOptionsOption) (bool, error) {
   246  	args := dm.Called(filter)
   247  	return false, args.Error(0)
   248  }
   249  
   250  func (dm *MockReadWriteTransaction) WriteNamespaces(_ context.Context, newConfigs ...*core.NamespaceDefinition) error {
   251  	args := dm.Called(newConfigs)
   252  	return args.Error(0)
   253  }
   254  
   255  func (dm *MockReadWriteTransaction) DeleteNamespaces(_ context.Context, nsNames ...string) error {
   256  	xs := make([]any, 0, len(nsNames))
   257  	for _, nsName := range nsNames {
   258  		xs = append(xs, nsName)
   259  	}
   260  
   261  	args := dm.Called(xs...)
   262  	return args.Error(0)
   263  }
   264  
   265  func (dm *MockReadWriteTransaction) BulkLoad(_ context.Context, iter datastore.BulkWriteRelationshipSource) (uint64, error) {
   266  	args := dm.Called(iter)
   267  	return uint64(args.Int(0)), args.Error(1)
   268  }
   269  
   270  func (dm *MockReadWriteTransaction) ReadCaveatByName(_ context.Context, name string) (*core.CaveatDefinition, datastore.Revision, error) {
   271  	args := dm.Called(name)
   272  
   273  	var def *core.CaveatDefinition
   274  	if args.Get(0) != nil {
   275  		def = args.Get(0).(*core.CaveatDefinition)
   276  	}
   277  
   278  	return def, args.Get(1).(datastore.Revision), args.Error(2)
   279  }
   280  
   281  func (dm *MockReadWriteTransaction) LookupCaveatsWithNames(_ context.Context, caveatNames []string) ([]datastore.RevisionedCaveat, error) {
   282  	args := dm.Called(caveatNames)
   283  	return args.Get(0).([]datastore.RevisionedCaveat), args.Error(1)
   284  }
   285  
   286  func (dm *MockReadWriteTransaction) ListAllCaveats(_ context.Context) ([]datastore.RevisionedCaveat, error) {
   287  	args := dm.Called()
   288  	return args.Get(0).([]datastore.RevisionedCaveat), args.Error(1)
   289  }
   290  
   291  func (dm *MockReadWriteTransaction) WriteCaveats(_ context.Context, caveats []*core.CaveatDefinition) error {
   292  	args := dm.Called(caveats)
   293  	return args.Error(0)
   294  }
   295  
   296  func (dm *MockReadWriteTransaction) DeleteCaveats(_ context.Context, _ []string) error {
   297  	panic("not used")
   298  }
   299  
   300  var (
   301  	_ datastore.Datastore            = &MockDatastore{}
   302  	_ datastore.Reader               = &MockReader{}
   303  	_ datastore.ReadWriteTransaction = &MockReadWriteTransaction{}
   304  )