github.com/grafana/tanka@v0.26.1-0.20240506093700-c22cfc35c21a/pkg/jsonnet/imports.go (about)

     1  package jsonnet
     2  
     3  import (
     4  	"crypto/sha256"
     5  	"encoding/base64"
     6  	"fmt"
     7  	"os"
     8  	"path/filepath"
     9  	"regexp"
    10  	"sort"
    11  	"sync"
    12  
    13  	jsonnet "github.com/google/go-jsonnet"
    14  	"github.com/google/go-jsonnet/ast"
    15  	"github.com/google/go-jsonnet/toolutils"
    16  	"github.com/pkg/errors"
    17  
    18  	"github.com/grafana/tanka/pkg/jsonnet/implementations/goimpl"
    19  	"github.com/grafana/tanka/pkg/jsonnet/jpath"
    20  )
    21  
    22  var importsRegexp = regexp.MustCompile(`import(str)?\s+['"]([^'"%()]+)['"]`)
    23  
    24  // TransitiveImports returns all recursive imports of an environment
    25  func TransitiveImports(dir string) ([]string, error) {
    26  	dir, err := filepath.Abs(dir)
    27  	if err != nil {
    28  		return nil, err
    29  	}
    30  
    31  	dir, err = filepath.EvalSymlinks(dir)
    32  	if err != nil {
    33  		return nil, err
    34  	}
    35  
    36  	entrypoint, err := jpath.Entrypoint(dir)
    37  	if err != nil {
    38  		return nil, err
    39  	}
    40  
    41  	sonnet, err := os.ReadFile(entrypoint)
    42  	if err != nil {
    43  		return nil, errors.Wrap(err, "opening file")
    44  	}
    45  
    46  	jpath, _, rootDir, err := jpath.Resolve(dir, false)
    47  	if err != nil {
    48  		return nil, errors.Wrap(err, "resolving JPATH")
    49  	}
    50  
    51  	vm := goimpl.MakeRawVM(jpath, nil, nil, 0)
    52  	node, err := jsonnet.SnippetToAST(filepath.Base(entrypoint), string(sonnet))
    53  	if err != nil {
    54  		return nil, errors.Wrap(err, "creating Jsonnet AST")
    55  	}
    56  
    57  	imports := make(map[string]bool)
    58  	if err = importRecursiveStrict(imports, vm, node, filepath.Base(entrypoint)); err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	paths := make([]string, 0, len(imports)+1)
    63  	for k := range imports {
    64  		// Try to resolve any symlinks; use the original path as a last resort
    65  		p, err := filepath.EvalSymlinks(k)
    66  		if err != nil {
    67  			return nil, errors.Wrap(err, "resolving symlinks")
    68  		}
    69  		paths = append(paths, p)
    70  	}
    71  	paths = append(paths, entrypoint)
    72  
    73  	for i := range paths {
    74  		paths[i], _ = filepath.Rel(rootDir, paths[i])
    75  
    76  		// Normalize path separators for windows
    77  		paths[i] = filepath.ToSlash(paths[i])
    78  	}
    79  	sort.Strings(paths)
    80  
    81  	return paths, nil
    82  }
    83  
    84  // importRecursiveStrict does the same as importRecursive, but returns an error
    85  // if a file is not found during when importing
    86  func importRecursiveStrict(list map[string]bool, vm *jsonnet.VM, node ast.Node, currentPath string) error {
    87  	return importRecursive(list, vm, node, currentPath, false)
    88  }
    89  
    90  // importRecursive takes a Jsonnet VM and recursively imports the AST. Every
    91  // found import is added to the `list` string slice, which will ultimately
    92  // contain all recursive imports
    93  func importRecursive(list map[string]bool, vm *jsonnet.VM, node ast.Node, currentPath string, ignoreMissing bool) error {
    94  	switch node := node.(type) {
    95  	// we have an `import`
    96  	case *ast.Import:
    97  		p := node.File.Value
    98  
    99  		contents, foundAt, err := vm.ImportAST(currentPath, p)
   100  		if err != nil {
   101  			if ignoreMissing {
   102  				return nil
   103  			}
   104  			return fmt.Errorf("importing '%s' from '%s': %w", p, currentPath, err)
   105  		}
   106  
   107  		abs, _ := filepath.Abs(foundAt)
   108  		if list[abs] {
   109  			return nil
   110  		}
   111  
   112  		list[abs] = true
   113  
   114  		if err := importRecursive(list, vm, contents, foundAt, ignoreMissing); err != nil {
   115  			return err
   116  		}
   117  
   118  	// we have an `importstr`
   119  	case *ast.ImportStr:
   120  		p := node.File.Value
   121  
   122  		foundAt, err := vm.ResolveImport(currentPath, p)
   123  		if err != nil {
   124  			if ignoreMissing {
   125  				return nil
   126  			}
   127  			return errors.Wrap(err, "importing string")
   128  		}
   129  
   130  		abs, _ := filepath.Abs(foundAt)
   131  		if list[abs] {
   132  			return nil
   133  		}
   134  
   135  		list[abs] = true
   136  
   137  	// neither `import` nor `importstr`, probably object or similar: try children
   138  	default:
   139  		for _, child := range toolutils.Children(node) {
   140  			if err := importRecursive(list, vm, child, currentPath, ignoreMissing); err != nil {
   141  				return err
   142  			}
   143  		}
   144  	}
   145  	return nil
   146  }
   147  
   148  var fileHashes sync.Map
   149  
   150  // getSnippetHash takes a jsonnet snippet and calculates a hash from its content
   151  // and the content of all of its dependencies.
   152  // File hashes are cached in-memory to optimize multiple executions of this function in a process
   153  func getSnippetHash(vm *jsonnet.VM, path, data string) (string, error) {
   154  	result := map[string]bool{}
   155  	if err := findImportRecursiveRegexp(result, vm, path, data); err != nil {
   156  		return "", err
   157  	}
   158  	fileNames := []string{}
   159  	for file := range result {
   160  		fileNames = append(fileNames, file)
   161  	}
   162  	sort.Strings(fileNames)
   163  
   164  	fullHasher := sha256.New()
   165  	fullHasher.Write([]byte(data))
   166  	for _, file := range fileNames {
   167  		var fileHash []byte
   168  		if got, ok := fileHashes.Load(file); ok {
   169  			fileHash = got.([]byte)
   170  		} else {
   171  			bytes, err := os.ReadFile(file)
   172  			if err != nil {
   173  				return "", err
   174  			}
   175  			hash := sha256.New()
   176  			fileHash = hash.Sum(bytes)
   177  			fileHashes.Store(file, fileHash)
   178  		}
   179  		fullHasher.Write(fileHash)
   180  	}
   181  
   182  	return base64.URLEncoding.EncodeToString(fullHasher.Sum(nil)), nil
   183  }
   184  
   185  // findImportRecursiveRegexp does the same as `importRecursive` but uses a regexp
   186  // rather than parsing the AST of all files. This is much faster, but can lead to
   187  // false positives (e.g. if a string contains `import "foo"`).
   188  func findImportRecursiveRegexp(list map[string]bool, vm *jsonnet.VM, filename, content string) error {
   189  	matches := importsRegexp.FindAllStringSubmatch(content, -1)
   190  
   191  	for _, match := range matches {
   192  		importContents, foundAt, err := vm.ImportData(filename, match[2])
   193  		if err != nil {
   194  			continue
   195  		}
   196  		abs, err := filepath.Abs(foundAt)
   197  		if err != nil {
   198  			return err
   199  		}
   200  
   201  		if list[abs] {
   202  			continue
   203  		}
   204  		list[abs] = true
   205  
   206  		if match[1] == "str" {
   207  			continue
   208  		}
   209  
   210  		if err := findImportRecursiveRegexp(list, vm, abs, importContents); err != nil {
   211  			return err
   212  		}
   213  	}
   214  	return nil
   215  }