github.com/choria-io/go-choria@v0.28.1-0.20240416190746-b3bf9c7d5a45/providers/agent/mcorpc/agent.go (about) 1 // Copyright (c) 2020-2022, R.I. Pienaar and the Choria Project contributors 2 // 3 // SPDX-License-Identifier: Apache-2.0 4 5 package mcorpc 6 7 import ( 8 "context" 9 "encoding/json" 10 "fmt" 11 "sort" 12 "strings" 13 14 "github.com/choria-io/go-choria/inter" 15 "github.com/sirupsen/logrus" 16 17 "github.com/choria-io/go-choria/config" 18 "github.com/choria-io/go-choria/protocol" 19 "github.com/choria-io/go-choria/providers/agent/mcorpc/audit" 20 "github.com/choria-io/go-choria/server/agents" 21 ) 22 23 // Action is a function that implements a RPC Action 24 type Action func(context.Context, *Request, *Reply, *Agent, inter.ConnectorInfo) 25 26 // ActivationChecker is a function that can determine if an agent should be activated 27 type ActivationChecker func() bool 28 29 // Agent is an instance of the MCollective compatible RPC agents 30 type Agent struct { 31 Log *logrus.Entry 32 Config *config.Config 33 Choria ChoriaFramework 34 ServerInfoSource agents.ServerInfoSource 35 36 activationCheck ActivationChecker 37 meta *agents.Metadata 38 actions map[string]Action 39 } 40 41 // New creates a new MCollective SimpleRPC compatible agent 42 func New(name string, metadata *agents.Metadata, fw ChoriaFramework, log *logrus.Entry) *Agent { 43 a := &Agent{ 44 meta: metadata, 45 Log: log.WithFields(logrus.Fields{"agent": name}), 46 actions: make(map[string]Action), 47 Choria: fw, 48 Config: fw.Configuration(), 49 activationCheck: func() bool { return true }, 50 } 51 52 return a 53 } 54 55 // ShouldActivate checks if the agent should be active using the method set in SetActivationChecker 56 func (a *Agent) ShouldActivate() bool { 57 return a.activationCheck() 58 } 59 60 // SetActivationChecker sets the function that can determine if the agent should be active 61 func (a *Agent) SetActivationChecker(ac ActivationChecker) { 62 a.activationCheck = ac 63 } 64 65 // SetServerInfo stores the server info source that owns this agent 66 func (a *Agent) SetServerInfo(si agents.ServerInfoSource) { 67 a.ServerInfoSource = si 68 } 69 70 // ServerInfo returns the stored server info source 71 func (a *Agent) ServerInfo() agents.ServerInfoSource { 72 return a.ServerInfoSource 73 } 74 75 // RegisterAction registers an action into the agent 76 func (a *Agent) RegisterAction(name string, f Action) error { 77 if _, ok := a.actions[name]; ok { 78 return fmt.Errorf("cannot register action %s, it already exist", name) 79 } 80 81 a.actions[name] = f 82 83 return nil 84 } 85 86 // MustRegisterAction registers an action and panics if it fails 87 func (a *Agent) MustRegisterAction(name string, f Action) { 88 if _, ok := a.actions[name]; ok { 89 panic(fmt.Errorf("cannot register action %s, it already exist", name)) 90 } 91 92 a.actions[name] = f 93 } 94 95 // HandleMessage attempts to parse a choria.Message as a MCollective SimpleRPC request and calls 96 // the agents and actions associated with it 97 func (a *Agent) HandleMessage(ctx context.Context, msg inter.Message, request protocol.Request, conn inter.ConnectorInfo, outbox chan *agents.AgentReply) { 98 var err error 99 100 reply := a.newReply() 101 defer a.publish(reply, msg, request, outbox) 102 103 rpcrequest, err := a.parseIncomingMessage(msg.Payload(), request) 104 if err != nil { 105 reply.Statuscode = InvalidData 106 reply.Statusmsg = fmt.Sprintf("Could not process request: %s", err) 107 return 108 } 109 110 reply.Action = rpcrequest.Action 111 112 action, found := a.actions[rpcrequest.Action] 113 if !found { 114 reply.Statuscode = UnknownAction 115 reply.Statusmsg = fmt.Sprintf("Unknown action %s for agent %s", rpcrequest.Action, a.Name()) 116 return 117 } 118 119 if a.Config.RPCAuthorization { 120 if !a.authorize(rpcrequest) { 121 a.Log.Warnf("Denying %s access to %s#%s based on authorization policy for request %s", request.CallerID(), rpcrequest.Agent, rpcrequest.Action, request.RequestID()) 122 reply.Statuscode = Aborted 123 reply.Statusmsg = "You are not authorized to call this agent or action" 124 return 125 } 126 } 127 128 if a.Config.RPCAudit { 129 audit.Request(request, rpcrequest.Agent, rpcrequest.Action, rpcrequest.Data, a.Config) 130 } 131 132 a.Log.Infof("Handling message %s for %s#%s from %s", msg.RequestID(), a.Name(), rpcrequest.Action, request.CallerID()) 133 134 action(ctx, rpcrequest, reply, a, conn) 135 } 136 137 // Name retrieves the name of the agent 138 func (a *Agent) Name() string { 139 return a.meta.Name 140 } 141 142 // ActionNames returns a list of known actions in the agent 143 func (a *Agent) ActionNames() []string { 144 var actions []string 145 146 for k := range a.actions { 147 actions = append(actions, k) 148 } 149 150 sort.Strings(actions) 151 152 return actions 153 } 154 155 // Metadata retrieves the agent metadata 156 func (a *Agent) Metadata() *agents.Metadata { 157 return a.meta 158 } 159 160 func (a *Agent) publish(rpcreply *Reply, msg inter.Message, request protocol.Request, outbox chan *agents.AgentReply) { 161 if rpcreply.DisableResponse { 162 return 163 } 164 165 reply := &agents.AgentReply{ 166 Message: msg, 167 Request: request, 168 } 169 170 if rpcreply.Data == nil { 171 rpcreply.Data = "{}" 172 } 173 174 j, err := json.Marshal(rpcreply) 175 if err != nil { 176 a.Log.Errorf("Could not JSON encode reply: %s", err) 177 reply.Error = err 178 } 179 180 reply.Body = j 181 182 outbox <- reply 183 } 184 185 func (a *Agent) newReply() *Reply { 186 reply := &Reply{ 187 Statuscode: OK, 188 Statusmsg: "OK", 189 Data: json.RawMessage(`{}`), 190 } 191 192 return reply 193 } 194 195 func (a *Agent) parseIncomingMessage(msg []byte, request protocol.Request) (*Request, error) { 196 r := &Request{} 197 198 err := json.Unmarshal(msg, r) 199 if err != nil { 200 return nil, fmt.Errorf("could not parse incoming message as a MCollective SimpleRPC Request: %s", err) 201 } 202 203 r.CallerID = request.CallerID() 204 r.RequestID = request.RequestID() 205 r.SenderID = request.SenderID() 206 r.Collective = request.Collective() 207 r.CallerPublicData = request.CallerPublicData() 208 r.SignerPublicData = request.SignerPublicData() 209 r.TTL = request.TTL() 210 r.Time = request.Time() 211 r.Filter, _ = request.Filter() 212 213 if r.Data == nil { 214 r.Data = json.RawMessage(`{}`) 215 } 216 217 return r, nil 218 } 219 220 func (a *Agent) authorize(req *Request) bool { 221 if !a.Config.RPCAuthorization { 222 return true 223 } 224 225 prov := strings.ToLower(a.Config.RPCAuthorizationProvider) 226 227 switch prov { 228 case "action_policy": 229 return actionPolicyAuthorize(req, a, a.Log) 230 231 case "rego_policy": 232 auth, err := regoPolicyAuthorize(req, a, a.Log) 233 if err != nil { 234 a.Log.Errorf("Could not process Open Policy Agent policy: %v", err) 235 return false 236 } 237 return auth 238 239 case "aaasvc", "aaasvc_policy": 240 auth, err := aaasvcPolicyAuthorize(req, a, a.Log) 241 if err != nil { 242 a.Log.Errorf("Could not process JWT policy: %v", err) 243 return false 244 } 245 return auth 246 247 default: 248 a.Log.Errorf("Unsupported authorization provider: %s", prov) 249 250 } 251 252 return false 253 }