github.com/choria-io/go-choria@v0.28.1-0.20240416190746-b3bf9c7d5a45/providers/agent/mcorpc/authz_rego.go (about)

     1  // Copyright (c) 2020-2022, R.I. Pienaar and the Choria Project contributors
     2  //
     3  // SPDX-License-Identifier: Apache-2.0
     4  
     5  package mcorpc
     6  
     7  import (
     8  	"context"
     9  	"encoding/json"
    10  	"fmt"
    11  	"path/filepath"
    12  
    13  	"github.com/choria-io/go-choria/config"
    14  	"github.com/choria-io/go-choria/internal/util"
    15  	"github.com/choria-io/go-choria/opa"
    16  	"github.com/open-policy-agent/opa/ast"
    17  	"github.com/sirupsen/logrus"
    18  )
    19  
    20  type regoPolicy struct {
    21  	cfg   *config.Config
    22  	req   *Request
    23  	agent *Agent
    24  	log   *logrus.Entry
    25  }
    26  
    27  func regoPolicyAuthorize(req *Request, agent *Agent, log *logrus.Entry) (bool, error) {
    28  	logger := log.WithFields(logrus.Fields{
    29  		"authorizer": "regoPolicy",
    30  		"agent":      agent.Name(),
    31  		"request":    req.RequestID,
    32  	})
    33  
    34  	authz := &regoPolicy{
    35  		cfg:   agent.Config,
    36  		req:   req,
    37  		agent: agent,
    38  		log:   logger,
    39  	}
    40  
    41  	return authz.authorize()
    42  }
    43  
    44  func (r *regoPolicy) authorize() (bool, error) {
    45  	policyFile, err := r.lookupPolicyFile()
    46  	if err != nil {
    47  		return false, err
    48  	}
    49  
    50  	if policyFile == "" {
    51  		return false, fmt.Errorf("policy file could not be found")
    52  	}
    53  
    54  	eopts := []opa.Option{
    55  		opa.Logger(r.log),
    56  		opa.File(policyFile),
    57  	}
    58  
    59  	if r.log.Logger.GetLevel() == logrus.DebugLevel || r.enableTracing() {
    60  		r.log.Debugf("regoInputs: %v", r.regoInputs())
    61  		eopts = append(eopts, opa.Trace())
    62  	}
    63  
    64  	evaluator, err := opa.New("io.choria.mcorpc.authpolicy", "data.io.choria.mcorpc.authpolicy.allow", eopts...)
    65  	if err != nil {
    66  		return false, err
    67  	}
    68  
    69  	allowed, err := evaluator.Evaluate(context.Background(), r.regoInputs())
    70  	switch err := err.(type) {
    71  	case nil:
    72  		break
    73  	case ast.Errors:
    74  		for _, e := range err {
    75  			r.log.Info("code: ", e.Code)
    76  			r.log.Info("row: ", e.Location.Row)
    77  			r.log.Info("filename: ", policyFile)
    78  		}
    79  		return false, err
    80  	default:
    81  		return false, err
    82  	}
    83  
    84  	return allowed, nil
    85  }
    86  
    87  func (r *regoPolicy) lookupPolicyFile() (string, error) {
    88  	dir := filepath.Join(filepath.Dir(r.cfg.ConfigFile), "policies", "rego")
    89  
    90  	regoPolicy := filepath.Join(dir, r.agent.Name()+".rego")
    91  
    92  	r.log.Debugf("Looking up rego policy in %s", regoPolicy)
    93  	if util.FileExist(regoPolicy) {
    94  		r.log.Debugf("Using policy file: %s", regoPolicy)
    95  		return regoPolicy, nil
    96  	}
    97  
    98  	defaultPolicy := filepath.Join(dir, "default.rego")
    99  	if util.FileExist(defaultPolicy) {
   100  		r.log.Debugf("Using policy file: %s", defaultPolicy)
   101  		return defaultPolicy, nil
   102  	}
   103  	return "", fmt.Errorf("no policy %s found for %s in %s", defaultPolicy, r.agent.Name(), dir)
   104  
   105  }
   106  
   107  func (r *regoPolicy) regoInputs() map[string]any {
   108  	facts := map[string]any{}
   109  
   110  	sif := r.agent.ServerInfoSource.Facts()
   111  	err := json.Unmarshal(sif, &facts)
   112  	if err != nil {
   113  		r.log.Errorf("could not marshal facts for rego policy: %v", err)
   114  	}
   115  
   116  	data := make(map[string]any)
   117  	err = json.Unmarshal(r.req.Data, &data)
   118  	if err != nil {
   119  		r.log.Errorf("could not marshal data from request: %v", err)
   120  	}
   121  
   122  	return map[string]any{
   123  		"agent":          r.req.Agent,
   124  		"action":         r.req.Action,
   125  		"callerid":       r.req.CallerID,
   126  		"collective":     r.req.Collective,
   127  		"data":           data,
   128  		"ttl":            r.req.TTL,
   129  		"time":           r.req.Time,
   130  		"facts":          facts,
   131  		"classes":        r.agent.ServerInfoSource.Classes(),
   132  		"agents":         r.agent.ServerInfoSource.KnownAgents(),
   133  		"provision_mode": r.agent.Choria.ProvisionMode(),
   134  	}
   135  }
   136  
   137  func (r *regoPolicy) enableTracing() bool {
   138  	tracing, err := util.StrToBool(r.cfg.Option("plugin.regopolicy.tracing", "n"))
   139  	if err != nil {
   140  		return false
   141  	}
   142  
   143  	return tracing
   144  }