github.com/choria-io/go-choria@v0.28.1-0.20240416190746-b3bf9c7d5a45/providers/agent/mcorpc/authz_jwt.go (about) 1 // Copyright (c) 2022, R.I. Pienaar and the Choria Project contributors 2 // 3 // SPDX-License-Identifier: Apache-2.0 4 5 package mcorpc 6 7 import ( 8 "context" 9 "encoding/json" 10 "fmt" 11 "strings" 12 13 "github.com/choria-io/go-choria/client/client" 14 "github.com/choria-io/go-choria/config" 15 "github.com/choria-io/go-choria/opa" 16 "github.com/choria-io/tokens" 17 "github.com/open-policy-agent/opa/ast" 18 "github.com/open-policy-agent/opa/rego" 19 "github.com/open-policy-agent/opa/types" 20 "github.com/sirupsen/logrus" 21 ) 22 23 type aaasvcPolicy struct { 24 cfg *config.Config 25 req *Request 26 agent *Agent 27 log *logrus.Entry 28 } 29 30 func aaasvcPolicyAuthorize(req *Request, agent *Agent, log *logrus.Entry) (bool, error) { 31 logger := log.WithFields(logrus.Fields{ 32 "authorizer": "aaasvc", 33 "agent": agent.Name(), 34 "request": req.RequestID, 35 }) 36 37 authz := &aaasvcPolicy{ 38 cfg: agent.Config, 39 req: req, 40 agent: agent, 41 log: logger, 42 } 43 44 return authz.authorize() 45 } 46 47 func (r *aaasvcPolicy) authorize() (bool, error) { 48 if r.req.CallerPublicData == "" { 49 return false, fmt.Errorf("no policy received in request") 50 } 51 52 claims, err := tokens.ParseClientIDTokenUnverified(r.req.CallerPublicData) 53 if err != nil { 54 return false, fmt.Errorf("invalid token in request: %v", err) 55 } 56 57 if r.req.Agent == "discovery" { 58 r.log.Debugf("Allowing discovery request") 59 return true, nil 60 } 61 62 allowed := false 63 hasAgents := len(claims.AllowedAgents) > 0 64 hasOpa := claims.OPAPolicy != "" 65 66 switch { 67 case !(hasAgents || hasOpa): 68 return false, fmt.Errorf("no policy received in token") 69 case hasAgents && hasOpa: 70 return false, fmt.Errorf("received agent list and rego policy") 71 case hasAgents: 72 r.log.Debugf("Processing using agent list") 73 74 allowed, err = EvaluateAgentListPolicy(r.req.Agent, r.req.Action, claims.AllowedAgents, r.log) 75 case hasOpa: 76 r.log.Debugf("Processing using opa policy") 77 78 allowed, err = EvaluateOpenPolicyAgentPolicy(r.req, claims.OPAPolicy, claims, "server", r.log) 79 } 80 81 return allowed, err 82 } 83 84 func EvaluateAgentListPolicy(agent string, action string, policy []string, _ *logrus.Entry) (bool, error) { 85 if len(policy) == 0 { 86 return false, nil 87 } 88 89 for _, allow := range policy { 90 // all things are allowed 91 if allow == "*" { 92 return true, nil 93 } 94 95 parts := strings.Split(allow, ".") 96 if len(parts) != 2 { 97 return false, fmt.Errorf("invalid agent policy: %s", allow) 98 } 99 100 // it's a claim for a different agent so pass, no need to check it here 101 if agent != parts[0] { 102 continue 103 } 104 105 // agent matches, action is * so allow it 106 if parts[1] == "*" { 107 return true, nil 108 } 109 110 // agent matches, action matches, allow it 111 if action == parts[1] { 112 return true, nil 113 } 114 } 115 116 return false, nil 117 118 } 119 120 // EvaluateOpenPolicyAgentPolicy evaluates a rego policy document, typically embedded in a JWT token, against a request. Shared by Choria and AAA Service 121 func EvaluateOpenPolicyAgentPolicy(req *Request, policy string, claims *tokens.ClientIDClaims, site string, log *logrus.Entry) (allowed bool, err error) { 122 if policy == "" { 123 return false, fmt.Errorf("invalid policy given") 124 } 125 126 eopts := []opa.Option{ 127 opa.Logger(log), 128 opa.Policy([]byte(policy)), 129 opa.Function(opaFunctionsMap(req)...), 130 } 131 132 if log.Logger.GetLevel() == logrus.DebugLevel { 133 eopts = append(eopts, opa.Trace()) 134 } 135 136 evaluator, err := opa.New("io.choria.aaasvc", "data.io.choria.aaasvc.allow", eopts...) 137 if err != nil { 138 return false, fmt.Errorf("could not initialize opa evaluator: %v", err) 139 } 140 141 inputs, err := opaInputs(req, req.Data, site, claims) 142 if err != nil { 143 return false, err 144 } 145 146 allowed, err = evaluator.Evaluate(context.Background(), inputs) 147 if err != nil { 148 return false, err 149 } 150 151 return allowed, nil 152 } 153 154 func opaInputs(req *Request, data json.RawMessage, site string, claims *tokens.ClientIDClaims) (map[string]any, error) { 155 dat := map[string]any{} 156 err := json.Unmarshal(data, &dat) 157 if err != nil { 158 return nil, err 159 } 160 161 // lame deep copy/data convert thing happening here 162 jclaims, err := json.Marshal(claims) 163 if err != nil { 164 return nil, fmt.Errorf("could not JSON encode claims") 165 } 166 167 cdat := new(map[string]any) 168 err = json.Unmarshal(jclaims, &cdat) 169 if err != nil { 170 return nil, fmt.Errorf("could not JSON encode claims") 171 } 172 173 return map[string]any{ 174 "agent": req.Agent, 175 "action": req.Action, 176 "data": data, 177 "sender": req.SenderID, 178 "collective": req.Collective, 179 "ttl": req.TTL, 180 "time": req.Time, 181 "site": site, 182 "claims": cdat, 183 }, nil 184 } 185 186 func opaFunctionsMap(req *Request) []func(r *rego.Rego) { 187 return []func(r *rego.Rego){ 188 rego.Function1(®o.Function{Name: "requires_filter", Decl: types.NewFunction(types.Args(), types.B)}, opaFuncRequiresFilter(req)), 189 rego.Function1(®o.Function{Name: "requires_fact_filter", Decl: types.NewFunction(types.Args(types.S), types.B)}, opaFuncRequiresFactFilter(req)), 190 rego.Function1(®o.Function{Name: "requires_class_filter", Decl: types.NewFunction(types.Args(types.S), types.B)}, opaFuncRequiresClassFilter(req)), 191 rego.Function1(®o.Function{Name: "requires_identity_filter", Decl: types.NewFunction(types.Args(types.S), types.B)}, opaFuncRequiresIdentityFilter(req)), 192 } 193 } 194 195 func opaFuncRequiresFilter(req *Request) func(_ rego.BuiltinContext, a *ast.Term) (*ast.Term, error) { 196 return func(_ rego.BuiltinContext, a *ast.Term) (*ast.Term, error) { 197 // agent is always set, so we don't check it else it will always be true 198 if len(req.Filter.ClassFilters()) > 0 || len(req.Filter.IdentityFilters()) > 0 || len(req.Filter.FactFilters()) > 0 || len(req.Filter.CompoundFilters()) > 0 { 199 return ast.BooleanTerm(true), nil 200 } 201 202 return ast.BooleanTerm(false), nil 203 } 204 205 } 206 207 func opaFuncRequiresIdentityFilter(req *Request) func(_ rego.BuiltinContext, a *ast.Term) (*ast.Term, error) { 208 return func(_ rego.BuiltinContext, a *ast.Term) (*ast.Term, error) { 209 str, ok := a.Value.(ast.String) 210 if !ok { 211 return ast.BooleanTerm(false), fmt.Errorf("invalid identity matcher received") 212 } 213 214 want := string(str) 215 for _, f := range req.Filter.IdentityFilters() { 216 if f == want { 217 return ast.BooleanTerm(true), nil 218 } 219 } 220 221 return ast.BooleanTerm(false), nil 222 } 223 } 224 225 func opaFuncRequiresClassFilter(req *Request) func(_ rego.BuiltinContext, a *ast.Term) (*ast.Term, error) { 226 return func(_ rego.BuiltinContext, a *ast.Term) (*ast.Term, error) { 227 str, ok := a.Value.(ast.String) 228 if !ok { 229 return ast.BooleanTerm(false), fmt.Errorf("invalid class matcher received") 230 } 231 232 want := string(str) 233 234 for _, f := range req.Filter.ClassFilters() { 235 if f == want { 236 return ast.BooleanTerm(true), nil 237 } 238 } 239 240 return ast.BooleanTerm(false), nil 241 } 242 } 243 244 func opaFuncRequiresFactFilter(req *Request) func(_ rego.BuiltinContext, a *ast.Term) (*ast.Term, error) { 245 return func(_ rego.BuiltinContext, a *ast.Term) (*ast.Term, error) { 246 str, ok := a.Value.(ast.String) 247 if !ok { 248 return ast.BooleanTerm(false), fmt.Errorf("invalid fact matcher received") 249 } 250 251 want, err := client.ParseFactFilterString(string(str)) 252 if err != nil { 253 return ast.BooleanTerm(false), fmt.Errorf("invalid fact matcher received: %s", err) 254 } 255 256 for _, f := range req.Filter.Fact { 257 if want.Fact == f.Fact && want.Operator == f.Operator && want.Value == f.Value { 258 return ast.BooleanTerm(true), nil 259 } 260 } 261 262 return ast.BooleanTerm(false), nil 263 } 264 }