github.com/choria-io/go-choria@v0.28.1-0.20240416190746-b3bf9c7d5a45/providers/agent/mcorpc/mcorpc_test.go (about)

     1  // Copyright (c) 2020-2022, R.I. Pienaar and the Choria Project contributors
     2  //
     3  // SPDX-License-Identifier: Apache-2.0
     4  
     5  package mcorpc
     6  
     7  import (
     8  	"context"
     9  	"encoding/json"
    10  	"testing"
    11  
    12  	"github.com/choria-io/go-choria/build"
    13  	"github.com/choria-io/go-choria/config"
    14  	"github.com/choria-io/go-choria/inter"
    15  	imock "github.com/choria-io/go-choria/inter/imocks"
    16  	"github.com/choria-io/go-choria/message"
    17  	"github.com/choria-io/go-choria/protocol"
    18  	v1 "github.com/choria-io/go-choria/protocol/v1"
    19  	"github.com/choria-io/go-choria/server/agents"
    20  	"github.com/golang/mock/gomock"
    21  	. "github.com/onsi/ginkgo/v2"
    22  	. "github.com/onsi/gomega"
    23  	"github.com/tidwall/gjson"
    24  )
    25  
    26  func TestMcoRPC(t *testing.T) {
    27  	RegisterFailHandler(Fail)
    28  	RunSpecs(t, "Providers/Agent/McoRPC")
    29  }
    30  
    31  var _ = Describe("McoRPC", func() {
    32  	var (
    33  		agent   *Agent
    34  		mockctl *gomock.Controller
    35  		fw      *imock.MockFramework
    36  		cfg     *config.Config
    37  		msg     inter.Message
    38  		req     protocol.Request
    39  		outbox  = make(chan *agents.AgentReply, 1)
    40  		err     error
    41  		ctx     context.Context
    42  	)
    43  
    44  	BeforeEach(func() {
    45  		mockctl = gomock.NewController(GinkgoT())
    46  
    47  		protocol.Secure = "false"
    48  		build.TLS = "false"
    49  
    50  		fw, cfg = imock.NewFrameworkForTests(mockctl, GinkgoWriter, imock.WithCallerID())
    51  		cfg.LogLevel = "fatal"
    52  
    53  		metadata := &agents.Metadata{Name: "test"}
    54  		agent = New("testing", metadata, fw, fw.Logger("test"))
    55  		ctx = context.Background()
    56  	})
    57  
    58  	AfterEach(func() {
    59  		mockctl.Finish()
    60  	})
    61  
    62  	It("Should have correct constants", func() {
    63  		Expect(OK).To(Equal(StatusCode(0)))
    64  		Expect(Aborted).To(Equal(StatusCode(1)))
    65  		Expect(UnknownAction).To(Equal(StatusCode(2)))
    66  		Expect(MissingData).To(Equal(StatusCode(3)))
    67  		Expect(InvalidData).To(Equal(StatusCode(4)))
    68  		Expect(UnknownError).To(Equal(StatusCode(5)))
    69  	})
    70  
    71  	Describe("RegisterAction", func() {
    72  		It("Should fail if the action already exist", func() {
    73  			action := func(ctx context.Context, req *Request, reply *Reply, agent *Agent, conn inter.ConnectorInfo) {}
    74  			err := agent.RegisterAction("test", action)
    75  			Expect(err).ToNot(HaveOccurred())
    76  			err = agent.RegisterAction("test", action)
    77  			Expect(err).To(MatchError("cannot register action test, it already exist"))
    78  		})
    79  	})
    80  
    81  	Describe("HandleMessage", func() {
    82  		BeforeEach(func() {
    83  			req, err = v1.NewRequest("test", "test.example.net", "choria=rip.mcollective", 60, "testrequest", "mcollective")
    84  			Expect(err).ToNot(HaveOccurred())
    85  			msg, err = message.NewMessageFromRequest(req, "dev.null", fw)
    86  			Expect(err).ToNot(HaveOccurred())
    87  		})
    88  
    89  		It("Should handle bad incoming data", func() {
    90  			agent.HandleMessage(ctx, msg, req, nil, outbox)
    91  
    92  			reply := <-outbox
    93  			Expect(gjson.GetBytes(reply.Body, "statusmsg").String()).To(Equal("Could not process request: could not parse incoming message as a MCollective SimpleRPC Request: unexpected end of JSON input"))
    94  			Expect(gjson.GetBytes(reply.Body, "statuscode").Int()).To(Equal(int64(4)))
    95  		})
    96  
    97  		It("Should handle unknown actions", func() {
    98  			msg.SetPayload([]byte(`{"agent":"test", "action":"nonexisting"}`))
    99  			agent.HandleMessage(ctx, msg, req, nil, outbox)
   100  
   101  			reply := <-outbox
   102  			Expect(gjson.GetBytes(reply.Body, "statusmsg").String()).To(Equal("Unknown action nonexisting for agent test"))
   103  			Expect(gjson.GetBytes(reply.Body, "statuscode").Int()).To(Equal(int64(2)))
   104  		})
   105  
   106  		It("Should call the action", func() {
   107  			action := func(ctx context.Context, req *Request, reply *Reply, agent *Agent, conn inter.ConnectorInfo) {
   108  				d := make(map[string]string)
   109  				d["test"] = "hello world"
   110  				reply.Data = &d
   111  			}
   112  
   113  			agent.RegisterAction("test", action)
   114  			msg.SetPayload([]byte(`{"agent":"test", "action":"test"}`))
   115  			agent.HandleMessage(ctx, msg, req, nil, outbox)
   116  
   117  			reply := <-outbox
   118  			Expect(gjson.GetBytes(reply.Body, "statusmsg").String()).To(Equal("OK"))
   119  			Expect(gjson.GetBytes(reply.Body, "statuscode").Int()).To(Equal(int64(0)))
   120  			Expect(gjson.GetBytes(reply.Body, "data.test").String()).To(Equal("hello world"))
   121  		})
   122  
   123  		It("Should detect unsupported authorization systems", func() {
   124  			cfg.RPCAuthorization = true
   125  			msg.SetPayload([]byte(`{"agent":"test", "action":"test"}`))
   126  			action := func(ctx context.Context, req *Request, reply *Reply, agent *Agent, conn inter.ConnectorInfo) {
   127  				d := map[string]string{"test": "hello world"}
   128  				reply.Data = &d
   129  			}
   130  
   131  			agent.RegisterAction("test", action)
   132  			agent.HandleMessage(ctx, msg, req, nil, outbox)
   133  			reply := <-outbox
   134  
   135  			Expect(gjson.GetBytes(reply.Body, "statusmsg").String()).To(Equal("You are not authorized to call this agent or action"))
   136  			Expect(gjson.GetBytes(reply.Body, "statuscode").Int()).To(Equal(int64(1)))
   137  		})
   138  
   139  		It("Should support action_policy authorization", func() {
   140  			cfg.ConfigFile = "testdata/config.cfg"
   141  			cfg.RPCAuthorization = true
   142  			msg.SetPayload([]byte(`{"agent":"test", "action":"test"}`))
   143  
   144  			action := func(ctx context.Context, req *Request, reply *Reply, agent *Agent, conn inter.ConnectorInfo) {
   145  				d := map[string]string{"test": "hello world"}
   146  				reply.Data = &d
   147  			}
   148  
   149  			agent.RegisterAction("test", action)
   150  			agent.HandleMessage(ctx, msg, req, nil, outbox)
   151  			reply := <-outbox
   152  
   153  			Expect(gjson.GetBytes(reply.Body, "statusmsg").String()).To(Equal("You are not authorized to call this agent or action"))
   154  			Expect(gjson.GetBytes(reply.Body, "statuscode").Int()).To(Equal(int64(1)))
   155  		})
   156  
   157  		It("Should support rego_policy authorization", func() {
   158  			cfg.ConfigFile = "testdata/config.cfg"
   159  			cfg.RPCAuthorization = true
   160  			msg.SetPayload([]byte(`{"agent":"test", "action":"test"}`))
   161  
   162  			action := func(ctx context.Context, req *Request, reply *Reply, agent *Agent, conn inter.ConnectorInfo) {
   163  				d := map[string]string{"test": "hello world"}
   164  				reply.Data = &d
   165  			}
   166  
   167  			agent.RegisterAction("test", action)
   168  			agent.HandleMessage(ctx, msg, req, nil, outbox)
   169  			reply := <-outbox
   170  
   171  			Expect(gjson.GetBytes(reply.Body, "statusmsg").String()).To(Equal("You are not authorized to call this agent or action"))
   172  			Expect(gjson.GetBytes(reply.Body, "statuscode").Int()).To(Equal(int64(1)))
   173  
   174  		})
   175  	})
   176  
   177  	Describe("publish", func() {
   178  		It("Should handle bad data", func() {
   179  			reply := &Reply{
   180  				Data: outbox,
   181  			}
   182  
   183  			agent.publish(reply, msg, req, outbox)
   184  			out := <-outbox
   185  			Expect(out.Error).To(MatchError("json: unsupported type: chan *agents.AgentReply"))
   186  		})
   187  
   188  		PIt("Should publish good messages")
   189  	})
   190  
   191  	Describe("ParseRequestData", func() {
   192  		It("Should handle valid data correctly", func() {
   193  			req := &Request{
   194  				Data: json.RawMessage(`{"hello":"world"}`),
   195  			}
   196  
   197  			reply := &Reply{}
   198  
   199  			var params struct {
   200  				Hello string `json:"hello"`
   201  			}
   202  
   203  			ok := ParseRequestData(&params, req, reply)
   204  
   205  			Expect(ok).To(BeTrue())
   206  			Expect(params.Hello).To(Equal("world"))
   207  		})
   208  
   209  		It("Should handle invalid data correctly", func() {
   210  			req := &Request{
   211  				Agent:  "test",
   212  				Action: "will_fail",
   213  				Data:   json.RawMessage(`fail`),
   214  			}
   215  
   216  			reply := &Reply{}
   217  
   218  			var params struct {
   219  				Hello string `json:"hello"`
   220  			}
   221  
   222  			ok := ParseRequestData(&params, req, reply)
   223  
   224  			Expect(ok).To(BeFalse())
   225  			Expect(reply.Statuscode).To(Equal(InvalidData))
   226  			Expect(reply.Statusmsg).To(Equal("Could not parse request data for test#will_fail: invalid character 'i' in literal false (expecting 'l')"))
   227  		})
   228  
   229  		It("Should use the validator to validate structs", func() {
   230  			req := &Request{
   231  				Agent:  "test",
   232  				Action: "will_fail",
   233  				Data:   json.RawMessage(`{"hello":"foo > bar"}`),
   234  			}
   235  
   236  			reply := &Reply{}
   237  
   238  			var params struct {
   239  				Hello string `json:"hello" validate:"shellsafe"`
   240  			}
   241  
   242  			ok := ParseRequestData(&params, req, reply)
   243  
   244  			Expect(ok).To(BeFalse())
   245  			Expect(reply.Statuscode).To(Equal(InvalidData))
   246  			Expect(reply.Statusmsg).To(Equal("Validation failed: Hello shellsafe validation failed: may not contain '>'"))
   247  		})
   248  	})
   249  
   250  	Describe("newReply", func() {
   251  		It("Should set the correct starting code and message", func() {
   252  			r := agent.newReply()
   253  			Expect(r.Statuscode).To(Equal(OK))
   254  			Expect(r.Statusmsg).To(Equal("OK"))
   255  		})
   256  	})
   257  })