github.com/choria-io/go-choria@v0.28.1-0.20240416190746-b3bf9c7d5a45/providers/agent/mcorpc/agent.go (about)

     1  // Copyright (c) 2020-2022, R.I. Pienaar and the Choria Project contributors
     2  //
     3  // SPDX-License-Identifier: Apache-2.0
     4  
     5  package mcorpc
     6  
     7  import (
     8  	"context"
     9  	"encoding/json"
    10  	"fmt"
    11  	"sort"
    12  	"strings"
    13  
    14  	"github.com/choria-io/go-choria/inter"
    15  	"github.com/sirupsen/logrus"
    16  
    17  	"github.com/choria-io/go-choria/config"
    18  	"github.com/choria-io/go-choria/protocol"
    19  	"github.com/choria-io/go-choria/providers/agent/mcorpc/audit"
    20  	"github.com/choria-io/go-choria/server/agents"
    21  )
    22  
    23  // Action is a function that implements a RPC Action
    24  type Action func(context.Context, *Request, *Reply, *Agent, inter.ConnectorInfo)
    25  
    26  // ActivationChecker is a function that can determine if an agent should be activated
    27  type ActivationChecker func() bool
    28  
    29  // Agent is an instance of the MCollective compatible RPC agents
    30  type Agent struct {
    31  	Log              *logrus.Entry
    32  	Config           *config.Config
    33  	Choria           ChoriaFramework
    34  	ServerInfoSource agents.ServerInfoSource
    35  
    36  	activationCheck ActivationChecker
    37  	meta            *agents.Metadata
    38  	actions         map[string]Action
    39  }
    40  
    41  // New creates a new MCollective SimpleRPC compatible agent
    42  func New(name string, metadata *agents.Metadata, fw ChoriaFramework, log *logrus.Entry) *Agent {
    43  	a := &Agent{
    44  		meta:            metadata,
    45  		Log:             log.WithFields(logrus.Fields{"agent": name}),
    46  		actions:         make(map[string]Action),
    47  		Choria:          fw,
    48  		Config:          fw.Configuration(),
    49  		activationCheck: func() bool { return true },
    50  	}
    51  
    52  	return a
    53  }
    54  
    55  // ShouldActivate checks if the agent should be active using the method set in SetActivationChecker
    56  func (a *Agent) ShouldActivate() bool {
    57  	return a.activationCheck()
    58  }
    59  
    60  // SetActivationChecker sets the function that can determine if the agent should be active
    61  func (a *Agent) SetActivationChecker(ac ActivationChecker) {
    62  	a.activationCheck = ac
    63  }
    64  
    65  // SetServerInfo stores the server info source that owns this agent
    66  func (a *Agent) SetServerInfo(si agents.ServerInfoSource) {
    67  	a.ServerInfoSource = si
    68  }
    69  
    70  // ServerInfo returns the stored server info source
    71  func (a *Agent) ServerInfo() agents.ServerInfoSource {
    72  	return a.ServerInfoSource
    73  }
    74  
    75  // RegisterAction registers an action into the agent
    76  func (a *Agent) RegisterAction(name string, f Action) error {
    77  	if _, ok := a.actions[name]; ok {
    78  		return fmt.Errorf("cannot register action %s, it already exist", name)
    79  	}
    80  
    81  	a.actions[name] = f
    82  
    83  	return nil
    84  }
    85  
    86  // MustRegisterAction registers an action and panics if it fails
    87  func (a *Agent) MustRegisterAction(name string, f Action) {
    88  	if _, ok := a.actions[name]; ok {
    89  		panic(fmt.Errorf("cannot register action %s, it already exist", name))
    90  	}
    91  
    92  	a.actions[name] = f
    93  }
    94  
    95  // HandleMessage attempts to parse a choria.Message as a MCollective SimpleRPC request and calls
    96  // the agents and actions associated with it
    97  func (a *Agent) HandleMessage(ctx context.Context, msg inter.Message, request protocol.Request, conn inter.ConnectorInfo, outbox chan *agents.AgentReply) {
    98  	var err error
    99  
   100  	reply := a.newReply()
   101  	defer a.publish(reply, msg, request, outbox)
   102  
   103  	rpcrequest, err := a.parseIncomingMessage(msg.Payload(), request)
   104  	if err != nil {
   105  		reply.Statuscode = InvalidData
   106  		reply.Statusmsg = fmt.Sprintf("Could not process request: %s", err)
   107  		return
   108  	}
   109  
   110  	reply.Action = rpcrequest.Action
   111  
   112  	action, found := a.actions[rpcrequest.Action]
   113  	if !found {
   114  		reply.Statuscode = UnknownAction
   115  		reply.Statusmsg = fmt.Sprintf("Unknown action %s for agent %s", rpcrequest.Action, a.Name())
   116  		return
   117  	}
   118  
   119  	if a.Config.RPCAuthorization {
   120  		if !a.authorize(rpcrequest) {
   121  			a.Log.Warnf("Denying %s access to %s#%s based on authorization policy for request %s", request.CallerID(), rpcrequest.Agent, rpcrequest.Action, request.RequestID())
   122  			reply.Statuscode = Aborted
   123  			reply.Statusmsg = "You are not authorized to call this agent or action"
   124  			return
   125  		}
   126  	}
   127  
   128  	if a.Config.RPCAudit {
   129  		audit.Request(request, rpcrequest.Agent, rpcrequest.Action, rpcrequest.Data, a.Config)
   130  	}
   131  
   132  	a.Log.Infof("Handling message %s for %s#%s from %s", msg.RequestID(), a.Name(), rpcrequest.Action, request.CallerID())
   133  
   134  	action(ctx, rpcrequest, reply, a, conn)
   135  }
   136  
   137  // Name retrieves the name of the agent
   138  func (a *Agent) Name() string {
   139  	return a.meta.Name
   140  }
   141  
   142  // ActionNames returns a list of known actions in the agent
   143  func (a *Agent) ActionNames() []string {
   144  	var actions []string
   145  
   146  	for k := range a.actions {
   147  		actions = append(actions, k)
   148  	}
   149  
   150  	sort.Strings(actions)
   151  
   152  	return actions
   153  }
   154  
   155  // Metadata retrieves the agent metadata
   156  func (a *Agent) Metadata() *agents.Metadata {
   157  	return a.meta
   158  }
   159  
   160  func (a *Agent) publish(rpcreply *Reply, msg inter.Message, request protocol.Request, outbox chan *agents.AgentReply) {
   161  	if rpcreply.DisableResponse {
   162  		return
   163  	}
   164  
   165  	reply := &agents.AgentReply{
   166  		Message: msg,
   167  		Request: request,
   168  	}
   169  
   170  	if rpcreply.Data == nil {
   171  		rpcreply.Data = "{}"
   172  	}
   173  
   174  	j, err := json.Marshal(rpcreply)
   175  	if err != nil {
   176  		a.Log.Errorf("Could not JSON encode reply: %s", err)
   177  		reply.Error = err
   178  	}
   179  
   180  	reply.Body = j
   181  
   182  	outbox <- reply
   183  }
   184  
   185  func (a *Agent) newReply() *Reply {
   186  	reply := &Reply{
   187  		Statuscode: OK,
   188  		Statusmsg:  "OK",
   189  		Data:       json.RawMessage(`{}`),
   190  	}
   191  
   192  	return reply
   193  }
   194  
   195  func (a *Agent) parseIncomingMessage(msg []byte, request protocol.Request) (*Request, error) {
   196  	r := &Request{}
   197  
   198  	err := json.Unmarshal(msg, r)
   199  	if err != nil {
   200  		return nil, fmt.Errorf("could not parse incoming message as a MCollective SimpleRPC Request: %s", err)
   201  	}
   202  
   203  	r.CallerID = request.CallerID()
   204  	r.RequestID = request.RequestID()
   205  	r.SenderID = request.SenderID()
   206  	r.Collective = request.Collective()
   207  	r.CallerPublicData = request.CallerPublicData()
   208  	r.SignerPublicData = request.SignerPublicData()
   209  	r.TTL = request.TTL()
   210  	r.Time = request.Time()
   211  	r.Filter, _ = request.Filter()
   212  
   213  	if r.Data == nil {
   214  		r.Data = json.RawMessage(`{}`)
   215  	}
   216  
   217  	return r, nil
   218  }
   219  
   220  func (a *Agent) authorize(req *Request) bool {
   221  	if !a.Config.RPCAuthorization {
   222  		return true
   223  	}
   224  
   225  	prov := strings.ToLower(a.Config.RPCAuthorizationProvider)
   226  
   227  	switch prov {
   228  	case "action_policy":
   229  		return actionPolicyAuthorize(req, a, a.Log)
   230  
   231  	case "rego_policy":
   232  		auth, err := regoPolicyAuthorize(req, a, a.Log)
   233  		if err != nil {
   234  			a.Log.Errorf("Could not process Open Policy Agent policy: %v", err)
   235  			return false
   236  		}
   237  		return auth
   238  
   239  	case "aaasvc", "aaasvc_policy":
   240  		auth, err := aaasvcPolicyAuthorize(req, a, a.Log)
   241  		if err != nil {
   242  			a.Log.Errorf("Could not process JWT policy: %v", err)
   243  			return false
   244  		}
   245  		return auth
   246  
   247  	default:
   248  		a.Log.Errorf("Unsupported authorization provider: %s", prov)
   249  
   250  	}
   251  
   252  	return false
   253  }