github.com/choria-io/go-choria@v0.28.1-0.20240416190746-b3bf9c7d5a45/providers/agent/mcorpc/authz_jwt.go (about)

     1  // Copyright (c) 2022, R.I. Pienaar and the Choria Project contributors
     2  //
     3  // SPDX-License-Identifier: Apache-2.0
     4  
     5  package mcorpc
     6  
     7  import (
     8  	"context"
     9  	"encoding/json"
    10  	"fmt"
    11  	"strings"
    12  
    13  	"github.com/choria-io/go-choria/client/client"
    14  	"github.com/choria-io/go-choria/config"
    15  	"github.com/choria-io/go-choria/opa"
    16  	"github.com/choria-io/tokens"
    17  	"github.com/open-policy-agent/opa/ast"
    18  	"github.com/open-policy-agent/opa/rego"
    19  	"github.com/open-policy-agent/opa/types"
    20  	"github.com/sirupsen/logrus"
    21  )
    22  
    23  type aaasvcPolicy struct {
    24  	cfg   *config.Config
    25  	req   *Request
    26  	agent *Agent
    27  	log   *logrus.Entry
    28  }
    29  
    30  func aaasvcPolicyAuthorize(req *Request, agent *Agent, log *logrus.Entry) (bool, error) {
    31  	logger := log.WithFields(logrus.Fields{
    32  		"authorizer": "aaasvc",
    33  		"agent":      agent.Name(),
    34  		"request":    req.RequestID,
    35  	})
    36  
    37  	authz := &aaasvcPolicy{
    38  		cfg:   agent.Config,
    39  		req:   req,
    40  		agent: agent,
    41  		log:   logger,
    42  	}
    43  
    44  	return authz.authorize()
    45  }
    46  
    47  func (r *aaasvcPolicy) authorize() (bool, error) {
    48  	if r.req.CallerPublicData == "" {
    49  		return false, fmt.Errorf("no policy received in request")
    50  	}
    51  
    52  	claims, err := tokens.ParseClientIDTokenUnverified(r.req.CallerPublicData)
    53  	if err != nil {
    54  		return false, fmt.Errorf("invalid token in request: %v", err)
    55  	}
    56  
    57  	if r.req.Agent == "discovery" {
    58  		r.log.Debugf("Allowing discovery request")
    59  		return true, nil
    60  	}
    61  
    62  	allowed := false
    63  	hasAgents := len(claims.AllowedAgents) > 0
    64  	hasOpa := claims.OPAPolicy != ""
    65  
    66  	switch {
    67  	case !(hasAgents || hasOpa):
    68  		return false, fmt.Errorf("no policy received in token")
    69  	case hasAgents && hasOpa:
    70  		return false, fmt.Errorf("received agent list and rego policy")
    71  	case hasAgents:
    72  		r.log.Debugf("Processing using agent list")
    73  
    74  		allowed, err = EvaluateAgentListPolicy(r.req.Agent, r.req.Action, claims.AllowedAgents, r.log)
    75  	case hasOpa:
    76  		r.log.Debugf("Processing using opa policy")
    77  
    78  		allowed, err = EvaluateOpenPolicyAgentPolicy(r.req, claims.OPAPolicy, claims, "server", r.log)
    79  	}
    80  
    81  	return allowed, err
    82  }
    83  
    84  func EvaluateAgentListPolicy(agent string, action string, policy []string, _ *logrus.Entry) (bool, error) {
    85  	if len(policy) == 0 {
    86  		return false, nil
    87  	}
    88  
    89  	for _, allow := range policy {
    90  		// all things are allowed
    91  		if allow == "*" {
    92  			return true, nil
    93  		}
    94  
    95  		parts := strings.Split(allow, ".")
    96  		if len(parts) != 2 {
    97  			return false, fmt.Errorf("invalid agent policy: %s", allow)
    98  		}
    99  
   100  		// it's a claim for a different agent so pass, no need to check it here
   101  		if agent != parts[0] {
   102  			continue
   103  		}
   104  
   105  		// agent matches, action is * so allow it
   106  		if parts[1] == "*" {
   107  			return true, nil
   108  		}
   109  
   110  		// agent matches, action matches, allow it
   111  		if action == parts[1] {
   112  			return true, nil
   113  		}
   114  	}
   115  
   116  	return false, nil
   117  
   118  }
   119  
   120  // EvaluateOpenPolicyAgentPolicy evaluates a rego policy document, typically embedded in a JWT token, against a request.  Shared by Choria and AAA Service
   121  func EvaluateOpenPolicyAgentPolicy(req *Request, policy string, claims *tokens.ClientIDClaims, site string, log *logrus.Entry) (allowed bool, err error) {
   122  	if policy == "" {
   123  		return false, fmt.Errorf("invalid policy given")
   124  	}
   125  
   126  	eopts := []opa.Option{
   127  		opa.Logger(log),
   128  		opa.Policy([]byte(policy)),
   129  		opa.Function(opaFunctionsMap(req)...),
   130  	}
   131  
   132  	if log.Logger.GetLevel() == logrus.DebugLevel {
   133  		eopts = append(eopts, opa.Trace())
   134  	}
   135  
   136  	evaluator, err := opa.New("io.choria.aaasvc", "data.io.choria.aaasvc.allow", eopts...)
   137  	if err != nil {
   138  		return false, fmt.Errorf("could not initialize opa evaluator: %v", err)
   139  	}
   140  
   141  	inputs, err := opaInputs(req, req.Data, site, claims)
   142  	if err != nil {
   143  		return false, err
   144  	}
   145  
   146  	allowed, err = evaluator.Evaluate(context.Background(), inputs)
   147  	if err != nil {
   148  		return false, err
   149  	}
   150  
   151  	return allowed, nil
   152  }
   153  
   154  func opaInputs(req *Request, data json.RawMessage, site string, claims *tokens.ClientIDClaims) (map[string]any, error) {
   155  	dat := map[string]any{}
   156  	err := json.Unmarshal(data, &dat)
   157  	if err != nil {
   158  		return nil, err
   159  	}
   160  
   161  	// lame deep copy/data convert thing happening here
   162  	jclaims, err := json.Marshal(claims)
   163  	if err != nil {
   164  		return nil, fmt.Errorf("could not JSON encode claims")
   165  	}
   166  
   167  	cdat := new(map[string]any)
   168  	err = json.Unmarshal(jclaims, &cdat)
   169  	if err != nil {
   170  		return nil, fmt.Errorf("could not JSON encode claims")
   171  	}
   172  
   173  	return map[string]any{
   174  		"agent":      req.Agent,
   175  		"action":     req.Action,
   176  		"data":       data,
   177  		"sender":     req.SenderID,
   178  		"collective": req.Collective,
   179  		"ttl":        req.TTL,
   180  		"time":       req.Time,
   181  		"site":       site,
   182  		"claims":     cdat,
   183  	}, nil
   184  }
   185  
   186  func opaFunctionsMap(req *Request) []func(r *rego.Rego) {
   187  	return []func(r *rego.Rego){
   188  		rego.Function1(&rego.Function{Name: "requires_filter", Decl: types.NewFunction(types.Args(), types.B)}, opaFuncRequiresFilter(req)),
   189  		rego.Function1(&rego.Function{Name: "requires_fact_filter", Decl: types.NewFunction(types.Args(types.S), types.B)}, opaFuncRequiresFactFilter(req)),
   190  		rego.Function1(&rego.Function{Name: "requires_class_filter", Decl: types.NewFunction(types.Args(types.S), types.B)}, opaFuncRequiresClassFilter(req)),
   191  		rego.Function1(&rego.Function{Name: "requires_identity_filter", Decl: types.NewFunction(types.Args(types.S), types.B)}, opaFuncRequiresIdentityFilter(req)),
   192  	}
   193  }
   194  
   195  func opaFuncRequiresFilter(req *Request) func(_ rego.BuiltinContext, a *ast.Term) (*ast.Term, error) {
   196  	return func(_ rego.BuiltinContext, a *ast.Term) (*ast.Term, error) {
   197  		// agent is always set, so we don't check it else it will always be true
   198  		if len(req.Filter.ClassFilters()) > 0 || len(req.Filter.IdentityFilters()) > 0 || len(req.Filter.FactFilters()) > 0 || len(req.Filter.CompoundFilters()) > 0 {
   199  			return ast.BooleanTerm(true), nil
   200  		}
   201  
   202  		return ast.BooleanTerm(false), nil
   203  	}
   204  
   205  }
   206  
   207  func opaFuncRequiresIdentityFilter(req *Request) func(_ rego.BuiltinContext, a *ast.Term) (*ast.Term, error) {
   208  	return func(_ rego.BuiltinContext, a *ast.Term) (*ast.Term, error) {
   209  		str, ok := a.Value.(ast.String)
   210  		if !ok {
   211  			return ast.BooleanTerm(false), fmt.Errorf("invalid identity matcher received")
   212  		}
   213  
   214  		want := string(str)
   215  		for _, f := range req.Filter.IdentityFilters() {
   216  			if f == want {
   217  				return ast.BooleanTerm(true), nil
   218  			}
   219  		}
   220  
   221  		return ast.BooleanTerm(false), nil
   222  	}
   223  }
   224  
   225  func opaFuncRequiresClassFilter(req *Request) func(_ rego.BuiltinContext, a *ast.Term) (*ast.Term, error) {
   226  	return func(_ rego.BuiltinContext, a *ast.Term) (*ast.Term, error) {
   227  		str, ok := a.Value.(ast.String)
   228  		if !ok {
   229  			return ast.BooleanTerm(false), fmt.Errorf("invalid class matcher received")
   230  		}
   231  
   232  		want := string(str)
   233  
   234  		for _, f := range req.Filter.ClassFilters() {
   235  			if f == want {
   236  				return ast.BooleanTerm(true), nil
   237  			}
   238  		}
   239  
   240  		return ast.BooleanTerm(false), nil
   241  	}
   242  }
   243  
   244  func opaFuncRequiresFactFilter(req *Request) func(_ rego.BuiltinContext, a *ast.Term) (*ast.Term, error) {
   245  	return func(_ rego.BuiltinContext, a *ast.Term) (*ast.Term, error) {
   246  		str, ok := a.Value.(ast.String)
   247  		if !ok {
   248  			return ast.BooleanTerm(false), fmt.Errorf("invalid fact matcher received")
   249  		}
   250  
   251  		want, err := client.ParseFactFilterString(string(str))
   252  		if err != nil {
   253  			return ast.BooleanTerm(false), fmt.Errorf("invalid fact matcher received: %s", err)
   254  		}
   255  
   256  		for _, f := range req.Filter.Fact {
   257  			if want.Fact == f.Fact && want.Operator == f.Operator && want.Value == f.Value {
   258  				return ast.BooleanTerm(true), nil
   259  			}
   260  		}
   261  
   262  		return ast.BooleanTerm(false), nil
   263  	}
   264  }