github.com/khulnasoft-lab/defsec@v1.0.5-0.20230827010352-5e9f46893d95/pkg/rego/load.go (about) 1 package rego 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "io/fs" 8 "path/filepath" 9 "strings" 10 11 "github.com/open-policy-agent/opa/ast" 12 "github.com/open-policy-agent/opa/bundle" 13 ) 14 15 func isRegoFile(name string) bool { 16 return strings.HasSuffix(name, bundle.RegoExt) && !strings.HasSuffix(name, "_test"+bundle.RegoExt) 17 } 18 19 func isJSONFile(name string) bool { 20 return strings.HasSuffix(name, ".json") 21 } 22 23 func sanitisePath(path string) string { 24 vol := filepath.VolumeName(path) 25 path = strings.TrimPrefix(path, vol) 26 27 return strings.TrimPrefix(strings.TrimPrefix(filepath.ToSlash(path), "./"), "/") 28 } 29 30 func (s *Scanner) loadPoliciesFromDirs(target fs.FS, paths []string) (map[string]*ast.Module, error) { 31 modules := make(map[string]*ast.Module) 32 for _, path := range paths { 33 if err := fs.WalkDir(target, sanitisePath(path), func(path string, info fs.DirEntry, err error) error { 34 if err != nil { 35 return err 36 } 37 if info.IsDir() { 38 return nil 39 } 40 if !isRegoFile(info.Name()) { 41 return nil 42 } 43 data, err := fs.ReadFile(target, filepath.ToSlash(path)) 44 if err != nil { 45 return err 46 } 47 module, err := ast.ParseModuleWithOpts(path, string(data), ast.ParserOptions{ 48 ProcessAnnotation: true, 49 }) 50 if err != nil { 51 s.debug.Log("Failed to load module: %s, err: %s", filepath.ToSlash(path), err.Error()) 52 return nil 53 } 54 modules[path] = module 55 return nil 56 }); err != nil { 57 return nil, err 58 } 59 } 60 return modules, nil 61 } 62 63 func (s *Scanner) loadPoliciesFromReaders(readers []io.Reader) (map[string]*ast.Module, error) { 64 modules := make(map[string]*ast.Module) 65 for i, r := range readers { 66 moduleName := fmt.Sprintf("reader_%d", i) 67 data, err := io.ReadAll(r) 68 if err != nil { 69 return nil, err 70 } 71 module, err := ast.ParseModuleWithOpts(moduleName, string(data), ast.ParserOptions{ 72 ProcessAnnotation: true, 73 }) 74 if err != nil { 75 return nil, err 76 } 77 modules[moduleName] = module 78 } 79 return modules, nil 80 } 81 82 func (s *Scanner) LoadEmbeddedLibraries() error { 83 if s.policies == nil { 84 s.policies = make(map[string]*ast.Module) 85 } 86 loadedLibs, err := loadEmbeddedLibraries() 87 if err != nil { 88 return fmt.Errorf("failed to load embedded rego libraries: %w", err) 89 } 90 for name, policy := range loadedLibs { 91 s.policies[name] = policy 92 } 93 s.debug.Log("Loaded %d embedded libraries (without embedded policies).", len(loadedLibs)) 94 return nil 95 } 96 97 func (s *Scanner) loadEmbedded(enableEmbeddedLibraries, enableEmbeddedPolicies bool) error { 98 if enableEmbeddedLibraries { 99 loadedLibs, errLoad := loadEmbeddedLibraries() 100 if errLoad != nil { 101 return fmt.Errorf("failed to load embedded rego libraries: %w", errLoad) 102 } 103 for name, policy := range loadedLibs { 104 s.policies[name] = policy 105 } 106 s.debug.Log("Loaded %d embedded libraries.", len(loadedLibs)) 107 } 108 109 if enableEmbeddedPolicies { 110 loaded, err := loadEmbeddedPolicies() 111 if err != nil { 112 return fmt.Errorf("failed to load embedded rego policies: %w", err) 113 } 114 for name, policy := range loaded { 115 s.policies[name] = policy 116 } 117 s.debug.Log("Loaded %d embedded policies.", len(loaded)) 118 } 119 120 return nil 121 } 122 123 func (s *Scanner) LoadPolicies(enableEmbeddedLibraries, enableEmbeddedPolicies bool, srcFS fs.FS, paths []string, readers []io.Reader) error { 124 125 if s.policies == nil { 126 s.policies = make(map[string]*ast.Module) 127 } 128 129 if s.policyFS != nil { 130 s.debug.Log("Overriding filesystem for policies!") 131 srcFS = s.policyFS 132 } 133 134 if err := s.loadEmbedded(enableEmbeddedLibraries, enableEmbeddedPolicies); err != nil { 135 return err 136 } 137 138 var err error 139 if len(paths) > 0 { 140 loaded, err := s.loadPoliciesFromDirs(srcFS, paths) 141 if err != nil { 142 return fmt.Errorf("failed to load rego policies from %s: %w", paths, err) 143 } 144 for name, policy := range loaded { 145 s.policies[name] = policy 146 } 147 s.debug.Log("Loaded %d policies from disk.", len(loaded)) 148 } 149 150 if len(readers) > 0 { 151 loaded, err := s.loadPoliciesFromReaders(readers) 152 if err != nil { 153 return fmt.Errorf("failed to load rego policies from reader(s): %w", err) 154 } 155 for name, policy := range loaded { 156 s.policies[name] = policy 157 } 158 s.debug.Log("Loaded %d policies from reader(s).", len(loaded)) 159 } 160 161 // gather namespaces 162 uniq := make(map[string]struct{}) 163 for _, module := range s.policies { 164 namespace := getModuleNamespace(module) 165 uniq[namespace] = struct{}{} 166 } 167 var namespaces []string 168 for namespace := range uniq { 169 namespaces = append(namespaces, namespace) 170 } 171 172 dataFS := srcFS 173 if s.dataFS != nil { 174 s.debug.Log("Overriding filesystem for data!") 175 dataFS = s.dataFS 176 } 177 store, err := initStore(dataFS, s.dataDirs, namespaces) 178 if err != nil { 179 return fmt.Errorf("unable to load data: %w", err) 180 } 181 s.store = store 182 183 return s.compilePolicies(srcFS, paths) 184 } 185 186 func (s *Scanner) prunePoliciesWithError(compiler *ast.Compiler) error { 187 if len(compiler.Errors) > s.regoErrorLimit { 188 s.debug.Log("Error(s) occurred while loading policies") 189 return compiler.Errors 190 } 191 192 for _, e := range compiler.Errors { 193 s.debug.Log("Error occurred while parsing: %s, %s", e.Location.File, e.Error()) 194 delete(s.policies, e.Location.File) 195 } 196 return nil 197 } 198 199 func (s *Scanner) compilePolicies(srcFS fs.FS, paths []string) error { 200 compiler := ast.NewCompiler() 201 schemaSet, custom, err := BuildSchemaSetFromPolicies(s.policies, paths, srcFS) 202 if err != nil { 203 return err 204 } 205 if custom { 206 s.inputSchema = nil // discard auto detected input schema in favour of policy defined schema 207 } 208 209 compiler.WithSchemas(schemaSet) 210 compiler.WithCapabilities(ast.CapabilitiesForThisVersion()) 211 compiler.Compile(s.policies) 212 if compiler.Failed() { 213 if err := s.prunePoliciesWithError(compiler); err != nil { 214 return err 215 } 216 return s.compilePolicies(srcFS, paths) 217 } 218 retriever := NewMetadataRetriever(compiler) 219 220 if err := s.filterModules(retriever); err != nil { 221 return err 222 } 223 if s.inputSchema != nil { 224 schemaSet := ast.NewSchemaSet() 225 schemaSet.Put(ast.MustParseRef("schema.input"), s.inputSchema) 226 compiler.WithSchemas(schemaSet) 227 compiler.Compile(s.policies) 228 if compiler.Failed() { 229 if err := s.prunePoliciesWithError(compiler); err != nil { 230 return err 231 } 232 return s.compilePolicies(srcFS, paths) 233 } 234 } 235 s.compiler = compiler 236 s.retriever = retriever 237 return nil 238 } 239 240 func (s *Scanner) filterModules(retriever *MetadataRetriever) error { 241 242 filtered := make(map[string]*ast.Module) 243 for name, module := range s.policies { 244 meta, err := retriever.RetrieveMetadata(context.TODO(), module) 245 if err != nil { 246 return err 247 } 248 if len(meta.InputOptions.Selectors) == 0 { 249 s.debug.Log("WARNING: Module %s has no input selectors - it will be loaded for all inputs!", name) 250 filtered[name] = module 251 continue 252 } 253 for _, selector := range meta.InputOptions.Selectors { 254 if selector.Type == string(s.sourceType) { 255 filtered[name] = module 256 break 257 } 258 } 259 } 260 261 s.policies = filtered 262 return nil 263 }