github.com/rajeev159/opa@v0.45.0/topdown/exported_test.go (about)

     1  // Copyright 2020 The OPA Authors.  All rights reserved.
     2  // Use of this source code is governed by an Apache2
     3  // license that can be found in the LICENSE file.
     4  
     5  package topdown
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"os"
    11  	"runtime"
    12  	"sort"
    13  	"strings"
    14  	"testing"
    15  
    16  	"github.com/open-policy-agent/opa/ast"
    17  	"github.com/open-policy-agent/opa/storage"
    18  	"github.com/open-policy-agent/opa/storage/inmem"
    19  	"github.com/open-policy-agent/opa/test/cases"
    20  )
    21  
    22  var x508Exceptions = []string{
    23  	"cryptox509parsecertificates/invalid DER or PEM data, b64",
    24  }
    25  
    26  func isException(note string) bool {
    27  	for _, exc := range x508Exceptions {
    28  		if note == exc {
    29  			return true
    30  		}
    31  	}
    32  	return false
    33  }
    34  
    35  func TestRego(t *testing.T) {
    36  	for _, tc := range cases.MustLoad("../test/cases/testdata").Sorted().Cases {
    37  		if strings.HasPrefix(runtime.Version(), "go1.16") && isException(tc.Note) {
    38  			t.Run(tc.Note, func(t *testing.T) {
    39  				t.Skip("skipped for go1.16, x509 errors differ")
    40  			})
    41  			continue
    42  		}
    43  		t.Run(tc.Note, func(t *testing.T) {
    44  			testRun(t, tc)
    45  		})
    46  	}
    47  }
    48  
    49  func TestOPARego(t *testing.T) {
    50  	for _, tc := range cases.MustLoad("testdata/cases").Sorted().Cases {
    51  		t.Run(tc.Note, func(t *testing.T) {
    52  			testRun(t, tc)
    53  		})
    54  	}
    55  }
    56  
    57  func testRun(t *testing.T, tc cases.TestCase) {
    58  
    59  	ctx := context.Background()
    60  
    61  	modules := map[string]string{}
    62  	for i, module := range tc.Modules {
    63  		modules[fmt.Sprintf("test-%d.rego", i)] = module
    64  	}
    65  
    66  	compiler := ast.MustCompileModules(modules)
    67  	query, err := compiler.QueryCompiler().Compile(ast.MustParseBody(tc.Query))
    68  
    69  	if err != nil {
    70  		t.Fatal(err)
    71  	}
    72  
    73  	var store storage.Store
    74  
    75  	if tc.Data != nil {
    76  		store = inmem.NewFromObject(*tc.Data)
    77  	} else {
    78  		store = inmem.New()
    79  	}
    80  
    81  	txn := storage.NewTransactionOrDie(ctx, store)
    82  
    83  	var input *ast.Term
    84  
    85  	if tc.InputTerm != nil {
    86  		input = ast.MustParseTerm(*tc.InputTerm)
    87  	} else if tc.Input != nil {
    88  		input = ast.NewTerm(ast.MustInterfaceToValue(*tc.Input))
    89  	}
    90  
    91  	buf := NewBufferTracer()
    92  	rs, err := NewQuery(query).
    93  		WithCompiler(compiler).
    94  		WithStore(store).
    95  		WithTransaction(txn).
    96  		WithInput(input).
    97  		WithStrictBuiltinErrors(tc.StrictError).
    98  		WithTracer(buf).
    99  		Run(ctx)
   100  
   101  	if tc.WantError != nil {
   102  		testAssertErrorText(t, *tc.WantError, err)
   103  	}
   104  
   105  	if tc.WantErrorCode != nil {
   106  		testAssertErrorCode(t, *tc.WantErrorCode, err)
   107  	}
   108  
   109  	if err != nil && tc.WantErrorCode == nil && tc.WantError == nil {
   110  		t.Fatalf("unexpected error: %v", err)
   111  	}
   112  
   113  	if tc.WantResult != nil {
   114  		testAssertResultSet(t, *tc.WantResult, rs, tc.SortBindings)
   115  	}
   116  
   117  	if tc.WantResult == nil && tc.WantErrorCode == nil && tc.WantError == nil {
   118  		t.Fatal("expected one of: 'want_result', 'want_error_code', or 'want_error'")
   119  	}
   120  
   121  	if testing.Verbose() {
   122  		PrettyTrace(os.Stderr, *buf)
   123  	}
   124  }
   125  
   126  func testAssertResultSet(t *testing.T, wantResult []map[string]interface{}, rs QueryResultSet, sortBindings bool) {
   127  
   128  	exp := ast.NewSet()
   129  
   130  	for _, b := range wantResult {
   131  		obj := ast.NewObject()
   132  		for k, v := range b {
   133  			obj.Insert(ast.StringTerm(k), ast.NewTerm(ast.MustInterfaceToValue(v)))
   134  		}
   135  		exp.Add(ast.NewTerm(obj))
   136  	}
   137  
   138  	got := ast.NewSet()
   139  
   140  	for _, b := range rs {
   141  		obj := ast.NewObject()
   142  		for k, term := range b {
   143  			v, err := ast.JSON(term.Value)
   144  			if err != nil {
   145  				t.Fatal(err)
   146  			}
   147  			if sortBindings {
   148  				sort.Sort(resultSet(v.([]interface{})))
   149  			}
   150  			obj.Insert(ast.StringTerm(string(k)), ast.NewTerm(ast.MustInterfaceToValue(v)))
   151  		}
   152  		got.Add(ast.NewTerm(obj))
   153  	}
   154  
   155  	if exp.Compare(got) != 0 {
   156  		t.Fatalf("unexpected query result:\nexp: %v\ngot: %v", exp, got)
   157  	}
   158  }
   159  
   160  func testAssertErrorCode(t *testing.T, wantErrorCode string, err error) {
   161  	e, ok := err.(*Error)
   162  	if !ok {
   163  		t.Fatal("expected topdown error but got:", err)
   164  	}
   165  
   166  	if e.Code != wantErrorCode {
   167  		t.Fatalf("expected error code %q but got %q", wantErrorCode, e.Code)
   168  	}
   169  }
   170  
   171  func testAssertErrorText(t *testing.T, wantText string, err error) {
   172  	if err == nil {
   173  		t.Fatal("expected error but got success")
   174  	}
   175  	if !strings.Contains(err.Error(), wantText) {
   176  		t.Fatalf("expected topdown error text %q but got: %q", wantText, err.Error())
   177  	}
   178  }