github.com/iqoqo/nomad@v0.11.3-0.20200911112621-d7021c74d101/client/testutil/rpc.go (about)

     1  package testutil
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"net"
     7  	"strings"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/hashicorp/go-msgpack/codec"
    12  	cstructs "github.com/hashicorp/nomad/client/structs"
    13  	"github.com/hashicorp/nomad/nomad/structs"
    14  	"github.com/stretchr/testify/require"
    15  )
    16  
    17  // StreamingRPC may be satisfied by client.Client or server.Server.
    18  type StreamingRPC interface {
    19  	StreamingRpcHandler(method string) (structs.StreamingRpcHandler, error)
    20  }
    21  
    22  // StreamingRPCErrorTestCase is a test case to be passed to the
    23  // assertStreamingRPCError func.
    24  type StreamingRPCErrorTestCase struct {
    25  	Name   string
    26  	RPC    string
    27  	Req    interface{}
    28  	Assert func(error) bool
    29  }
    30  
    31  // AssertStreamingRPCError asserts a streaming RPC's error matches the given
    32  // assertion in the test case.
    33  func AssertStreamingRPCError(t *testing.T, s StreamingRPC, tc StreamingRPCErrorTestCase) {
    34  	handler, err := s.StreamingRpcHandler(tc.RPC)
    35  	require.NoError(t, err)
    36  
    37  	// Create a pipe
    38  	p1, p2 := net.Pipe()
    39  	defer p1.Close()
    40  	defer p2.Close()
    41  
    42  	errCh := make(chan error, 1)
    43  	streamMsg := make(chan *cstructs.StreamErrWrapper, 1)
    44  
    45  	// Start the handler
    46  	go handler(p2)
    47  
    48  	// Start the decoder
    49  	go func() {
    50  		decoder := codec.NewDecoder(p1, structs.MsgpackHandle)
    51  		for {
    52  			var msg cstructs.StreamErrWrapper
    53  			if err := decoder.Decode(&msg); err != nil {
    54  				if err == io.EOF || strings.Contains(err.Error(), "closed") {
    55  					return
    56  				}
    57  				errCh <- fmt.Errorf("error decoding: %v", err)
    58  			}
    59  
    60  			streamMsg <- &msg
    61  		}
    62  	}()
    63  
    64  	// Send the request
    65  	encoder := codec.NewEncoder(p1, structs.MsgpackHandle)
    66  	require.NoError(t, encoder.Encode(tc.Req))
    67  
    68  	timeout := time.After(5 * time.Second)
    69  
    70  	for {
    71  		select {
    72  		case <-timeout:
    73  			t.Fatal("timeout")
    74  		case err := <-errCh:
    75  			require.NoError(t, err)
    76  		case msg := <-streamMsg:
    77  			// Convert RpcError to error
    78  			var err error
    79  			if msg.Error != nil {
    80  				err = msg.Error
    81  			}
    82  			require.True(t, tc.Assert(err), "(%T) %s", msg.Error, msg.Error)
    83  			return
    84  		}
    85  	}
    86  }