github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/pkg/p2p/mock_grpc_client.go (about)

     1  // Copyright 2021 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package p2p
    15  
    16  import (
    17  	"context"
    18  	"sync"
    19  	"sync/atomic"
    20  
    21  	"github.com/pingcap/tiflow/proto/p2p"
    22  	"github.com/stretchr/testify/mock"
    23  	"google.golang.org/grpc"
    24  )
    25  
    26  type mockSendMessageClient struct {
    27  	mu sync.Mutex
    28  	mock.Mock
    29  	// embeds an empty interface
    30  	p2p.CDCPeerToPeer_SendMessageClient
    31  	ctx context.Context
    32  
    33  	msgCount int32
    34  	replyCh  chan *p2p.SendMessageResponse
    35  }
    36  
    37  func newMockSendMessageClient(ctx context.Context) *mockSendMessageClient {
    38  	return &mockSendMessageClient{
    39  		ctx:     ctx,
    40  		replyCh: make(chan *p2p.SendMessageResponse), // unbuffered channel
    41  	}
    42  }
    43  
    44  func (s *mockSendMessageClient) Send(packet *p2p.MessagePacket) error {
    45  	s.mu.Lock()
    46  	defer s.mu.Unlock()
    47  
    48  	args := s.Called(packet)
    49  	atomic.AddInt32(&s.msgCount, 1)
    50  	return args.Error(0)
    51  }
    52  
    53  func (s *mockSendMessageClient) Recv() (*p2p.SendMessageResponse, error) {
    54  	var args mock.Arguments
    55  	func() {
    56  		// We use a deferred Unlock in case `s.Called()` panics.
    57  		s.mu.Lock()
    58  		defer s.mu.Unlock()
    59  
    60  		args = s.MethodCalled("Recv")
    61  	}()
    62  
    63  	if err := args.Error(1); err != nil {
    64  		return nil, err
    65  	}
    66  	if args.Get(0) != nil {
    67  		return args.Get(0).(*p2p.SendMessageResponse), nil
    68  	}
    69  	select {
    70  	case <-s.ctx.Done():
    71  		return nil, s.ctx.Err()
    72  	case resp := <-s.replyCh:
    73  		return resp, nil
    74  	}
    75  }
    76  
    77  func (s *mockSendMessageClient) Context() context.Context {
    78  	return s.ctx
    79  }
    80  
    81  func (s *mockSendMessageClient) ResetMock() {
    82  	s.mu.Lock()
    83  	defer s.mu.Unlock()
    84  
    85  	s.ExpectedCalls = nil
    86  	s.Calls = nil
    87  }
    88  
    89  type mockCDCPeerToPeerClient struct {
    90  	mock.Mock
    91  }
    92  
    93  func (c *mockCDCPeerToPeerClient) SendMessage(
    94  	ctx context.Context, opts ...grpc.CallOption,
    95  ) (p2p.CDCPeerToPeer_SendMessageClient, error) {
    96  	args := c.Called(ctx, opts)
    97  	return args.Get(0).(p2p.CDCPeerToPeer_SendMessageClient), args.Error(1)
    98  }