github.com/rajeev159/opa@v0.45.0/topdown/exported_test.go (about) 1 // Copyright 2020 The OPA Authors. All rights reserved. 2 // Use of this source code is governed by an Apache2 3 // license that can be found in the LICENSE file. 4 5 package topdown 6 7 import ( 8 "context" 9 "fmt" 10 "os" 11 "runtime" 12 "sort" 13 "strings" 14 "testing" 15 16 "github.com/open-policy-agent/opa/ast" 17 "github.com/open-policy-agent/opa/storage" 18 "github.com/open-policy-agent/opa/storage/inmem" 19 "github.com/open-policy-agent/opa/test/cases" 20 ) 21 22 var x508Exceptions = []string{ 23 "cryptox509parsecertificates/invalid DER or PEM data, b64", 24 } 25 26 func isException(note string) bool { 27 for _, exc := range x508Exceptions { 28 if note == exc { 29 return true 30 } 31 } 32 return false 33 } 34 35 func TestRego(t *testing.T) { 36 for _, tc := range cases.MustLoad("../test/cases/testdata").Sorted().Cases { 37 if strings.HasPrefix(runtime.Version(), "go1.16") && isException(tc.Note) { 38 t.Run(tc.Note, func(t *testing.T) { 39 t.Skip("skipped for go1.16, x509 errors differ") 40 }) 41 continue 42 } 43 t.Run(tc.Note, func(t *testing.T) { 44 testRun(t, tc) 45 }) 46 } 47 } 48 49 func TestOPARego(t *testing.T) { 50 for _, tc := range cases.MustLoad("testdata/cases").Sorted().Cases { 51 t.Run(tc.Note, func(t *testing.T) { 52 testRun(t, tc) 53 }) 54 } 55 } 56 57 func testRun(t *testing.T, tc cases.TestCase) { 58 59 ctx := context.Background() 60 61 modules := map[string]string{} 62 for i, module := range tc.Modules { 63 modules[fmt.Sprintf("test-%d.rego", i)] = module 64 } 65 66 compiler := ast.MustCompileModules(modules) 67 query, err := compiler.QueryCompiler().Compile(ast.MustParseBody(tc.Query)) 68 69 if err != nil { 70 t.Fatal(err) 71 } 72 73 var store storage.Store 74 75 if tc.Data != nil { 76 store = inmem.NewFromObject(*tc.Data) 77 } else { 78 store = inmem.New() 79 } 80 81 txn := storage.NewTransactionOrDie(ctx, store) 82 83 var input *ast.Term 84 85 if tc.InputTerm != nil { 86 input = ast.MustParseTerm(*tc.InputTerm) 87 } else if tc.Input != nil { 88 input = ast.NewTerm(ast.MustInterfaceToValue(*tc.Input)) 89 } 90 91 buf := NewBufferTracer() 92 rs, err := NewQuery(query). 93 WithCompiler(compiler). 94 WithStore(store). 95 WithTransaction(txn). 96 WithInput(input). 97 WithStrictBuiltinErrors(tc.StrictError). 98 WithTracer(buf). 99 Run(ctx) 100 101 if tc.WantError != nil { 102 testAssertErrorText(t, *tc.WantError, err) 103 } 104 105 if tc.WantErrorCode != nil { 106 testAssertErrorCode(t, *tc.WantErrorCode, err) 107 } 108 109 if err != nil && tc.WantErrorCode == nil && tc.WantError == nil { 110 t.Fatalf("unexpected error: %v", err) 111 } 112 113 if tc.WantResult != nil { 114 testAssertResultSet(t, *tc.WantResult, rs, tc.SortBindings) 115 } 116 117 if tc.WantResult == nil && tc.WantErrorCode == nil && tc.WantError == nil { 118 t.Fatal("expected one of: 'want_result', 'want_error_code', or 'want_error'") 119 } 120 121 if testing.Verbose() { 122 PrettyTrace(os.Stderr, *buf) 123 } 124 } 125 126 func testAssertResultSet(t *testing.T, wantResult []map[string]interface{}, rs QueryResultSet, sortBindings bool) { 127 128 exp := ast.NewSet() 129 130 for _, b := range wantResult { 131 obj := ast.NewObject() 132 for k, v := range b { 133 obj.Insert(ast.StringTerm(k), ast.NewTerm(ast.MustInterfaceToValue(v))) 134 } 135 exp.Add(ast.NewTerm(obj)) 136 } 137 138 got := ast.NewSet() 139 140 for _, b := range rs { 141 obj := ast.NewObject() 142 for k, term := range b { 143 v, err := ast.JSON(term.Value) 144 if err != nil { 145 t.Fatal(err) 146 } 147 if sortBindings { 148 sort.Sort(resultSet(v.([]interface{}))) 149 } 150 obj.Insert(ast.StringTerm(string(k)), ast.NewTerm(ast.MustInterfaceToValue(v))) 151 } 152 got.Add(ast.NewTerm(obj)) 153 } 154 155 if exp.Compare(got) != 0 { 156 t.Fatalf("unexpected query result:\nexp: %v\ngot: %v", exp, got) 157 } 158 } 159 160 func testAssertErrorCode(t *testing.T, wantErrorCode string, err error) { 161 e, ok := err.(*Error) 162 if !ok { 163 t.Fatal("expected topdown error but got:", err) 164 } 165 166 if e.Code != wantErrorCode { 167 t.Fatalf("expected error code %q but got %q", wantErrorCode, e.Code) 168 } 169 } 170 171 func testAssertErrorText(t *testing.T, wantText string, err error) { 172 if err == nil { 173 t.Fatal("expected error but got success") 174 } 175 if !strings.Contains(err.Error(), wantText) { 176 t.Fatalf("expected topdown error text %q but got: %q", wantText, err.Error()) 177 } 178 }