github.com/khulnasoft-lab/defsec@v1.0.5-0.20230827010352-5e9f46893d95/pkg/rego/load.go (about)

     1  package rego
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"io/fs"
     8  	"path/filepath"
     9  	"strings"
    10  
    11  	"github.com/open-policy-agent/opa/ast"
    12  	"github.com/open-policy-agent/opa/bundle"
    13  )
    14  
    15  func isRegoFile(name string) bool {
    16  	return strings.HasSuffix(name, bundle.RegoExt) && !strings.HasSuffix(name, "_test"+bundle.RegoExt)
    17  }
    18  
    19  func isJSONFile(name string) bool {
    20  	return strings.HasSuffix(name, ".json")
    21  }
    22  
    23  func sanitisePath(path string) string {
    24  	vol := filepath.VolumeName(path)
    25  	path = strings.TrimPrefix(path, vol)
    26  
    27  	return strings.TrimPrefix(strings.TrimPrefix(filepath.ToSlash(path), "./"), "/")
    28  }
    29  
    30  func (s *Scanner) loadPoliciesFromDirs(target fs.FS, paths []string) (map[string]*ast.Module, error) {
    31  	modules := make(map[string]*ast.Module)
    32  	for _, path := range paths {
    33  		if err := fs.WalkDir(target, sanitisePath(path), func(path string, info fs.DirEntry, err error) error {
    34  			if err != nil {
    35  				return err
    36  			}
    37  			if info.IsDir() {
    38  				return nil
    39  			}
    40  			if !isRegoFile(info.Name()) {
    41  				return nil
    42  			}
    43  			data, err := fs.ReadFile(target, filepath.ToSlash(path))
    44  			if err != nil {
    45  				return err
    46  			}
    47  			module, err := ast.ParseModuleWithOpts(path, string(data), ast.ParserOptions{
    48  				ProcessAnnotation: true,
    49  			})
    50  			if err != nil {
    51  				s.debug.Log("Failed to load module: %s, err: %s", filepath.ToSlash(path), err.Error())
    52  				return nil
    53  			}
    54  			modules[path] = module
    55  			return nil
    56  		}); err != nil {
    57  			return nil, err
    58  		}
    59  	}
    60  	return modules, nil
    61  }
    62  
    63  func (s *Scanner) loadPoliciesFromReaders(readers []io.Reader) (map[string]*ast.Module, error) {
    64  	modules := make(map[string]*ast.Module)
    65  	for i, r := range readers {
    66  		moduleName := fmt.Sprintf("reader_%d", i)
    67  		data, err := io.ReadAll(r)
    68  		if err != nil {
    69  			return nil, err
    70  		}
    71  		module, err := ast.ParseModuleWithOpts(moduleName, string(data), ast.ParserOptions{
    72  			ProcessAnnotation: true,
    73  		})
    74  		if err != nil {
    75  			return nil, err
    76  		}
    77  		modules[moduleName] = module
    78  	}
    79  	return modules, nil
    80  }
    81  
    82  func (s *Scanner) LoadEmbeddedLibraries() error {
    83  	if s.policies == nil {
    84  		s.policies = make(map[string]*ast.Module)
    85  	}
    86  	loadedLibs, err := loadEmbeddedLibraries()
    87  	if err != nil {
    88  		return fmt.Errorf("failed to load embedded rego libraries: %w", err)
    89  	}
    90  	for name, policy := range loadedLibs {
    91  		s.policies[name] = policy
    92  	}
    93  	s.debug.Log("Loaded %d embedded libraries (without embedded policies).", len(loadedLibs))
    94  	return nil
    95  }
    96  
    97  func (s *Scanner) loadEmbedded(enableEmbeddedLibraries, enableEmbeddedPolicies bool) error {
    98  	if enableEmbeddedLibraries {
    99  		loadedLibs, errLoad := loadEmbeddedLibraries()
   100  		if errLoad != nil {
   101  			return fmt.Errorf("failed to load embedded rego libraries: %w", errLoad)
   102  		}
   103  		for name, policy := range loadedLibs {
   104  			s.policies[name] = policy
   105  		}
   106  		s.debug.Log("Loaded %d embedded libraries.", len(loadedLibs))
   107  	}
   108  
   109  	if enableEmbeddedPolicies {
   110  		loaded, err := loadEmbeddedPolicies()
   111  		if err != nil {
   112  			return fmt.Errorf("failed to load embedded rego policies: %w", err)
   113  		}
   114  		for name, policy := range loaded {
   115  			s.policies[name] = policy
   116  		}
   117  		s.debug.Log("Loaded %d embedded policies.", len(loaded))
   118  	}
   119  
   120  	return nil
   121  }
   122  
   123  func (s *Scanner) LoadPolicies(enableEmbeddedLibraries, enableEmbeddedPolicies bool, srcFS fs.FS, paths []string, readers []io.Reader) error {
   124  
   125  	if s.policies == nil {
   126  		s.policies = make(map[string]*ast.Module)
   127  	}
   128  
   129  	if s.policyFS != nil {
   130  		s.debug.Log("Overriding filesystem for policies!")
   131  		srcFS = s.policyFS
   132  	}
   133  
   134  	if err := s.loadEmbedded(enableEmbeddedLibraries, enableEmbeddedPolicies); err != nil {
   135  		return err
   136  	}
   137  
   138  	var err error
   139  	if len(paths) > 0 {
   140  		loaded, err := s.loadPoliciesFromDirs(srcFS, paths)
   141  		if err != nil {
   142  			return fmt.Errorf("failed to load rego policies from %s: %w", paths, err)
   143  		}
   144  		for name, policy := range loaded {
   145  			s.policies[name] = policy
   146  		}
   147  		s.debug.Log("Loaded %d policies from disk.", len(loaded))
   148  	}
   149  
   150  	if len(readers) > 0 {
   151  		loaded, err := s.loadPoliciesFromReaders(readers)
   152  		if err != nil {
   153  			return fmt.Errorf("failed to load rego policies from reader(s): %w", err)
   154  		}
   155  		for name, policy := range loaded {
   156  			s.policies[name] = policy
   157  		}
   158  		s.debug.Log("Loaded %d policies from reader(s).", len(loaded))
   159  	}
   160  
   161  	// gather namespaces
   162  	uniq := make(map[string]struct{})
   163  	for _, module := range s.policies {
   164  		namespace := getModuleNamespace(module)
   165  		uniq[namespace] = struct{}{}
   166  	}
   167  	var namespaces []string
   168  	for namespace := range uniq {
   169  		namespaces = append(namespaces, namespace)
   170  	}
   171  
   172  	dataFS := srcFS
   173  	if s.dataFS != nil {
   174  		s.debug.Log("Overriding filesystem for data!")
   175  		dataFS = s.dataFS
   176  	}
   177  	store, err := initStore(dataFS, s.dataDirs, namespaces)
   178  	if err != nil {
   179  		return fmt.Errorf("unable to load data: %w", err)
   180  	}
   181  	s.store = store
   182  
   183  	return s.compilePolicies(srcFS, paths)
   184  }
   185  
   186  func (s *Scanner) prunePoliciesWithError(compiler *ast.Compiler) error {
   187  	if len(compiler.Errors) > s.regoErrorLimit {
   188  		s.debug.Log("Error(s) occurred while loading policies")
   189  		return compiler.Errors
   190  	}
   191  
   192  	for _, e := range compiler.Errors {
   193  		s.debug.Log("Error occurred while parsing: %s, %s", e.Location.File, e.Error())
   194  		delete(s.policies, e.Location.File)
   195  	}
   196  	return nil
   197  }
   198  
   199  func (s *Scanner) compilePolicies(srcFS fs.FS, paths []string) error {
   200  	compiler := ast.NewCompiler()
   201  	schemaSet, custom, err := BuildSchemaSetFromPolicies(s.policies, paths, srcFS)
   202  	if err != nil {
   203  		return err
   204  	}
   205  	if custom {
   206  		s.inputSchema = nil // discard auto detected input schema in favour of policy defined schema
   207  	}
   208  
   209  	compiler.WithSchemas(schemaSet)
   210  	compiler.WithCapabilities(ast.CapabilitiesForThisVersion())
   211  	compiler.Compile(s.policies)
   212  	if compiler.Failed() {
   213  		if err := s.prunePoliciesWithError(compiler); err != nil {
   214  			return err
   215  		}
   216  		return s.compilePolicies(srcFS, paths)
   217  	}
   218  	retriever := NewMetadataRetriever(compiler)
   219  
   220  	if err := s.filterModules(retriever); err != nil {
   221  		return err
   222  	}
   223  	if s.inputSchema != nil {
   224  		schemaSet := ast.NewSchemaSet()
   225  		schemaSet.Put(ast.MustParseRef("schema.input"), s.inputSchema)
   226  		compiler.WithSchemas(schemaSet)
   227  		compiler.Compile(s.policies)
   228  		if compiler.Failed() {
   229  			if err := s.prunePoliciesWithError(compiler); err != nil {
   230  				return err
   231  			}
   232  			return s.compilePolicies(srcFS, paths)
   233  		}
   234  	}
   235  	s.compiler = compiler
   236  	s.retriever = retriever
   237  	return nil
   238  }
   239  
   240  func (s *Scanner) filterModules(retriever *MetadataRetriever) error {
   241  
   242  	filtered := make(map[string]*ast.Module)
   243  	for name, module := range s.policies {
   244  		meta, err := retriever.RetrieveMetadata(context.TODO(), module)
   245  		if err != nil {
   246  			return err
   247  		}
   248  		if len(meta.InputOptions.Selectors) == 0 {
   249  			s.debug.Log("WARNING: Module %s has no input selectors - it will be loaded for all inputs!", name)
   250  			filtered[name] = module
   251  			continue
   252  		}
   253  		for _, selector := range meta.InputOptions.Selectors {
   254  			if selector.Type == string(s.sourceType) {
   255  				filtered[name] = module
   256  				break
   257  			}
   258  		}
   259  	}
   260  
   261  	s.policies = filtered
   262  	return nil
   263  }