github.com/w3security/vervet/v5@v5.3.1-0.20230618081846-5bd9b5d799dc/internal/compiler/compiler.go (about)

     1  package compiler
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"html/template"
     7  	"io/fs"
     8  	"log"
     9  	"os"
    10  	"path/filepath"
    11  
    12  	"github.com/bmatcuk/doublestar/v4"
    13  	"github.com/getkin/kin-openapi/openapi3"
    14  	"github.com/ghodss/yaml"
    15  	"go.uber.org/multierr"
    16  
    17  	"github.com/w3security/vervet/v5"
    18  	"github.com/w3security/vervet/v5/config"
    19  	"github.com/w3security/vervet/v5/internal/files"
    20  	"github.com/w3security/vervet/v5/internal/linter"
    21  	"github.com/w3security/vervet/v5/internal/linter/optic"
    22  	"github.com/w3security/vervet/v5/internal/linter/spectral"
    23  )
    24  
    25  // A Compiler checks and builds versioned API resource inputs into aggregated
    26  // OpenAPI versioned outputs, as determined by an API project configuration.
    27  type Compiler struct {
    28  	apis    map[string]*api
    29  	linters map[string]linter.Linter
    30  
    31  	newLinter func(ctx context.Context, lc *config.Linter) (linter.Linter, error)
    32  }
    33  
    34  // CompilerOption applies a configuration option to a Compiler.
    35  type CompilerOption func(*Compiler) error
    36  
    37  // LinterFactory configures a Compiler to use a custom factory function for
    38  // instantiating Linters.
    39  func LinterFactory(f func(ctx context.Context, lc *config.Linter) (linter.Linter, error)) CompilerOption {
    40  	return func(c *Compiler) error {
    41  		c.newLinter = f
    42  		return nil
    43  	}
    44  }
    45  
    46  func defaultLinterFactory(ctx context.Context, lc *config.Linter) (linter.Linter, error) {
    47  	if lc.Spectral != nil {
    48  		return spectral.New(ctx, lc.Spectral)
    49  	} else if lc.SweaterComb != nil {
    50  		return optic.New(ctx, lc.SweaterComb)
    51  	} else if lc.OpticCI != nil {
    52  		return optic.New(ctx, lc.OpticCI)
    53  	}
    54  	return nil, fmt.Errorf("invalid linter (linters.%s)", lc.Name)
    55  }
    56  
    57  type api struct {
    58  	resources       []*resourceSet
    59  	overlayIncludes []*vervet.Document
    60  	overlayInlines  []*openapi3.T
    61  	output          *output
    62  }
    63  
    64  type resourceSet struct {
    65  	path            string
    66  	linter          linter.Linter
    67  	linterOverrides map[string]map[string]config.Linter
    68  	sourceFiles     []string
    69  	lintFiles       []string
    70  }
    71  
    72  type output struct {
    73  	paths  []string
    74  	linter linter.Linter
    75  }
    76  
    77  // New returns a new Compiler for a given project configuration.
    78  func New(ctx context.Context, proj *config.Project, lint bool, options ...CompilerOption) (*Compiler, error) {
    79  	compiler := &Compiler{
    80  		apis:      map[string]*api{},
    81  		linters:   map[string]linter.Linter{},
    82  		newLinter: defaultLinterFactory,
    83  	}
    84  
    85  	for i := range options {
    86  		err := options[i](compiler)
    87  		if err != nil {
    88  			return nil, err
    89  		}
    90  	}
    91  
    92  	if lint {
    93  		// set up linters
    94  		for linterName, linterConfig := range proj.Linters {
    95  			linter, err := compiler.newLinter(ctx, linterConfig)
    96  			if err != nil {
    97  				return nil, fmt.Errorf("%w (linters.%s)", err, linterName)
    98  			}
    99  			compiler.linters[linterName] = linter
   100  		}
   101  	}
   102  
   103  	// set up APIs
   104  	for apiName, apiConfig := range proj.APIs {
   105  		a := api{}
   106  
   107  		// Build resources
   108  		for rcIndex, rcConfig := range apiConfig.Resources {
   109  			var err error
   110  			r := &resourceSet{
   111  				path:            rcConfig.Path,
   112  				linter:          compiler.linters[rcConfig.Linter],
   113  				linterOverrides: map[string]map[string]config.Linter{},
   114  			}
   115  			if lint && r.linter != nil {
   116  				r.lintFiles, err = r.linter.Match(rcConfig)
   117  				if err != nil {
   118  					return nil, fmt.Errorf("%w: (apis.%s.resources[%d].path)", err, apiName, rcIndex)
   119  				}
   120  				// TODO: overrides are deprecated by Optic CI, remove soon
   121  				linterOverrides := map[string]map[string]config.Linter{}
   122  				for rcName, versionMap := range rcConfig.LinterOverrides {
   123  					linterOverrides[rcName] = map[string]config.Linter{}
   124  					for version, linter := range versionMap {
   125  						linterOverrides[rcName][version] = *linter
   126  					}
   127  				}
   128  				r.linterOverrides = linterOverrides
   129  			}
   130  			r.sourceFiles, err = ResourceSpecFiles(rcConfig)
   131  			if err != nil {
   132  				return nil, fmt.Errorf("%w: (apis.%s.resources[%d].path)", err, apiName, rcIndex)
   133  			}
   134  			a.resources = append(a.resources, r)
   135  		}
   136  
   137  		// Build overlays
   138  		for overlayIndex, overlayConfig := range apiConfig.Overlays {
   139  			if overlayConfig.Include != "" {
   140  				doc, err := vervet.NewDocumentFile(overlayConfig.Include)
   141  				if err != nil {
   142  					return nil, fmt.Errorf("failed to load overlay %q: %w (apis.%s.overlays[%d])",
   143  						overlayConfig.Include, err, apiName, overlayIndex)
   144  				}
   145  				err = vervet.Localize(ctx, doc)
   146  				if err != nil {
   147  					return nil, fmt.Errorf("failed to localize references in %q: %w (apis.%s.overlays[%d]",
   148  						overlayConfig.Include, err, apiName, overlayIndex)
   149  				}
   150  				a.overlayIncludes = append(a.overlayIncludes, doc)
   151  			} else if overlayConfig.Inline != "" {
   152  				docString := os.ExpandEnv(overlayConfig.Inline)
   153  				l := openapi3.NewLoader()
   154  				doc, err := l.LoadFromData([]byte(docString))
   155  				if err != nil {
   156  					return nil, fmt.Errorf("failed to load template: %w (apis.%s.overlays[%d].template)",
   157  						err, apiName, overlayIndex)
   158  				}
   159  				a.overlayInlines = append(a.overlayInlines, doc)
   160  			}
   161  		}
   162  
   163  		// Build output
   164  		if apiConfig.Output != nil {
   165  			paths := apiConfig.Output.Paths
   166  			if len(paths) == 0 && apiConfig.Output.Path != "" {
   167  				paths = []string{apiConfig.Output.Path}
   168  			}
   169  			if len(paths) > 0 {
   170  				a.output = &output{
   171  					paths:  paths,
   172  					linter: compiler.linters[apiConfig.Output.Linter],
   173  				}
   174  			}
   175  		}
   176  
   177  		compiler.apis[apiName] = &a
   178  	}
   179  	return compiler, nil
   180  }
   181  
   182  // ResourceSpecFiles returns all matching spec files for a config.Resource.
   183  func ResourceSpecFiles(rcConfig *config.ResourceSet) ([]string, error) {
   184  	return files.LocalFSSource{}.Match(rcConfig)
   185  }
   186  
   187  // LintResources checks the inputs of an API's resources with the configured linter.
   188  func (c *Compiler) LintResources(ctx context.Context, apiName string) error {
   189  	api, ok := c.apis[apiName]
   190  	if !ok {
   191  		return fmt.Errorf("api not found (apis.%s)", apiName)
   192  	}
   193  	var errs error
   194  	for rcIndex, rc := range api.resources {
   195  		if rc.linter == nil {
   196  			continue
   197  		}
   198  		if len(rc.linterOverrides) > 0 {
   199  			err := c.lintWithOverrides(ctx, rc, apiName, rcIndex)
   200  			if err != nil {
   201  				errs = multierr.Append(errs, fmt.Errorf("%w (apis.%s.resources[%d])", err, apiName, rcIndex))
   202  			}
   203  		} else {
   204  			err := rc.linter.Run(ctx, rc.path, rc.lintFiles...)
   205  			if err != nil {
   206  				errs = multierr.Append(errs, fmt.Errorf("%w (apis.%s.resources[%d])", err, apiName, rcIndex))
   207  			}
   208  		}
   209  	}
   210  	return errs
   211  }
   212  
   213  func (c *Compiler) lintWithOverrides(ctx context.Context, rc *resourceSet, apiName string, rcIndex int) error {
   214  	var pending []string
   215  	for _, matchedFile := range rc.lintFiles {
   216  		versionDir := filepath.Dir(matchedFile)
   217  		rcDir := filepath.Dir(versionDir)
   218  		versionName := filepath.Base(versionDir)
   219  		rcName := filepath.Base(rcDir)
   220  		if linter, ok := rc.linterOverrides[rcName][versionName]; ok {
   221  			linter, err := rc.linter.WithOverride(ctx, &linter)
   222  			if err != nil {
   223  				return fmt.Errorf("failed to apply overrides to linter: %w (apis.%s.resources[%d].linter-overrides.%s.%s)",
   224  					err, apiName, rcIndex, rcName, versionName)
   225  			}
   226  			err = linter.Run(ctx, matchedFile)
   227  			if err != nil {
   228  				return fmt.Errorf("lint failed on %q: %w (apis.%s.resources[%d])", matchedFile, err, apiName, rcIndex)
   229  			}
   230  		} else {
   231  			pending = append(pending, matchedFile)
   232  		}
   233  	}
   234  	if len(pending) == 0 {
   235  		return nil
   236  	}
   237  	err := rc.linter.Run(ctx, rc.path, pending...)
   238  	if err != nil {
   239  		return fmt.Errorf("lint failed (apis.%s.resources[%d])", apiName, rcIndex)
   240  	}
   241  	return nil
   242  }
   243  
   244  func (c *Compiler) apisEach(ctx context.Context, f func(ctx context.Context, apiName string) error) error {
   245  	var errs error
   246  	for apiName := range c.apis {
   247  		err := f(ctx, apiName)
   248  		if err != nil {
   249  			errs = multierr.Append(errs, err)
   250  		}
   251  	}
   252  	return errs
   253  }
   254  
   255  // Build builds an aggregate versioned OpenAPI spec for a specific API by name
   256  // in the project.
   257  func (c *Compiler) Build(ctx context.Context, apiName string) error {
   258  	api, ok := c.apis[apiName]
   259  	if !ok {
   260  		return fmt.Errorf("api not found (apis.%s)", apiName)
   261  	}
   262  	if api.output == nil || len(api.output.paths) == 0 {
   263  		return nil
   264  	}
   265  	for _, path := range api.output.paths {
   266  		err := os.RemoveAll(path)
   267  		if err != nil {
   268  			return fmt.Errorf("failed to clear output directory: %w", err)
   269  		}
   270  	}
   271  	err := os.MkdirAll(api.output.paths[0], 0777)
   272  	if err != nil {
   273  		return fmt.Errorf("failed to create output directory: %w", err)
   274  	}
   275  	log.Printf("compiling API %s to output versions", apiName)
   276  	var versionSpecFiles []string
   277  	for rcIndex, rc := range api.resources {
   278  		specVersions, err := vervet.LoadSpecVersionsFileset(rc.sourceFiles) //nolint:contextcheck //acked
   279  		if err != nil {
   280  			return fmt.Errorf("failed to load spec versions: %+v (apis.%s.resources[%d])",
   281  				err, apiName, rcIndex)
   282  		}
   283  		buildErr := func(err error) error {
   284  			return fmt.Errorf("%w (apis.%s.resources[%d])", err, apiName, rcIndex)
   285  		}
   286  		versions := specVersions.Versions()
   287  		for _, version := range versions {
   288  			spec, err := specVersions.At(version)
   289  			if err == vervet.ErrNoMatchingVersion {
   290  				continue
   291  			} else if err != nil {
   292  				return buildErr(err)
   293  			}
   294  
   295  			// Create the directories, but only if a spec file exists for it.
   296  			versionDir := api.output.paths[0] + "/" + version.String()
   297  
   298  			if spec != nil {
   299  				err = os.MkdirAll(versionDir, 0755)
   300  				if err != nil {
   301  					return buildErr(err)
   302  				}
   303  			}
   304  
   305  			// Merge all overlays
   306  			for _, doc := range api.overlayIncludes {
   307  				vervet.Merge(spec, doc.T, true)
   308  			}
   309  			for _, doc := range api.overlayInlines {
   310  				vervet.Merge(spec, doc, true)
   311  			}
   312  
   313  			// Write the compiled spec to JSON and YAML
   314  			jsonBuf, err := vervet.ToSpecJSON(spec)
   315  			if err != nil {
   316  				return buildErr(err)
   317  			}
   318  			jsonSpecPath := versionDir + "/spec.json"
   319  			jsonEmbedPath, err := filepath.Rel(api.output.paths[0], jsonSpecPath)
   320  			if err != nil {
   321  				return buildErr(err)
   322  			}
   323  			versionSpecFiles = append(versionSpecFiles, jsonEmbedPath)
   324  			err = os.WriteFile(jsonSpecPath, jsonBuf, 0644)
   325  			if err != nil {
   326  				return buildErr(err)
   327  			}
   328  			log.Println(jsonSpecPath)
   329  			yamlBuf, err := yaml.JSONToYAML(jsonBuf)
   330  			if err != nil {
   331  				return buildErr(err)
   332  			}
   333  			yamlBuf, err = vervet.WithGeneratedComment(yamlBuf)
   334  			if err != nil {
   335  				return buildErr(err)
   336  			}
   337  			yamlSpecPath := versionDir + "/spec.yaml"
   338  			yamlEmbedPath, err := filepath.Rel(api.output.paths[0], yamlSpecPath)
   339  			if err != nil {
   340  				return buildErr(err)
   341  			}
   342  			versionSpecFiles = append(versionSpecFiles, yamlEmbedPath)
   343  			err = os.WriteFile(yamlSpecPath, yamlBuf, 0644)
   344  			if err != nil {
   345  				return buildErr(err)
   346  			}
   347  			log.Println(yamlSpecPath)
   348  		}
   349  	}
   350  	err = c.writeEmbedGo(filepath.Base(api.output.paths[0]), api, versionSpecFiles)
   351  	if err != nil {
   352  		return fmt.Errorf("failed to create embed.go: %w", err)
   353  	}
   354  	// Copy output to multiple paths if specified
   355  	src := api.output.paths[0]
   356  	for _, dst := range api.output.paths[1:] {
   357  		if err := files.CopyDir(dst, src, true); err != nil {
   358  			return fmt.Errorf("failed to copy %q to %q: %w", src, dst, err)
   359  		}
   360  	}
   361  	return nil
   362  }
   363  
   364  func (c *Compiler) writeEmbedGo(pkgName string, a *api, versionSpecFiles []string) error {
   365  	embedPath := filepath.Join(a.output.paths[0], "embed.go")
   366  	f, err := os.Create(embedPath)
   367  	if err != nil {
   368  		return err
   369  	}
   370  	defer f.Close()
   371  	err = embedGoTmpl.Execute(f, struct {
   372  		Package          string
   373  		API              *api
   374  		VersionSpecFiles []string
   375  	}{
   376  		Package:          pkgName,
   377  		API:              a,
   378  		VersionSpecFiles: versionSpecFiles,
   379  	})
   380  	if err != nil {
   381  		return err
   382  	}
   383  	for _, dst := range a.output.paths[1:] {
   384  		if err := files.CopyFile(filepath.Join(dst, "embed.go"), embedPath, true); err != nil {
   385  			return err
   386  		}
   387  	}
   388  	return nil
   389  }
   390  
   391  var embedGoTmpl = template.Must(template.New("embed.go").Parse(`
   392  // Code generated by Vervet. DO NOT EDIT.
   393  
   394  package {{ .Package }}
   395  
   396  import "embed"
   397  
   398  // Embed compiled OpenAPI specs in Go projects.
   399  
   400  {{ range .VersionSpecFiles -}}
   401  //go:embed {{ . }}
   402  {{ end }}
   403  // Versions contains OpenAPI specs for each distinct release version.
   404  var Versions embed.FS
   405  `[1:]))
   406  
   407  // BuildAll builds all APIs in the project.
   408  func (c *Compiler) BuildAll(ctx context.Context) error {
   409  	if err := c.apisEach(ctx, c.LintResources); err != nil {
   410  		return err
   411  	}
   412  	if err := c.apisEach(ctx, c.Build); err != nil {
   413  		return err
   414  	}
   415  	return c.apisEach(ctx, c.LintOutput)
   416  }
   417  
   418  // LintOutput applies configured linting rules to the build output.
   419  func (c *Compiler) LintOutput(ctx context.Context, apiName string) error {
   420  	api, ok := c.apis[apiName]
   421  	if !ok {
   422  		return fmt.Errorf("api not found (apis.%s)", apiName)
   423  	}
   424  	if api.output != nil && len(api.output.paths) > 0 && api.output.linter != nil {
   425  		var outputFiles []string
   426  		err := doublestar.GlobWalk(os.DirFS(api.output.paths[0]), "**/spec.{json,yaml}",
   427  			func(path string, d fs.DirEntry) error {
   428  				outputFiles = append(outputFiles, filepath.Join(api.output.paths[0], path))
   429  				return nil
   430  			})
   431  		if err != nil {
   432  			return fmt.Errorf("failed to match output files for linting: %w (apis.%s.output)",
   433  				err, apiName)
   434  		}
   435  		if len(outputFiles) == 0 {
   436  			return fmt.Errorf("lint failed: no output files were produced")
   437  		}
   438  		err = api.output.linter.Run(ctx, api.output.paths[0], outputFiles...)
   439  		if err != nil {
   440  			return fmt.Errorf("lint failed (apis.%s.output)", apiName)
   441  		}
   442  	}
   443  	return nil
   444  }