github.com/choria-io/go-choria@v0.28.1-0.20240416190746-b3bf9c7d5a45/providers/agent/mcorpc/client/client_test.go (about)

     1  // Copyright (c) 2020-2022, R.I. Pienaar and the Choria Project contributors
     2  //
     3  // SPDX-License-Identifier: Apache-2.0
     4  
     5  package client
     6  
     7  import (
     8  	"context"
     9  	"encoding/json"
    10  	"fmt"
    11  	"strings"
    12  	"testing"
    13  
    14  	"github.com/choria-io/go-choria/build"
    15  	"github.com/choria-io/go-choria/config"
    16  	"github.com/choria-io/go-choria/inter"
    17  	imock "github.com/choria-io/go-choria/inter/imocks"
    18  	"github.com/choria-io/go-choria/message"
    19  	v1 "github.com/choria-io/go-choria/protocol/v1"
    20  	"github.com/choria-io/go-choria/providers/security/filesec"
    21  
    22  	"github.com/choria-io/go-choria/providers/agent/mcorpc"
    23  	"github.com/choria-io/go-choria/providers/agent/mcorpc/ddl/agent"
    24  	"github.com/choria-io/go-choria/server/agents"
    25  
    26  	"github.com/choria-io/go-choria/client/client"
    27  	"github.com/choria-io/go-choria/protocol"
    28  	"github.com/golang/mock/gomock"
    29  
    30  	. "github.com/onsi/ginkgo/v2"
    31  	. "github.com/onsi/gomega"
    32  )
    33  
    34  func TestMcoRPC(t *testing.T) {
    35  	RegisterFailHandler(Fail)
    36  	RunSpecs(t, "Providers/Agent/McoRPC/Client")
    37  }
    38  
    39  var _ = Describe("Providers/Agent/McoRPC/Client", func() {
    40  	var (
    41  		fw      *imock.MockFramework
    42  		cfg     *config.Config
    43  		rpc     *RPC
    44  		mockctl *gomock.Controller
    45  		cl      *MockChoriaClient
    46  		ctx     context.Context
    47  		cancel  func()
    48  		err     error
    49  	)
    50  
    51  	type request struct {
    52  		Testing bool `json:"testing"`
    53  	}
    54  
    55  	type reply struct {
    56  		Received bool `json:"received"`
    57  	}
    58  
    59  	BeforeEach(func() {
    60  		mockctl = gomock.NewController(GinkgoT())
    61  		cl = NewMockChoriaClient(mockctl)
    62  
    63  		fw, cfg = imock.NewFrameworkForTests(mockctl, GinkgoWriter, imock.WithCallerID(), imock.WithDDLFiles("agent", "package", "testdata/mcollective/agent/package.json"))
    64  		fw.EXPECT().NewMessage(gomock.Any(), gomock.Eq("package"), gomock.Eq("ginkgo"), gomock.Eq(inter.RequestMessageType), gomock.Eq(nil)).DoAndReturn(func(payload []byte, agent string, collective string, msgType string, request inter.Message) (msg inter.Message, err error) {
    65  			return message.NewMessage(payload, agent, collective, msgType, request, fw)
    66  		}).AnyTimes()
    67  
    68  		fw.Configuration().LibDir = []string{"testdata"}
    69  
    70  		protocol.Secure = "false"
    71  		rpc, err = New(fw, "package")
    72  		Expect(err).ToNot(HaveOccurred())
    73  
    74  		rpc.cl = cl
    75  		ctx, cancel = context.WithCancel(context.Background())
    76  	})
    77  
    78  	AfterEach(func() {
    79  		cancel()
    80  		mockctl.Finish()
    81  	})
    82  
    83  	Describe("SetOptions", func() {
    84  		It("Should set the options", func() {
    85  			Expect(rpc.ResolveDDL(context.Background())).ToNot(HaveOccurred())
    86  			rpc.setOptions()
    87  			Expect(rpc.opts.BatchSize).To(Equal(0))
    88  			rpc.setOptions(InBatches(10, 1))
    89  			Expect(rpc.opts.BatchSize).To(Equal(10))
    90  		})
    91  	})
    92  
    93  	Describe("RPCReply", func() {
    94  		It("Should match against replies", func() {
    95  			r := RPCReply{
    96  				Statuscode: 0,
    97  				Statusmsg:  "OK",
    98  				Data:       json.RawMessage(`{"hello":"world", "ints": [1,2,3], "strings": ["1","2","3"], "bool":true, "fbool":false}`),
    99  			}
   100  
   101  			check := func(f string) (bool, error) {
   102  				res, _, err := r.MatchExpr(f, nil)
   103  				return res, err
   104  			}
   105  
   106  			Expect(check("ok() && code == 0 && msg == 'OK' && data('hello') in ['world', 'bob']")).To(BeTrue())
   107  			Expect(check("!ok() && data('hello') == 'world'")).To(BeFalse())
   108  			Expect(check("ok() && data('hello') == 'other'")).To(BeFalse())
   109  			Expect(check("ok() && include(data('strings'), '1')")).To(BeTrue())
   110  			Expect(check("ok() && include(data('strings'), '5')")).To(BeFalse())
   111  			Expect(check("ok() && include(data('ints'), 1)")).To(BeTrue())
   112  			Expect(check("include(data('ints'), 1)")).To(BeTrue())
   113  			Expect(check("include(data('ints'), 5)")).To(BeFalse())
   114  			Expect(check("data('bool')")).To(BeTrue())
   115  			Expect(check("!data('bool')")).To(BeFalse())
   116  			Expect(check("data('fbool')")).To(BeFalse())
   117  
   118  			res, _, err := r.MatchExpr("ok() && data('hello')", nil)
   119  			Expect(err).To(MatchError("match expressions should return boolean"))
   120  			Expect(res).To(BeFalse())
   121  		})
   122  	})
   123  
   124  	Describe("New", func() {
   125  		It("Should accept DDLs as an argument", func() {
   126  			ddl := &agent.DDL{
   127  				Metadata: &agents.Metadata{
   128  					Name:        "backplane",
   129  					Description: "Choria Management Backplane",
   130  					Author:      "R.I.Pienaar <rip@devco.net>",
   131  					Version:     "1.0.0",
   132  					License:     "Apache-2.0",
   133  					URL:         "https://choria.io",
   134  					Timeout:     10,
   135  				},
   136  				Actions: []*agent.Action{},
   137  				Schema:  "https://choria.io/schemas/mcorpc/ddl/v1/agent.json",
   138  			}
   139  
   140  			rpc, err = New(fw, "backplane", DDL(ddl))
   141  			Expect(err).ToNot(HaveOccurred())
   142  			Expect(rpc).ToNot(BeNil())
   143  		})
   144  	})
   145  
   146  	Describe("Do", func() {
   147  		It("Should only accept DDLs for the requested agent", func() {
   148  			ddl := &agent.DDL{
   149  				Metadata: &agents.Metadata{
   150  					Name:        "backplane",
   151  					Description: "Choria Management Backplane",
   152  					Author:      "R.I.Pienaar <rip@devco.net>",
   153  					Version:     "1.0.0",
   154  					License:     "Apache-2.0",
   155  					URL:         "https://choria.io",
   156  					Timeout:     10,
   157  				},
   158  				Actions: []*agent.Action{},
   159  				Schema:  "https://choria.io/schemas/mcorpc/ddl/v1/agent.json",
   160  			}
   161  
   162  			rpc, err = New(fw, "package", DDL(ddl))
   163  			_, err := rpc.Do(
   164  				ctx,
   165  				"test_action",
   166  				request{Testing: true},
   167  				Targets(strings.Fields("host1 host2")),
   168  				ReplyTo("custom.reply.to"),
   169  				InBatches(1, -1),
   170  			)
   171  			Expect(err).To(MatchError("the DDL does not describe the package agent"))
   172  		})
   173  
   174  		It("Should perform the request", func() {
   175  			reqid := ""
   176  			handled := 0
   177  
   178  			sec, err := filesec.New(filesec.WithChoriaConfig(&build.Info{}, cfg), filesec.WithLog(fw.Logger("")))
   179  			Expect(err).ToNot(HaveOccurred())
   180  
   181  			fw.EXPECT().NewSecureRequestFromTransport(gomock.Any(), gomock.Any()).DoAndReturn(func(message protocol.TransportMessage, skipvalidate bool) (secure protocol.SecureRequest, err error) {
   182  				return v1.NewSecureRequestFromTransport(message, sec, skipvalidate)
   183  			}).AnyTimes()
   184  			fw.EXPECT().NewRequestFromSecureRequest(gomock.Any()).DoAndReturn(func(sr protocol.SecureRequest) (request protocol.Request, err error) {
   185  				return v1.NewRequestFromSecureRequest(sr)
   186  			}).AnyTimes()
   187  			fw.EXPECT().NewSecureReply(gomock.Any()).DoAndReturn(func(reply protocol.Reply) (secure protocol.SecureReply, err error) {
   188  				return v1.NewSecureReply(reply, sec)
   189  			}).AnyTimes()
   190  			fw.EXPECT().NewTransportForSecureReply(gomock.Any()).DoAndReturn(func(reply protocol.SecureReply) (message protocol.TransportMessage, err error) {
   191  				t, err := v1.NewTransportMessage(cfg.Identity)
   192  				Expect(err).ToNot(HaveOccurred())
   193  				t.SetReplyData(reply)
   194  				return t, nil
   195  			}).AnyTimes()
   196  			fw.EXPECT().NewReplyFromTransportJSON(gomock.Any(), gomock.Any()).DoAndReturn(func(payload []byte, skipvalidate bool) (msg protocol.Reply, err error) {
   197  				t, err := v1.NewTransportFromJSON(payload)
   198  				Expect(err).ToNot(HaveOccurred())
   199  				sreply, err := v1.NewSecureReplyFromTransport(t, sec, skipvalidate)
   200  				Expect(err).ToNot(HaveOccurred())
   201  				return v1.NewReplyFromSecureReply(sreply)
   202  			}).AnyTimes()
   203  			fw.EXPECT().NewRequestTransportForMessage(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(_ context.Context, msg inter.Message, version protocol.ProtocolVersion) (protocol.TransportMessage, error) {
   204  				req, err := v1.NewRequest(msg.Agent(), msg.SenderID(), msg.CallerID(), msg.TTL(), msg.RequestID(), msg.Collective())
   205  				Expect(err).ToNot(HaveOccurred())
   206  				req.SetMessage(msg.Payload())
   207  
   208  				sreq, err := v1.NewSecureRequest(req, sec)
   209  				Expect(err).ToNot(HaveOccurred())
   210  
   211  				sm, err := v1.NewTransportMessage(fw.Configuration().Identity)
   212  				Expect(err).ToNot(HaveOccurred())
   213  				err = sm.SetRequestData(sreq)
   214  				Expect(err).ToNot(HaveOccurred())
   215  
   216  				return sm, nil
   217  			}).AnyTimes()
   218  
   219  			handler := func(r protocol.Reply, rpcr *RPCReply) {
   220  				res := reply{}
   221  				err := json.Unmarshal(rpcr.Data, &res)
   222  				Expect(err).ToNot(HaveOccurred())
   223  				Expect(res.Received).To(BeTrue())
   224  				handled++
   225  			}
   226  
   227  			cl.EXPECT().Request(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Do(func(ctx context.Context, msg inter.Message, handler client.Handler) {
   228  				Expect(msg.Collective()).To(Equal("ginkgo"))
   229  				Expect(msg.Payload()).To(Equal([]byte("{\"agent\":\"package\",\"action\":\"test_action\",\"data\":{\"testing\":true}}")))
   230  
   231  				reqid = msg.RequestID()
   232  
   233  				rpcreply := RPCReply{
   234  					Statusmsg:  "OK",
   235  					Statuscode: mcorpc.OK,
   236  					Data:       json.RawMessage("{\"received\":true}"),
   237  				}
   238  
   239  				j, err := json.Marshal(rpcreply)
   240  				Expect(err).ToNot(HaveOccurred())
   241  
   242  				mt, err := msg.Transport(context.Background())
   243  				Expect(err).ToNot(HaveOccurred())
   244  
   245  				sreq, err := fw.NewSecureRequestFromTransport(mt, true)
   246  				Expect(err).ToNot(HaveOccurred())
   247  
   248  				req, err := fw.NewRequestFromSecureRequest(sreq)
   249  				Expect(err).ToNot(HaveOccurred())
   250  
   251  				rpchandler := rpc.handlerFactory(ctx, cancel, rpc.opts.totalStats)
   252  
   253  				for i := 0; i < 2; i++ {
   254  					reply, err := v1.NewReply(req, fmt.Sprintf("test.sender.%d", i))
   255  					Expect(err).ToNot(HaveOccurred())
   256  					reply.SetMessage(j)
   257  
   258  					srep, err := fw.NewSecureReply(reply)
   259  					Expect(err).ToNot(HaveOccurred())
   260  
   261  					transport, err := fw.NewTransportForSecureReply(srep)
   262  					Expect(err).ToNot(HaveOccurred())
   263  
   264  					tj, err := transport.JSON()
   265  					Expect(err).ToNot(HaveOccurred())
   266  
   267  					cm := imock.NewMockConnectorMessage(mockctl)
   268  					cm.EXPECT().Data().Return(tj)
   269  					rpchandler(ctx, cm)
   270  				}
   271  			})
   272  
   273  			decbcalled := false
   274  			dediscovered := 0
   275  			delimited := 0
   276  
   277  			result, err := rpc.Do(
   278  				ctx,
   279  				"test_action",
   280  				request{Testing: true},
   281  				ReplyHandler(handler),
   282  				Targets(strings.Fields("test.sender.0 test.sender.1")),
   283  				DiscoveryEndCB(func(d, l int) error {
   284  					dediscovered = d
   285  					delimited = l
   286  					decbcalled = true
   287  					return nil
   288  				}),
   289  			)
   290  			Expect(err).ToNot(HaveOccurred())
   291  
   292  			Expect(decbcalled).To(BeTrue())
   293  			Expect(dediscovered).To(Equal(2))
   294  			Expect(delimited).To(Equal(2))
   295  
   296  			Expect(handled).To(Equal(2))
   297  			stats := result.Stats()
   298  			Expect(stats.RequestID).To(Equal(reqid))
   299  			Expect(stats.discoveredNodes).To(Equal(strings.Fields("test.sender.0 test.sender.1")))
   300  			Expect(*stats.DiscoveredNodes()).To(Equal(strings.Fields("test.sender.0 test.sender.1")))
   301  			Expect(stats.unexpectedRespones.Hosts()).To(Equal([]string{}))
   302  			Expect(stats.OKCount()).To(Equal(2))
   303  			Expect(stats.All()).To(BeTrue())
   304  
   305  			d, err := stats.RequestDuration()
   306  			Expect(err).ToNot(HaveOccurred())
   307  			Expect(d).ToNot(BeZero())
   308  			Expect(stats.Action()).To(Equal("test_action"))
   309  			Expect(stats.Agent()).To(Equal("package"))
   310  		})
   311  
   312  		It("Should support discovery callbacks and limits", func() {
   313  			cl.EXPECT().Request(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Do(func(ctx context.Context, msg inter.Message, handler client.Handler) {
   314  				Expect(msg.DiscoveredHosts()).To(Equal([]string{"host1"}))
   315  			})
   316  
   317  			discoveredCnt := 0
   318  			limitedCnt := 0
   319  
   320  			_, err := rpc.Do(ctx, "test_action", request{Testing: true},
   321  				Targets([]string{"host1", "host2", "host3", "host4"}),
   322  				LimitSize("1"),
   323  				LimitMethod("first"),
   324  				DiscoveryEndCB(func(d, l int) error {
   325  					discoveredCnt = d
   326  					limitedCnt = l
   327  					return nil
   328  				}),
   329  			)
   330  
   331  			Expect(err).ToNot(HaveOccurred())
   332  			Expect(discoveredCnt).To(Equal(4))
   333  			Expect(limitedCnt).To(Equal(1))
   334  		})
   335  
   336  		It("Should interruptable by the discovery callback", func() {
   337  			_, err := rpc.Do(ctx, "test_action", request{Testing: true},
   338  				Targets([]string{"host1", "host2", "host3", "host4"}),
   339  				LimitSize("1"),
   340  				LimitMethod("first"),
   341  				DiscoveryEndCB(func(d, l int) error {
   342  					return fmt.Errorf("simulated")
   343  				}),
   344  			)
   345  
   346  			Expect(err).To(MatchError("simulated"))
   347  		})
   348  
   349  		It("Should support batched mode", func() {
   350  			batch1 := cl.EXPECT().Request(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Do(func(ctx context.Context, msg inter.Message, handler client.Handler) {
   351  				Expect(msg.DiscoveredHosts()).To(Equal([]string{"host1", "host2"}))
   352  			})
   353  
   354  			cl.EXPECT().Request(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).After(batch1).Do(func(ctx context.Context, msg inter.Message, handler client.Handler) {
   355  				Expect(msg.DiscoveredHosts()).To(Equal([]string{"host3", "host4"}))
   356  			})
   357  
   358  			rpc.Do(ctx, "test_action", request{Testing: true}, Targets([]string{"host1", "host2", "host3", "host4"}), InBatches(2, -1))
   359  		})
   360  
   361  		It("Should support making requests without processing replies unbatched", func() {
   362  			cl.EXPECT().Request(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Do(func(ctx context.Context, msg inter.Message, handler client.Handler) {
   363  				Expect(msg.DiscoveredHosts()).To(Equal([]string{"host1", "host2"}))
   364  				Expect(msg.ReplyTo()).To(Equal("custom.reply.to"))
   365  				Expect(handler).To(BeNil())
   366  			})
   367  
   368  			_, err := rpc.Do(ctx, "test_action", request{Testing: true}, Targets(strings.Fields("host1 host2")), ReplyTo("custom.reply.to"))
   369  			Expect(err).ToNot(HaveOccurred())
   370  		})
   371  
   372  		It("Should support making requests without processing replies batched", func() {
   373  			batch1 := cl.EXPECT().Request(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Do(func(ctx context.Context, msg inter.Message, handler client.Handler) {
   374  				Expect(msg.DiscoveredHosts()).To(Equal([]string{"host1"}))
   375  				Expect(msg.ReplyTo()).To(Equal("custom.reply.to"))
   376  				Expect(handler).To(BeNil())
   377  			})
   378  
   379  			cl.EXPECT().Request(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).After(batch1).Do(func(ctx context.Context, msg inter.Message, handler client.Handler) {
   380  				Expect(msg.DiscoveredHosts()).To(Equal([]string{"host2"}))
   381  				Expect(msg.ReplyTo()).To(Equal("custom.reply.to"))
   382  				Expect(handler).To(BeNil())
   383  			})
   384  
   385  			_, err := rpc.Do(ctx, "test_action", request{Testing: true}, Targets(strings.Fields("host1 host2")), ReplyTo("custom.reply.to"), InBatches(1, -1))
   386  			Expect(err).ToNot(HaveOccurred())
   387  		})
   388  	})
   389  })