github.com/1aal/kubeblocks@v0.0.0-20231107070852-e1c03e598921/pkg/testutil/k8s/k8sclient_util.go (about)

     1  /*
     2  Copyright (C) 2022-2023 ApeCloud Co., Ltd
     3  
     4  This file is part of KubeBlocks project
     5  
     6  This program is free software: you can redistribute it and/or modify
     7  it under the terms of the GNU Affero General Public License as published by
     8  the Free Software Foundation, either version 3 of the License, or
     9  (at your option) any later version.
    10  
    11  This program is distributed in the hope that it will be useful
    12  but WITHOUT ANY WARRANTY; without even the implied warranty of
    13  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    14  GNU Affero General Public License for more details.
    15  
    16  You should have received a copy of the GNU Affero General Public License
    17  along with this program.  If not, see <http://www.gnu.org/licenses/>.
    18  */
    19  
    20  package testutil
    21  
    22  import (
    23  	"context"
    24  	"fmt"
    25  	"reflect"
    26  	"sync"
    27  
    28  	"github.com/golang/mock/gomock"
    29  	ginkgov2 "github.com/onsi/ginkgo/v2"
    30  	apierrors "k8s.io/apimachinery/pkg/api/errors"
    31  	apimeta "k8s.io/apimachinery/pkg/api/meta"
    32  	"k8s.io/apimachinery/pkg/runtime"
    33  	"k8s.io/apimachinery/pkg/runtime/schema"
    34  	"sigs.k8s.io/controller-runtime/pkg/client"
    35  
    36  	mock_client "github.com/1aal/kubeblocks/pkg/testutil/k8s/mocks"
    37  )
    38  
    39  type CallMockOptions = func(call *gomock.Call)
    40  type CallerFunction = func() *gomock.Call
    41  type DoReturnedFunction = any
    42  
    43  type HandleGetReturnedObject = func(key client.ObjectKey, obj client.Object) error
    44  type HandlePatchReturnedObject = func(obj client.Object, patch client.Patch) error
    45  type HandleListReturnedObject = func(list client.ObjectList) error
    46  type HandleCreateReturnedObject = func(obj client.Object) error
    47  
    48  type CallMockReturnedOptions = func(callHelper *callHelper, call *gomock.Call)
    49  type CallMockGetReturnedOptions = func(callHelper *callHelper, call *gomock.Call, _ HandleGetReturnedObject) error
    50  type CallMockPatchReturnedOptions = func(callHelper *callHelper, call *gomock.Call, _ HandlePatchReturnedObject) error
    51  type CallMockListReturnedOptions = func(callHelper *callHelper, call *gomock.Call, _ HandleListReturnedObject) error
    52  
    53  type callHelper struct {
    54  	callerOnce   sync.Once
    55  	callerFn     CallerFunction
    56  	doReturnedFn DoReturnedFunction
    57  }
    58  
    59  type K8sClientMockHelper struct {
    60  	ctrl         *gomock.Controller
    61  	k8sClient    *mock_client.MockClient
    62  	statusWriter *mock_client.MockStatusWriter
    63  
    64  	getCaller    callHelper
    65  	createCaller callHelper
    66  	updateCaller callHelper
    67  	listCaller   callHelper
    68  	patchCaller  callHelper
    69  	deleteCaller callHelper
    70  }
    71  
    72  func (h *callHelper) Caller(newCaller func() (CallerFunction, DoReturnedFunction)) CallerFunction {
    73  	h.callerOnce.Do(func() {
    74  		h.callerFn, h.doReturnedFn = newCaller()
    75  	})
    76  	return h.callerFn
    77  }
    78  
    79  func (helper *K8sClientMockHelper) Client() client.Client {
    80  	return helper.k8sClient
    81  }
    82  
    83  func (helper *K8sClientMockHelper) StatusWriter() *mock_client.MockStatusWriter {
    84  	return helper.statusWriter
    85  }
    86  
    87  func (helper *K8sClientMockHelper) Controller() *gomock.Controller {
    88  	return helper.ctrl
    89  }
    90  
    91  func (helper *K8sClientMockHelper) Finish() {
    92  	helper.ctrl.Finish()
    93  }
    94  
    95  func (helper *K8sClientMockHelper) mockMethod(callHelper *callHelper, options ...any) {
    96  	for _, option := range options {
    97  		call := callHelper.callerFn()
    98  		switch f := option.(type) {
    99  		case CallMockOptions:
   100  			f(call)
   101  		case CallMockReturnedOptions:
   102  			f(callHelper, call)
   103  		}
   104  	}
   105  }
   106  
   107  func (helper *K8sClientMockHelper) MockStatusMethod() *mock_client.MockStatusWriter {
   108  	if helper.statusWriter == nil {
   109  		helper.statusWriter = mock_client.NewMockStatusWriter(helper.ctrl)
   110  	}
   111  	helper.k8sClient.EXPECT().Status().Return(helper.statusWriter).AnyTimes()
   112  	return helper.statusWriter
   113  }
   114  
   115  func (helper *K8sClientMockHelper) MockGetMethod(options ...any) {
   116  	helper.getCaller.Caller(func() (CallerFunction, DoReturnedFunction) {
   117  		caller := func() *gomock.Call {
   118  			return helper.k8sClient.EXPECT().Get(gomock.Any(), gomock.Any(), gomock.Any())
   119  		}
   120  		doAndReturn := func(caller *gomock.Call, fnWrap HandleGetReturnedObject) {
   121  			caller.DoAndReturn(func(ctx context.Context, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error {
   122  				return fnWrap(key, obj)
   123  			})
   124  		}
   125  		return caller, doAndReturn
   126  	})
   127  	helper.mockMethod(&helper.getCaller, options...)
   128  }
   129  
   130  func (helper *K8sClientMockHelper) MockCreateMethod(options ...any) {
   131  	helper.createCaller.Caller(func() (CallerFunction, DoReturnedFunction) {
   132  		caller := func() *gomock.Call {
   133  			return helper.k8sClient.EXPECT().Create(gomock.Any(), gomock.Any())
   134  		}
   135  		doAndReturn := func(caller *gomock.Call, fnWrap func(obj client.Object) error) {
   136  			caller.DoAndReturn(func(ctx context.Context, obj client.Object, opts ...client.CreateOption) error {
   137  				return fnWrap(obj)
   138  			})
   139  		}
   140  		return caller, doAndReturn
   141  	})
   142  	helper.mockMethod(&helper.createCaller, options...)
   143  }
   144  
   145  func (helper *K8sClientMockHelper) MockUpdateMethod(options ...any) {
   146  	helper.updateCaller.Caller(func() (CallerFunction, DoReturnedFunction) {
   147  		caller := func() *gomock.Call {
   148  			return helper.k8sClient.EXPECT().Update(gomock.Any(), gomock.Any())
   149  		}
   150  		doAndReturn := func(caller *gomock.Call, fnWrap func(obj client.Object) error) {
   151  			caller.DoAndReturn(func(ctx context.Context, obj client.Object, opts ...client.UpdateOption) error {
   152  				return fnWrap(obj)
   153  			})
   154  		}
   155  		return caller, doAndReturn
   156  	})
   157  	helper.mockMethod(&helper.updateCaller, options...)
   158  }
   159  
   160  func (helper *K8sClientMockHelper) MockDeleteMethod(options ...any) {
   161  	helper.deleteCaller.Caller(func() (CallerFunction, DoReturnedFunction) {
   162  		caller := func() *gomock.Call {
   163  			return helper.k8sClient.EXPECT().Delete(gomock.Any(), gomock.Any())
   164  		}
   165  		doAndReturn := func(caller *gomock.Call, fnWrap func(obj client.Object) error) {
   166  			caller.DoAndReturn(func(ctx context.Context, obj client.Object, opts ...client.DeleteOption) error {
   167  				return fnWrap(obj)
   168  			})
   169  		}
   170  		return caller, doAndReturn
   171  	})
   172  	helper.mockMethod(&helper.updateCaller, options...)
   173  }
   174  
   175  func (helper *K8sClientMockHelper) MockListMethod(options ...any) {
   176  	helper.listCaller.Caller(func() (CallerFunction, DoReturnedFunction) {
   177  		caller := func() *gomock.Call {
   178  			return helper.k8sClient.EXPECT().List(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
   179  		}
   180  		doAndReturn := func(caller *gomock.Call, fnWrap HandleListReturnedObject) {
   181  			caller.DoAndReturn(func(ctx context.Context, list client.ObjectList, opts ...client.ListOption) error {
   182  				return fnWrap(list)
   183  			})
   184  		}
   185  		return caller, doAndReturn
   186  	})
   187  	helper.mockMethod(&helper.listCaller, options...)
   188  }
   189  
   190  func (helper *K8sClientMockHelper) MockPatchMethod(options ...any) {
   191  	helper.patchCaller.Caller(func() (CallerFunction, DoReturnedFunction) {
   192  		caller := func() *gomock.Call {
   193  			return helper.k8sClient.EXPECT().Patch(gomock.Any(), gomock.Any(), gomock.Any())
   194  		}
   195  		doAndReturn := func(caller *gomock.Call, fnWrap HandlePatchReturnedObject) {
   196  			caller.DoAndReturn(func(ctx context.Context, obj client.Object, patch client.Patch, opts ...client.PatchOption) error {
   197  				return fnWrap(obj, patch)
   198  			})
   199  		}
   200  		return caller, doAndReturn
   201  	})
   202  	helper.mockMethod(&helper.patchCaller, options...)
   203  }
   204  
   205  func SetupK8sMock() (*gomock.Controller, *mock_client.MockClient) {
   206  	ctrl := gomock.NewController(ginkgov2.GinkgoT())
   207  	client := mock_client.NewMockClient(ctrl)
   208  	return ctrl, client
   209  }
   210  
   211  func SetGetReturnedObject(out client.Object, expectedObj client.Object) {
   212  	outVal := reflect.ValueOf(out)
   213  	objVal := reflect.ValueOf(expectedObj)
   214  	reflect.Indirect(outVal).Set(reflect.Indirect(objVal))
   215  }
   216  
   217  func SetListReturnedObjects(list client.ObjectList, objects []runtime.Object) error {
   218  	return apimeta.SetList(list, objects)
   219  }
   220  
   221  func NewK8sMockClient() *K8sClientMockHelper {
   222  	ctrl, client := SetupK8sMock()
   223  	clientHelper := K8sClientMockHelper{
   224  		ctrl:      ctrl,
   225  		k8sClient: client,
   226  	}
   227  	return &clientHelper
   228  }
   229  
   230  func WithTimes(n int) CallMockOptions {
   231  	return func(call *gomock.Call) {
   232  		call.Times(n)
   233  	}
   234  }
   235  
   236  func WithMinTimes(n int) CallMockOptions {
   237  	return func(call *gomock.Call) {
   238  		call.MinTimes(n)
   239  	}
   240  }
   241  
   242  func WithMaxTimes(n int) CallMockOptions {
   243  	return func(call *gomock.Call) {
   244  		call.MaxTimes(n)
   245  	}
   246  }
   247  
   248  func WithAnyTimes() CallMockOptions {
   249  	return func(call *gomock.Call) {
   250  		call.AnyTimes()
   251  	}
   252  }
   253  
   254  func WithFailed(err error, times ...CallMockOptions) CallMockOptions {
   255  	return func(call *gomock.Call) {
   256  		call.Return(err).AnyTimes()
   257  		handleTimes(call, times...)
   258  	}
   259  }
   260  
   261  func WithSucceed(times ...CallMockOptions) CallMockOptions {
   262  	return func(call *gomock.Call) {
   263  		call.Return(nil).AnyTimes()
   264  		handleTimes(call, times...)
   265  	}
   266  }
   267  
   268  func WithConstructListReturnedResult(r []runtime.Object) HandleListReturnedObject {
   269  	return func(list client.ObjectList) error {
   270  		return SetListReturnedObjects(list, r)
   271  	}
   272  }
   273  
   274  type CallbackFn = func(sequence int, r []runtime.Object)
   275  
   276  func WithConstructListSequenceResult(mockObjsList [][]runtime.Object, fns ...CallbackFn) HandleListReturnedObject {
   277  	sequenceAccessCounter := 0
   278  	return func(list client.ObjectList) error {
   279  		for _, fn := range fns {
   280  			fn(sequenceAccessCounter, mockObjsList[sequenceAccessCounter])
   281  		}
   282  		if err := SetListReturnedObjects(list, mockObjsList[sequenceAccessCounter]); err != nil {
   283  			return err
   284  		}
   285  		if sequenceAccessCounter < len(mockObjsList)-1 {
   286  			sequenceAccessCounter++
   287  		}
   288  		return nil
   289  	}
   290  }
   291  
   292  type MockGetReturned struct {
   293  	Object client.Object
   294  	Err    error
   295  }
   296  
   297  func WithConstructSequenceResult(mockObjs map[client.ObjectKey][]MockGetReturned) HandleGetReturnedObject {
   298  	sequenceAccessCounter := make(map[client.ObjectKey]int, len(mockObjs))
   299  	return func(key client.ObjectKey, obj client.Object) error {
   300  		accessibleSequence, ok := mockObjs[key]
   301  		if !ok {
   302  			return fmt.Errorf("not existed key: %v", key)
   303  		}
   304  
   305  		index := sequenceAccessCounter[key]
   306  		mockReturned := accessibleSequence[index]
   307  		if index < len(accessibleSequence)-1 {
   308  			sequenceAccessCounter[key]++
   309  		}
   310  
   311  		switch {
   312  		case mockReturned.Err != nil:
   313  			return mockReturned.Err
   314  		case mockReturned.Object != nil:
   315  			SetGetReturnedObject(obj, mockReturned.Object)
   316  			return nil
   317  		default:
   318  			return apierrors.NewNotFound(schema.GroupResource{Group: "unknown", Resource: "unknown"}, key.Name)
   319  		}
   320  	}
   321  }
   322  
   323  func WithConstructGetResult(mockObj client.Object) HandleGetReturnedObject {
   324  	return func(key client.ObjectKey, obj client.Object) error {
   325  		SetGetReturnedObject(obj, mockObj)
   326  		return nil
   327  	}
   328  }
   329  
   330  func WithCreatedSucceedResult() HandleCreateReturnedObject {
   331  	return func(obj client.Object) error {
   332  		_ = obj
   333  		return nil
   334  	}
   335  }
   336  
   337  func WithCreatedFailedResult() HandleCreateReturnedObject {
   338  	return func(obj client.Object) error {
   339  		_ = obj
   340  		return fmt.Errorf("create failed")
   341  	}
   342  }
   343  
   344  type Getter = func(key client.ObjectKey, obj client.Object) (bool, error)
   345  
   346  func WithConstructSimpleGetResult(mockObjs []client.Object, get ...Getter) HandleGetReturnedObject {
   347  	mockMap := make(map[client.ObjectKey]client.Object, len(mockObjs))
   348  	for _, obj := range mockObjs {
   349  		mockMap[client.ObjectKeyFromObject(obj)] = obj
   350  	}
   351  	return func(key client.ObjectKey, obj client.Object) error {
   352  		if mockObj, ok := mockMap[key]; ok {
   353  			SetGetReturnedObject(obj, mockObj)
   354  			return nil
   355  		}
   356  		if len(get) > 0 {
   357  			processed, err := get[0](key, obj)
   358  			if processed {
   359  				return err
   360  			}
   361  		}
   362  		return apierrors.NewNotFound(schema.GroupResource{Group: "unknown", Resource: "unknown"}, key.Name)
   363  	}
   364  }
   365  
   366  func WithListReturned(action HandleListReturnedObject, times ...CallMockOptions) CallMockReturnedOptions {
   367  	return func(helper *callHelper, call *gomock.Call) {
   368  		switch fn := helper.doReturnedFn.(type) {
   369  		case func(_ *gomock.Call, _ HandleListReturnedObject):
   370  			fn(call, func(list client.ObjectList) error {
   371  				return action(list)
   372  			})
   373  			handleTimes(call, times...)
   374  		default:
   375  			panic("not walk here!")
   376  		}
   377  	}
   378  }
   379  
   380  func handleTimes(call *gomock.Call, times ...CallMockOptions) {
   381  	for _, time := range times {
   382  		time(call)
   383  	}
   384  }
   385  
   386  func WithGetReturned(action HandleGetReturnedObject, times ...CallMockOptions) CallMockReturnedOptions {
   387  	return func(helper *callHelper, call *gomock.Call) {
   388  		switch fn := helper.doReturnedFn.(type) {
   389  		case func(_ *gomock.Call, _ HandleGetReturnedObject):
   390  			fn(call, func(key client.ObjectKey, obj client.Object) error {
   391  				return action(key, obj)
   392  			})
   393  			handleTimes(call, times...)
   394  		default:
   395  			panic("impossible dead end!")
   396  		}
   397  	}
   398  }
   399  
   400  func WithCreateReturned(action HandleCreateReturnedObject, times ...CallMockOptions) CallMockReturnedOptions {
   401  	return func(helper *callHelper, call *gomock.Call) {
   402  		switch fn := helper.doReturnedFn.(type) {
   403  		case func(_ *gomock.Call, _ HandleCreateReturnedObject):
   404  			fn(call, func(obj client.Object) error {
   405  				return action(obj)
   406  			})
   407  			handleTimes(call, times...)
   408  		default:
   409  			panic("impossible dead end!")
   410  		}
   411  	}
   412  }
   413  
   414  func WithPatchReturned(action HandlePatchReturnedObject, times ...CallMockOptions) CallMockReturnedOptions {
   415  	return func(helper *callHelper, call *gomock.Call) {
   416  		switch fn := helper.doReturnedFn.(type) {
   417  		case func(_ *gomock.Call, _ HandlePatchReturnedObject):
   418  			fn(call, func(obj client.Object, patch client.Patch) error {
   419  				return action(obj, patch)
   420  			})
   421  			handleTimes(call, times...)
   422  		default:
   423  			panic("impossible dead end!")
   424  		}
   425  	}
   426  }