github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/utils/rpcwrapper/rpc_handlemock.go (about)

     1  package rpcwrapper
     2  
     3  import (
     4  	"context"
     5  	"net/rpc"
     6  	"sync"
     7  	"testing"
     8  )
     9  
    10  // MockRPCHdl is mock of rpchdl
    11  type MockRPCHdl struct {
    12  	Client  *rpc.Client
    13  	Channel string
    14  }
    15  
    16  type mockedMethods struct {
    17  	NewRPCClientMock     func(contextID string, channel string, secret string) error
    18  	GetRPCClientMock     func(contextID string) (*RPCHdl, error)
    19  	RemoteCallMock       func(contextID string, methodName string, req *Request, resp *Response) error
    20  	DestroyRPCClientMock func(contextID string)
    21  	StartServerMock      func(ctx context.Context, protocol string, path string, handler interface{}) error
    22  	ProcessMessageMock   func(req *Request, secret string) bool
    23  	ContextListMock      func() []string
    24  	CheckValidityMock    func(req *Request, secret string) bool
    25  }
    26  
    27  // TestRPCClient is a RPC Client used for test
    28  type TestRPCClient interface {
    29  	RPCClient
    30  	MockNewRPCClient(t *testing.T, impl func(contextID string, channel string, secret string) error)
    31  	MockGetRPCClient(t *testing.T, impl func(contextID string) (*RPCHdl, error))
    32  	MockRemoteCall(t *testing.T, impl func(contextID string, methodName string, req *Request, resp *Response) error)
    33  	MockDestroyRPCClient(t *testing.T, impl func(contextID string))
    34  	MockContextList(t *testing.T, impl func() []string)
    35  	MockCheckValidity(t *testing.T, impl func(req *Request, secret string) bool)
    36  }
    37  
    38  // TestRPCServer is a RPC Server used for test
    39  type TestRPCServer interface {
    40  	RPCServer
    41  	MockStartServer(t *testing.T, impl func(ctx context.Context, protocol string, path string, handler interface{}) error)
    42  	MockProcessMessage(t *testing.T, impl func(req *Request, secret string) bool)
    43  	MockCheckValidity(t *testing.T, impl func(req *Request, secret string) bool)
    44  }
    45  
    46  type testRPC struct {
    47  	mocks       map[*testing.T]*mockedMethods
    48  	lock        *sync.Mutex
    49  	currentTest *testing.T
    50  }
    51  
    52  // NewTestRPCServer is a Test RPC Server
    53  func NewTestRPCServer() TestRPCServer {
    54  	return &testRPC{
    55  		lock:  &sync.Mutex{},
    56  		mocks: map[*testing.T]*mockedMethods{},
    57  	}
    58  }
    59  
    60  // NewTestRPCClient is a Test RPC Client
    61  func NewTestRPCClient() TestRPCClient {
    62  	return &testRPC{
    63  		lock:  &sync.Mutex{},
    64  		mocks: map[*testing.T]*mockedMethods{},
    65  	}
    66  }
    67  
    68  // MockNewRPCClient mocks the NewRPCClient function
    69  func (m *testRPC) MockNewRPCClient(t *testing.T, impl func(contextID string, channel string, secret string) error) {
    70  	m.currentMocks(t).NewRPCClientMock = impl
    71  }
    72  
    73  // MockGetRPCClient mocks the GetRPCClient function
    74  func (m *testRPC) MockGetRPCClient(t *testing.T, impl func(contextID string) (*RPCHdl, error)) {
    75  	m.currentMocks(t).GetRPCClientMock = impl
    76  }
    77  
    78  // MockRemoteCall mocks the RemoteCall function
    79  func (m *testRPC) MockRemoteCall(t *testing.T, impl func(contextID string, methodName string, req *Request, resp *Response) error) {
    80  	m.currentMocks(t).RemoteCallMock = impl
    81  }
    82  
    83  // MockDestroyRPCClient mocks the DestroyRPCClient function
    84  func (m *testRPC) MockDestroyRPCClient(t *testing.T, impl func(contextID string)) {
    85  	m.currentMocks(t).DestroyRPCClientMock = impl
    86  }
    87  
    88  // MockStartServer mocks the StartServer function
    89  func (m *testRPC) MockStartServer(t *testing.T, impl func(ctx context.Context, protocol string, path string, handler interface{}) error) {
    90  	m.currentMocks(t).StartServerMock = impl
    91  
    92  }
    93  
    94  // MockProcessMessage mocks the ProcessMessage function
    95  func (m *testRPC) MockProcessMessage(t *testing.T, impl func(req *Request, secret string) bool) {
    96  	m.currentMocks(t).ProcessMessageMock = impl
    97  }
    98  
    99  // MockContextList mocks the ContextList function
   100  func (m *testRPC) MockContextList(t *testing.T, impl func() []string) {
   101  	m.currentMocks(t).ContextListMock = impl
   102  }
   103  
   104  // MockCheckValidity mocks the CheckValidity function
   105  func (m *testRPC) MockCheckValidity(t *testing.T, impl func(req *Request, secret string) bool) {
   106  	m.currentMocks(t).CheckValidityMock = impl
   107  }
   108  
   109  // NewRPCClient implements the new interface
   110  func (m *testRPC) NewRPCClient(contextID string, channel string, secret string) error {
   111  	if mock := m.currentMocks(nil); mock != nil && mock.NewRPCClientMock != nil {
   112  		return mock.NewRPCClientMock(contextID, channel, secret)
   113  	}
   114  	return nil
   115  }
   116  
   117  // GetRPCClient implements the interface with a mock
   118  func (m *testRPC) GetRPCClient(contextID string) (*RPCHdl, error) {
   119  	if mock := m.currentMocks(nil); mock != nil && mock.GetRPCClientMock != nil {
   120  		return mock.GetRPCClientMock(contextID)
   121  	}
   122  	return nil, nil
   123  }
   124  
   125  // RemoteCall implements the interface with a mock
   126  func (m *testRPC) RemoteCall(contextID string, methodName string, req *Request, resp *Response) error {
   127  	if mock := m.currentMocks(nil); mock != nil && mock.RemoteCallMock != nil {
   128  		return mock.RemoteCallMock(contextID, methodName, req, resp)
   129  	}
   130  	return nil
   131  }
   132  
   133  // DestroyRPCClient implements the interface with a Mock
   134  func (m *testRPC) DestroyRPCClient(contextID string) {
   135  	if mock := m.currentMocks(nil); mock != nil && mock.DestroyRPCClientMock != nil {
   136  		mock.DestroyRPCClientMock(contextID)
   137  		return
   138  	}
   139  }
   140  
   141  // CheckValidity implements the interface with a mock
   142  func (m *testRPC) CheckValidity(req *Request, secret string) bool {
   143  	if mock := m.currentMocks(nil); mock != nil && mock.DestroyRPCClientMock != nil {
   144  		return mock.CheckValidityMock(req, secret)
   145  	}
   146  	return false
   147  }
   148  
   149  // StartServer implements the interface with a mock
   150  func (m *testRPC) StartServer(ctx context.Context, protocol string, path string, handler interface{}) error {
   151  	if mock := m.currentMocks(nil); mock != nil && mock.StartServerMock != nil {
   152  		return mock.StartServerMock(ctx, protocol, path, handler)
   153  	}
   154  	return nil
   155  }
   156  
   157  // ProcessMessage implements the interface with a mock
   158  func (m *testRPC) ProcessMessage(req *Request, secret string) bool {
   159  	if mock := m.currentMocks(nil); mock != nil && mock.ProcessMessageMock != nil {
   160  		return mock.ProcessMessageMock(req, secret)
   161  	}
   162  	return true
   163  }
   164  
   165  // ContextList implements the interface with a mock
   166  func (m *testRPC) ContextList() []string {
   167  	if mock := m.currentMocks(m.currentTest); mock != nil && mock.ContextListMock != nil {
   168  		return mock.ContextListMock()
   169  	}
   170  	return []string{}
   171  }
   172  
   173  // currentMocks returns the list of current mocks
   174  func (m *testRPC) currentMocks(t *testing.T) *mockedMethods {
   175  	m.lock.Lock()
   176  	defer m.lock.Unlock()
   177  
   178  	if t == nil {
   179  		t = m.currentTest
   180  	} else {
   181  		m.currentTest = t
   182  	}
   183  
   184  	mocks := m.mocks[t]
   185  	if mocks == nil {
   186  		mocks = &mockedMethods{}
   187  		m.mocks[t] = mocks
   188  	}
   189  
   190  	return mocks
   191  }