github.com/chainguard-dev/yam@v0.0.7/pkg/yam/formatted/encoder.go (about)

     1  package formatted
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"fmt"
     7  	"io"
     8  	"os"
     9  	"sort"
    10  	"strings"
    11  
    12  	"github.com/chainguard-dev/yam/pkg/util"
    13  	"github.com/chainguard-dev/yam/pkg/yam/formatted/path"
    14  	"gopkg.in/yaml.v3"
    15  )
    16  
    17  var (
    18  	newline   = []byte("\n")
    19  	colon     = []byte(":")
    20  	space     = []byte(" ")
    21  	dashSpace = []byte("- ")
    22  )
    23  
    24  const defaultIndentSize = 2
    25  
    26  // EncodeOptions describes the set of configuration options used to adjust the
    27  // behavior of yam's YAML encoder.
    28  type EncodeOptions struct {
    29  	// Indent specifies how many spaces to use per-indentation
    30  	Indent int `yaml:"indent"`
    31  
    32  	// GapExpressions specifies a list of yq-style paths for which the path's YAML
    33  	// element's children elements should be separated by an empty line
    34  	GapExpressions []string `yaml:"gap"`
    35  
    36  	// SortExpressions specifies a list of yq-style paths for which the path's YAML
    37  	// element's children elements should be sorted
    38  	SortExpressions []string `yaml:"sort"`
    39  }
    40  
    41  // Encoder is an implementation of a YAML encoder that applies a configurable
    42  // formatting to the YAML data as it's written out to the encoder's io.Writer.
    43  type Encoder struct {
    44  	w          io.Writer
    45  	indentSize int
    46  	yamlEnc    *yaml.Encoder
    47  	gapPaths   []path.Path
    48  	sortPaths  []path.Path
    49  }
    50  
    51  // NewEncoder returns a new encoder that can write formatted YAML to the given
    52  // io.Writer.
    53  func NewEncoder(w io.Writer) Encoder {
    54  	yamlEnc := yaml.NewEncoder(w)
    55  	yamlEnc.SetIndent(defaultIndentSize)
    56  
    57  	enc := Encoder{
    58  		w:          w,
    59  		yamlEnc:    yamlEnc,
    60  		indentSize: defaultIndentSize,
    61  	}
    62  
    63  	return enc
    64  }
    65  
    66  // AutomaticConfig configures the encoder using a `.yam.yaml` config file in the
    67  // current working directory, if one exists. This method is meant to work on a
    68  // "best effort" basis, and all errors are silently ignored.
    69  func (enc Encoder) AutomaticConfig() Encoder {
    70  	options, err := ReadConfig()
    71  	if err != nil {
    72  		// Didn't find a config to apply, but that's okay.
    73  		return enc
    74  	}
    75  
    76  	enc = enc.SetIndent(options.Indent)
    77  	enc, _ = enc.SetGapExpressions(options.GapExpressions...)
    78  
    79  	return enc
    80  }
    81  
    82  // ReadConfig tries to load a yam encoder config from a `.yam.yaml` file in the
    83  // current working directory. It returns an error if it wasn't able to open or
    84  // unmarshal the file.
    85  func ReadConfig() (*EncodeOptions, error) {
    86  	f, err := os.Open(util.ConfigFileName)
    87  	if err != nil {
    88  		return nil, fmt.Errorf("unable to open yam config: %w", err)
    89  	}
    90  
    91  	config, err := ReadConfigFrom(f)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  
    96  	return config, nil
    97  }
    98  
    99  // ReadConfigFrom loads a yam encoder config from the given io.Reader. It
   100  // returns an error if it wasn't able to unmarshal the data.
   101  func ReadConfigFrom(r io.Reader) (*EncodeOptions, error) {
   102  	options := EncodeOptions{}
   103  
   104  	err := yaml.NewDecoder(r).Decode(&options)
   105  	if err != nil {
   106  		return nil, fmt.Errorf("unable to read yam config: %w", err)
   107  	}
   108  
   109  	return &options, nil
   110  }
   111  
   112  // SetIndent configures the encoder to use the provided number of spaces for
   113  // each indentation.
   114  func (enc Encoder) SetIndent(spaces int) Encoder {
   115  	enc.indentSize = spaces
   116  	enc.yamlEnc.SetIndent(spaces)
   117  	return enc
   118  }
   119  
   120  // SetGapExpressions takes 0 or more YAML path expressions (e.g. "." or
   121  // ".something.foo") and configures the encoder to insert empty lines ("gaps")
   122  // in between the children elements of the YAML nodes referenced by the path
   123  // expressions.
   124  func (enc Encoder) SetGapExpressions(expressions ...string) (Encoder, error) {
   125  	for _, expr := range expressions {
   126  		p, err := path.Parse(expr)
   127  		if err != nil {
   128  			return Encoder{}, fmt.Errorf("unable to parse expression %q: %w", expr, err)
   129  		}
   130  
   131  		enc.gapPaths = append(enc.gapPaths, p)
   132  	}
   133  
   134  	return enc, nil
   135  }
   136  
   137  // SetSortExpressions takes 0 or more YAML path expressions (e.g. "." or
   138  // ".something.foo") and configures the encoder to sort the arrays.
   139  func (enc Encoder) SetSortExpressions(expressions ...string) (Encoder, error) {
   140  	for _, expr := range expressions {
   141  		p, err := path.Parse(expr)
   142  		if err != nil {
   143  			return Encoder{}, fmt.Errorf("unable to parse expression %q: %w", expr, err)
   144  		}
   145  
   146  		enc.sortPaths = append(enc.sortPaths, p)
   147  	}
   148  
   149  	return enc, nil
   150  }
   151  
   152  // UseOptions configures the encoder to use the configuration from the given
   153  // EncodeOptions.
   154  func (enc Encoder) UseOptions(options EncodeOptions) (Encoder, error) {
   155  	enc = enc.SetIndent(options.Indent)
   156  	enc, err := enc.SetGapExpressions(options.GapExpressions...)
   157  	if err != nil {
   158  		return Encoder{}, err
   159  	}
   160  	enc, err = enc.SetSortExpressions(options.SortExpressions...)
   161  	if err != nil {
   162  		return Encoder{}, err
   163  	}
   164  
   165  	return enc, nil
   166  }
   167  
   168  // Encode writes out the formatted YAML from the given yaml.Node to the
   169  // encoder's io.Writer.
   170  func (enc Encoder) Encode(node *yaml.Node) error {
   171  	b, err := enc.marshalRoot(node)
   172  	if err != nil {
   173  		return err
   174  	}
   175  
   176  	_, err = enc.w.Write(b)
   177  	if err != nil {
   178  		return err
   179  	}
   180  
   181  	return nil
   182  }
   183  
   184  func (enc Encoder) marshalRoot(node *yaml.Node) ([]byte, error) {
   185  	return enc.marshal(node, path.Root())
   186  }
   187  
   188  func (enc Encoder) marshal(node *yaml.Node, nodePath path.Path) ([]byte, error) {
   189  	switch node.Kind {
   190  	case yaml.DocumentNode:
   191  		var bytes []byte
   192  		for _, inner := range node.Content {
   193  			innerBytes, err := enc.marshal(inner, nodePath)
   194  			if err != nil {
   195  				return nil, err
   196  			}
   197  			bytes = append(bytes, innerBytes...)
   198  		}
   199  		return bytes, nil
   200  
   201  	case yaml.MappingNode:
   202  		return enc.marshalMapping(node, nodePath)
   203  
   204  	case yaml.SequenceNode:
   205  		return enc.marshalSequence(node, nodePath)
   206  
   207  	case yaml.ScalarNode:
   208  		if node.Tag == "!!null" {
   209  			return nil, nil
   210  		}
   211  		return yaml.Marshal(node)
   212  
   213  	default:
   214  		return yaml.Marshal(node)
   215  
   216  	}
   217  }
   218  
   219  func (enc Encoder) marshalMapping(node *yaml.Node, nodePath path.Path) ([]byte, error) {
   220  	// Note: A mapping node's content items are laid out as key-value pairs!
   221  
   222  	var result []byte
   223  	var latestKey string
   224  	for i, item := range node.Content {
   225  		if isMapKeyIndex(i) {
   226  			rawKeyBytes, err := enc.marshal(item, nodePath)
   227  			if err != nil {
   228  				return nil, err
   229  			}
   230  
   231  			// assume the key can be a string (this isn't always true in YAML, but we'll see how far this gets us)
   232  			key := bytes.TrimSuffix(rawKeyBytes, newline)
   233  			latestKey = string(key)
   234  
   235  			keyBytes := bytes.Join([][]byte{
   236  				key,
   237  				colon,
   238  			}, nil)
   239  
   240  			if nextItem := node.Content[i+1]; nextItem.Kind == yaml.ScalarNode && nextItem.Tag != "!!null" { // TODO: check that there is a value node for this key node
   241  				// render in same line
   242  				keyBytes = append(keyBytes, space...)
   243  			} else {
   244  				keyBytes = append(keyBytes, newline...)
   245  			}
   246  
   247  			result = append(result, keyBytes...)
   248  			continue
   249  		}
   250  
   251  		nodePathForValue := nodePath.AppendMapPart(latestKey)
   252  
   253  		valueBytes, err := enc.marshal(item, nodePathForValue)
   254  		if err != nil {
   255  			return nil, err
   256  		}
   257  
   258  		isFinalMapValue := i == len(node.Content)-1
   259  
   260  		// This was the key's value node, so add a gap if configured to do so.
   261  		// We shouldn't add a newline after the final map value, though.
   262  		if enc.matchesAnyGapPath(nodePath) && !isFinalMapValue {
   263  			valueBytes = append(valueBytes, newline...)
   264  		}
   265  
   266  		if item.Kind == yaml.MappingNode || item.Kind == yaml.SequenceNode {
   267  			valueBytes = enc.applyIndent(valueBytes)
   268  		} else {
   269  			valueBytes = enc.handleMultilineStringIndentation(valueBytes)
   270  		}
   271  
   272  		result = append(result, valueBytes...)
   273  	}
   274  
   275  	return result, nil
   276  }
   277  
   278  func isMapKeyIndex(i int) bool {
   279  	return i%2 == 0
   280  }
   281  
   282  func (enc Encoder) marshalSequence(node *yaml.Node, nodePath path.Path) ([]byte, error) {
   283  	var lines [][]byte
   284  
   285  	// Sort the sequence if configured to do so before marshalling.
   286  	if node.Kind == yaml.SequenceNode && enc.matchesAnySortPath(nodePath) {
   287  		sort.Slice(node.Content, func(i int, j int) bool {
   288  			return node.Content[i].Value < node.Content[j].Value
   289  		})
   290  	}
   291  
   292  	for i, item := range node.Content {
   293  		// For scalar items, pull out the head comment, so we can control its encoding
   294  		// here, rather than delegate it to the underlying encoder.
   295  		var extractedHeadComment string
   296  		if item.HeadComment != "" {
   297  			extractedHeadComment = item.HeadComment + "\n"
   298  			item.HeadComment = ""
   299  		}
   300  
   301  		itemBytes, err := enc.marshal(item, nodePath.AppendSeqPart(i))
   302  		if err != nil {
   303  			return nil, err
   304  		}
   305  
   306  		if item.Kind == yaml.ScalarNode {
   307  			// Print head comment first. Then continue.
   308  			itemBytes = bytes.Join([][]byte{
   309  				[]byte(extractedHeadComment),
   310  				dashSpace,
   311  				itemBytes,
   312  			}, nil)
   313  		} else {
   314  			itemBytes = enc.applyIndentExceptFirstLine(itemBytes)
   315  
   316  			// Precede with a dash.
   317  			itemBytes = bytes.Join([][]byte{
   318  				[]byte(extractedHeadComment),
   319  				dashSpace,
   320  				itemBytes,
   321  			}, nil)
   322  		}
   323  
   324  		lines = append(lines, itemBytes)
   325  	}
   326  
   327  	var sep []byte
   328  	if enc.matchesAnyGapPath(nodePath) {
   329  		sep = newline
   330  	}
   331  
   332  	return bytes.Join(lines, sep), nil
   333  }
   334  
   335  func (enc Encoder) applyIndent(content []byte) []byte {
   336  	var processedLines []string
   337  
   338  	scanner := bufio.NewScanner(bytes.NewReader(content))
   339  	for scanner.Scan() {
   340  		line := scanner.Text()
   341  
   342  		// We don't indent empty lines.
   343  		if line != "" {
   344  			line = enc.indentString() + line
   345  		}
   346  		processedLines = append(processedLines, line)
   347  	}
   348  
   349  	result := []byte(strings.Join(processedLines, "\n") + "\n")
   350  
   351  	return result
   352  }
   353  
   354  func (enc Encoder) applyIndentExceptFirstLine(content []byte) []byte {
   355  	var processedLines []string
   356  
   357  	scanner := bufio.NewScanner(bytes.NewReader(content))
   358  	isFirstLine := true
   359  	for scanner.Scan() {
   360  		line := scanner.Text()
   361  
   362  		if isFirstLine {
   363  			processedLines = append(processedLines, line)
   364  			isFirstLine = false
   365  			continue
   366  		}
   367  
   368  		// We don't indent empty lines.
   369  		if line != "" {
   370  			line = enc.indentString() + line
   371  		}
   372  		processedLines = append(processedLines, line)
   373  	}
   374  
   375  	return []byte(strings.Join(processedLines, "\n") + "\n")
   376  }
   377  
   378  func (enc Encoder) matchesAnyGapPath(testSubject path.Path) bool {
   379  	for _, gp := range enc.gapPaths {
   380  		if gp.Matches(testSubject) {
   381  			return true
   382  		}
   383  	}
   384  
   385  	return false
   386  }
   387  
   388  func (enc Encoder) matchesAnySortPath(testSubject path.Path) bool {
   389  	for _, sp := range enc.sortPaths {
   390  		if sp.Matches(testSubject) {
   391  			return true
   392  		}
   393  	}
   394  	return false
   395  }
   396  
   397  func (enc Encoder) handleMultilineStringIndentation(content []byte) []byte {
   398  	// For some reason, yaml.Marshal seemed to be indenting non-first lines twice.
   399  
   400  	lines := bytes.Split(content, newline)
   401  	if len(lines) == 1 {
   402  		return content
   403  	}
   404  
   405  	for i := 1; i < len(lines); i++ { // i.e. starting with second line
   406  		lines[i] = bytes.TrimPrefix(lines[i], []byte(enc.indentString()))
   407  	}
   408  
   409  	return bytes.Join(lines, newline)
   410  }
   411  
   412  func (enc Encoder) indentString() string {
   413  	return strings.Repeat(" ", enc.indentSize)
   414  }