github.com/snyk/vervet/v4@v4.27.2/internal/compiler/compiler.go (about)

     1  package compiler
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"html/template"
     7  	"io/fs"
     8  	"io/ioutil"
     9  	"log"
    10  	"os"
    11  	"path/filepath"
    12  
    13  	"github.com/bmatcuk/doublestar/v4"
    14  	"github.com/getkin/kin-openapi/openapi3"
    15  	"github.com/ghodss/yaml"
    16  	"go.uber.org/multierr"
    17  
    18  	"github.com/snyk/vervet/v4"
    19  	"github.com/snyk/vervet/v4/config"
    20  	"github.com/snyk/vervet/v4/internal/files"
    21  	"github.com/snyk/vervet/v4/internal/linter"
    22  	"github.com/snyk/vervet/v4/internal/linter/optic"
    23  	"github.com/snyk/vervet/v4/internal/linter/spectral"
    24  )
    25  
    26  // A Compiler checks and builds versioned API resource inputs into aggregated
    27  // OpenAPI versioned outputs, as determined by an API project configuration.
    28  type Compiler struct {
    29  	apis    map[string]*api
    30  	linters map[string]linter.Linter
    31  
    32  	newLinter func(ctx context.Context, lc *config.Linter) (linter.Linter, error)
    33  }
    34  
    35  // CompilerOption applies a configuration option to a Compiler.
    36  type CompilerOption func(*Compiler) error
    37  
    38  // LinterFactory configures a Compiler to use a custom factory function for
    39  // instantiating Linters.
    40  func LinterFactory(f func(ctx context.Context, lc *config.Linter) (linter.Linter, error)) CompilerOption {
    41  	return func(c *Compiler) error {
    42  		c.newLinter = f
    43  		return nil
    44  	}
    45  }
    46  
    47  func defaultLinterFactory(ctx context.Context, lc *config.Linter) (linter.Linter, error) {
    48  	if lc.Spectral != nil {
    49  		return spectral.New(ctx, lc.Spectral)
    50  	} else if lc.SweaterComb != nil {
    51  		return optic.New(ctx, lc.SweaterComb)
    52  	} else if lc.OpticCI != nil {
    53  		return optic.New(ctx, lc.OpticCI)
    54  	}
    55  	return nil, fmt.Errorf("invalid linter (linters.%s)", lc.Name)
    56  }
    57  
    58  type api struct {
    59  	resources       []*resourceSet
    60  	overlayIncludes []*vervet.Document
    61  	overlayInlines  []*openapi3.T
    62  	output          *output
    63  }
    64  
    65  type resourceSet struct {
    66  	path            string
    67  	linter          linter.Linter
    68  	linterOverrides map[string]map[string]config.Linter
    69  	sourceFiles     []string
    70  	lintFiles       []string
    71  }
    72  
    73  type output struct {
    74  	paths  []string
    75  	linter linter.Linter
    76  }
    77  
    78  // New returns a new Compiler for a given project configuration.
    79  func New(ctx context.Context, proj *config.Project, options ...CompilerOption) (*Compiler, error) {
    80  	compiler := &Compiler{
    81  		apis:      map[string]*api{},
    82  		linters:   map[string]linter.Linter{},
    83  		newLinter: defaultLinterFactory,
    84  	}
    85  	for i := range options {
    86  		err := options[i](compiler)
    87  		if err != nil {
    88  			return nil, err
    89  		}
    90  	}
    91  	// set up linters
    92  	for linterName, linterConfig := range proj.Linters {
    93  		linter, err := compiler.newLinter(ctx, linterConfig)
    94  		if err != nil {
    95  			return nil, fmt.Errorf("%w (linters.%s)", err, linterName)
    96  		}
    97  		compiler.linters[linterName] = linter
    98  	}
    99  	// set up APIs
   100  	for apiName, apiConfig := range proj.APIs {
   101  		a := api{}
   102  
   103  		// Build resources
   104  		for rcIndex, rcConfig := range apiConfig.Resources {
   105  			var err error
   106  			r := &resourceSet{
   107  				path:            rcConfig.Path,
   108  				linter:          compiler.linters[rcConfig.Linter],
   109  				linterOverrides: map[string]map[string]config.Linter{},
   110  			}
   111  			if r.linter != nil {
   112  				r.lintFiles, err = r.linter.Match(rcConfig)
   113  				if err != nil {
   114  					return nil, fmt.Errorf("%w: (apis.%s.resources[%d].path)", err, apiName, rcIndex)
   115  				}
   116  				// TODO: overrides are deprecated by Optic CI, remove soon
   117  				linterOverrides := map[string]map[string]config.Linter{}
   118  				for rcName, versionMap := range rcConfig.LinterOverrides {
   119  					linterOverrides[rcName] = map[string]config.Linter{}
   120  					for version, linter := range versionMap {
   121  						linterOverrides[rcName][version] = *linter
   122  					}
   123  				}
   124  				r.linterOverrides = linterOverrides
   125  			}
   126  			r.sourceFiles, err = ResourceSpecFiles(rcConfig)
   127  			if err != nil {
   128  				return nil, fmt.Errorf("%w: (apis.%s.resources[%d].path)", err, apiName, rcIndex)
   129  			}
   130  			a.resources = append(a.resources, r)
   131  		}
   132  
   133  		// Build overlays
   134  		for overlayIndex, overlayConfig := range apiConfig.Overlays {
   135  			if overlayConfig.Include != "" {
   136  				doc, err := vervet.NewDocumentFile(overlayConfig.Include)
   137  				if err != nil {
   138  					return nil, fmt.Errorf("failed to load overlay %q: %w (apis.%s.overlays[%d])",
   139  						overlayConfig.Include, err, apiName, overlayIndex)
   140  				}
   141  				err = vervet.Localize(doc)
   142  				if err != nil {
   143  					return nil, fmt.Errorf("failed to localize references in %q: %w (apis.%s.overlays[%d]",
   144  						overlayConfig.Include, err, apiName, overlayIndex)
   145  				}
   146  				a.overlayIncludes = append(a.overlayIncludes, doc)
   147  			} else if overlayConfig.Inline != "" {
   148  				docString := os.ExpandEnv(overlayConfig.Inline)
   149  				l := openapi3.NewLoader()
   150  				doc, err := l.LoadFromData([]byte(docString))
   151  				if err != nil {
   152  					return nil, fmt.Errorf("failed to load template: %w (apis.%s.overlays[%d].template)",
   153  						err, apiName, overlayIndex)
   154  				}
   155  				a.overlayInlines = append(a.overlayInlines, doc)
   156  			}
   157  		}
   158  
   159  		// Build output
   160  		if apiConfig.Output != nil {
   161  			paths := apiConfig.Output.Paths
   162  			if len(paths) == 0 && apiConfig.Output.Path != "" {
   163  				paths = []string{apiConfig.Output.Path}
   164  			}
   165  			if len(paths) > 0 {
   166  				a.output = &output{
   167  					paths:  paths,
   168  					linter: compiler.linters[apiConfig.Output.Linter],
   169  				}
   170  			}
   171  		}
   172  
   173  		compiler.apis[apiName] = &a
   174  	}
   175  	return compiler, nil
   176  }
   177  
   178  // ResourceSpecFiles returns all matching spec files for a config.Resource.
   179  func ResourceSpecFiles(rcConfig *config.ResourceSet) ([]string, error) {
   180  	return files.LocalFSSource{}.Match(rcConfig)
   181  }
   182  
   183  // LintResources checks the inputs of an API's resources with the configured linter.
   184  func (c *Compiler) LintResources(ctx context.Context, apiName string) error {
   185  	api, ok := c.apis[apiName]
   186  	if !ok {
   187  		return fmt.Errorf("api not found (apis.%s)", apiName)
   188  	}
   189  	var errs error
   190  	for rcIndex, rc := range api.resources {
   191  		if rc.linter == nil {
   192  			continue
   193  		}
   194  		if len(rc.linterOverrides) > 0 {
   195  			err := c.lintWithOverrides(ctx, rc, apiName, rcIndex)
   196  			if err != nil {
   197  				errs = multierr.Append(errs, fmt.Errorf("%w (apis.%s.resources[%d])", err, apiName, rcIndex))
   198  			}
   199  		} else {
   200  			err := rc.linter.Run(ctx, rc.path, rc.lintFiles...)
   201  			if err != nil {
   202  				errs = multierr.Append(errs, fmt.Errorf("%w (apis.%s.resources[%d])", err, apiName, rcIndex))
   203  			}
   204  		}
   205  	}
   206  	return errs
   207  }
   208  
   209  func (c *Compiler) lintWithOverrides(ctx context.Context, rc *resourceSet, apiName string, rcIndex int) error {
   210  	var pending []string
   211  	for _, matchedFile := range rc.lintFiles {
   212  		versionDir := filepath.Dir(matchedFile)
   213  		rcDir := filepath.Dir(versionDir)
   214  		versionName := filepath.Base(versionDir)
   215  		rcName := filepath.Base(rcDir)
   216  		if linter, ok := rc.linterOverrides[rcName][versionName]; ok {
   217  			linter, err := rc.linter.WithOverride(ctx, &linter)
   218  			if err != nil {
   219  				return fmt.Errorf("failed to apply overrides to linter: %w (apis.%s.resources[%d].linter-overrides.%s.%s)",
   220  					err, apiName, rcIndex, rcName, versionName)
   221  			}
   222  			err = linter.Run(ctx, matchedFile)
   223  			if err != nil {
   224  				return fmt.Errorf("lint failed on %q: %w (apis.%s.resources[%d])", matchedFile, err, apiName, rcIndex)
   225  			}
   226  		} else {
   227  			pending = append(pending, matchedFile)
   228  		}
   229  	}
   230  	if len(pending) == 0 {
   231  		return nil
   232  	}
   233  	err := rc.linter.Run(ctx, rc.path, pending...)
   234  	if err != nil {
   235  		return fmt.Errorf("lint failed (apis.%s.resources[%d])", apiName, rcIndex)
   236  	}
   237  	return nil
   238  }
   239  
   240  // LintResourcesAll lints resources in all APIs in the project.
   241  func (c *Compiler) LintResourcesAll(ctx context.Context) error {
   242  	return c.apisEach(ctx, c.LintResources)
   243  }
   244  
   245  func (c *Compiler) apisEach(ctx context.Context, f func(ctx context.Context, apiName string) error) error {
   246  	var errs error
   247  	for apiName := range c.apis {
   248  		err := f(ctx, apiName)
   249  		if err != nil {
   250  			errs = multierr.Append(errs, err)
   251  		}
   252  	}
   253  	return errs
   254  }
   255  
   256  // Build builds an aggregate versioned OpenAPI spec for a specific API by name
   257  // in the project.
   258  func (c *Compiler) Build(ctx context.Context, apiName string) error {
   259  	api, ok := c.apis[apiName]
   260  	if !ok {
   261  		return fmt.Errorf("api not found (apis.%s)", apiName)
   262  	}
   263  	if api.output == nil || len(api.output.paths) == 0 {
   264  		return nil
   265  	}
   266  	for _, path := range api.output.paths {
   267  		err := os.RemoveAll(path)
   268  		if err != nil {
   269  			return fmt.Errorf("failed to clear output directory: %w", err)
   270  		}
   271  	}
   272  	err := os.MkdirAll(api.output.paths[0], 0777)
   273  	if err != nil {
   274  		return fmt.Errorf("failed to create output directory: %w", err)
   275  	}
   276  	log.Printf("compiling API %s to output versions", apiName)
   277  	var versionSpecFiles []string
   278  	for rcIndex, rc := range api.resources {
   279  		specVersions, err := vervet.LoadSpecVersionsFileset(rc.sourceFiles)
   280  		if err != nil {
   281  			return fmt.Errorf("failed to load spec versions: %+v (apis.%s.resources[%d])",
   282  				err, apiName, rcIndex)
   283  		}
   284  		buildErr := func(err error) error {
   285  			return fmt.Errorf("%w (apis.%s.resources[%d])", err, apiName, rcIndex)
   286  		}
   287  		versions := specVersions.Versions()
   288  		for _, version := range versions {
   289  			spec, err := specVersions.At(version)
   290  			if err == vervet.ErrNoMatchingVersion {
   291  				continue
   292  			} else if err != nil {
   293  				return buildErr(err)
   294  			}
   295  
   296  			// Create the directories, but only if a spec file exists for it.
   297  			versionDir := api.output.paths[0] + "/" + version.String()
   298  
   299  			if spec != nil {
   300  				err = os.MkdirAll(versionDir, 0755)
   301  				if err != nil {
   302  					return buildErr(err)
   303  				}
   304  			}
   305  
   306  			// Merge all overlays
   307  			for _, doc := range api.overlayIncludes {
   308  				vervet.Merge(spec, doc.T, true)
   309  			}
   310  			for _, doc := range api.overlayInlines {
   311  				vervet.Merge(spec, doc, true)
   312  			}
   313  
   314  			// Write the compiled spec to JSON and YAML
   315  			jsonBuf, err := vervet.ToSpecJSON(spec)
   316  			if err != nil {
   317  				return buildErr(err)
   318  			}
   319  			jsonSpecPath := versionDir + "/spec.json"
   320  			jsonEmbedPath, err := filepath.Rel(api.output.paths[0], jsonSpecPath)
   321  			if err != nil {
   322  				return buildErr(err)
   323  			}
   324  			versionSpecFiles = append(versionSpecFiles, jsonEmbedPath)
   325  			err = ioutil.WriteFile(jsonSpecPath, jsonBuf, 0644)
   326  			if err != nil {
   327  				return buildErr(err)
   328  			}
   329  			log.Println(jsonSpecPath)
   330  			yamlBuf, err := yaml.JSONToYAML(jsonBuf)
   331  			if err != nil {
   332  				return buildErr(err)
   333  			}
   334  			yamlBuf, err = vervet.WithGeneratedComment(yamlBuf)
   335  			if err != nil {
   336  				return buildErr(err)
   337  			}
   338  			yamlSpecPath := versionDir + "/spec.yaml"
   339  			yamlEmbedPath, err := filepath.Rel(api.output.paths[0], yamlSpecPath)
   340  			if err != nil {
   341  				return buildErr(err)
   342  			}
   343  			versionSpecFiles = append(versionSpecFiles, yamlEmbedPath)
   344  			err = ioutil.WriteFile(yamlSpecPath, yamlBuf, 0644)
   345  			if err != nil {
   346  				return buildErr(err)
   347  			}
   348  			log.Println(yamlSpecPath)
   349  		}
   350  	}
   351  	err = c.writeEmbedGo(filepath.Base(api.output.paths[0]), api, versionSpecFiles)
   352  	if err != nil {
   353  		return fmt.Errorf("failed to create embed.go: %w", err)
   354  	}
   355  	// Copy output to multiple paths if specified
   356  	src := api.output.paths[0]
   357  	for _, dst := range api.output.paths[1:] {
   358  		if err := files.CopyDir(dst, src, true); err != nil {
   359  			return fmt.Errorf("failed to copy %q to %q: %w", src, dst, err)
   360  		}
   361  	}
   362  	return nil
   363  }
   364  
   365  func (c *Compiler) writeEmbedGo(pkgName string, a *api, versionSpecFiles []string) error {
   366  	embedPath := filepath.Join(a.output.paths[0], "embed.go")
   367  	f, err := os.Create(embedPath)
   368  	if err != nil {
   369  		return err
   370  	}
   371  	defer f.Close()
   372  	err = embedGoTmpl.Execute(f, struct {
   373  		Package          string
   374  		API              *api
   375  		VersionSpecFiles []string
   376  	}{
   377  		Package:          pkgName,
   378  		API:              a,
   379  		VersionSpecFiles: versionSpecFiles,
   380  	})
   381  	if err != nil {
   382  		return err
   383  	}
   384  	for _, dst := range a.output.paths[1:] {
   385  		if err := files.CopyFile(filepath.Join(dst, "embed.go"), embedPath, true); err != nil {
   386  			return err
   387  		}
   388  	}
   389  	return nil
   390  }
   391  
   392  var embedGoTmpl = template.Must(template.New("embed.go").Parse(`
   393  package {{ .Package }}
   394  
   395  import "embed"
   396  
   397  // Embed compiled OpenAPI specs in Go projects.
   398  
   399  {{ range .VersionSpecFiles -}}
   400  //go:embed {{ . }}
   401  {{ end -}}
   402  // Versions contains OpenAPI specs for each distinct release version.
   403  var Versions embed.FS
   404  `[1:]))
   405  
   406  // BuildAll builds all APIs in the project.
   407  func (c *Compiler) BuildAll(ctx context.Context) error {
   408  	return c.apisEach(ctx, c.Build)
   409  }
   410  
   411  // LintOutput applies configured linting rules to the build output.
   412  func (c *Compiler) LintOutput(ctx context.Context, apiName string) error {
   413  	api, ok := c.apis[apiName]
   414  	if !ok {
   415  		return fmt.Errorf("api not found (apis.%s)", apiName)
   416  	}
   417  	if api.output != nil && len(api.output.paths) > 0 && api.output.linter != nil {
   418  		var outputFiles []string
   419  		err := doublestar.GlobWalk(os.DirFS(api.output.paths[0]), "**/spec.{json,yaml}",
   420  			func(path string, d fs.DirEntry) error {
   421  				outputFiles = append(outputFiles, filepath.Join(api.output.paths[0], path))
   422  				return nil
   423  			})
   424  		if err != nil {
   425  			return fmt.Errorf("failed to match output files for linting: %w (apis.%s.output)",
   426  				err, apiName)
   427  		}
   428  		if len(outputFiles) == 0 {
   429  			return fmt.Errorf("lint failed: no output files were produced")
   430  		}
   431  		err = api.output.linter.Run(ctx, api.output.paths[0], outputFiles...)
   432  		if err != nil {
   433  			return fmt.Errorf("lint failed (apis.%s.output)", apiName)
   434  		}
   435  	}
   436  	return nil
   437  }
   438  
   439  // LintOutputAll lints output of all APIs in the project.
   440  func (c *Compiler) LintOutputAll(ctx context.Context) error {
   441  	return c.apisEach(ctx, c.LintOutput)
   442  }