github.com/blend/go-sdk@v1.20220411.3/copyright/copyright.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package copyright
     9  
    10  import (
    11  	"bytes"
    12  	"context"
    13  	"fmt"
    14  	"io"
    15  	"os"
    16  	"path/filepath"
    17  	"regexp"
    18  	"strings"
    19  	"text/template"
    20  	"unicode"
    21  
    22  	"github.com/blend/go-sdk/diff"
    23  	"github.com/blend/go-sdk/stringutil"
    24  )
    25  
    26  // New creates a new copyright engine with a given set of config options.
    27  func New(options ...Option) *Copyright {
    28  	var c Copyright
    29  	for _, option := range options {
    30  		option(&c)
    31  	}
    32  	return &c
    33  }
    34  
    35  // Copyright is the main type that injects, removes and verifies copyright headers.
    36  type Copyright struct {
    37  	Config // Config holds the configuration opitons.
    38  
    39  	// Stdout is the writer for Verbose and Debug output.
    40  	// If it is unset, `os.Stdout` will be used.
    41  	Stdout io.Writer
    42  	// Stderr is the writer for Error output.
    43  	// If it is unset, `os.Stderr` will be used.
    44  	Stderr io.Writer
    45  }
    46  
    47  // Inject inserts the copyright header in any matching files that don't already
    48  // have the copyright header.
    49  func (c Copyright) Inject(ctx context.Context, root string) error {
    50  	return c.Walk(ctx, c.inject, root)
    51  }
    52  
    53  // Remove removes the copyright header in any matching files that
    54  // have the copyright header.
    55  func (c Copyright) Remove(ctx context.Context, root string) error {
    56  	return c.Walk(ctx, c.remove, root)
    57  }
    58  
    59  // Verify asserts that the files found during walk
    60  // have the copyright header.
    61  func (c Copyright) Verify(ctx context.Context, root string) error {
    62  	return c.Walk(ctx, c.verify, root)
    63  }
    64  
    65  // Walk traverses the tree recursively from the root and applies the given action.
    66  //
    67  // If the root is a file, it is handled singly and then walk will return.
    68  func (c Copyright) Walk(ctx context.Context, action Action, root string) error {
    69  	noticeBody, err := c.compileNoticeBodyTemplate(c.NoticeBodyTemplateOrDefault())
    70  	if err != nil {
    71  		return err
    72  	}
    73  
    74  	c.Verbosef("using root: %s", root)
    75  	c.Verbosef("using excludes: %s", strings.Join(c.Config.Excludes, ", "))
    76  	c.Verbosef("using include files: %s", strings.Join(c.Config.IncludeFiles, ", "))
    77  	c.Verbosef("using notice body:\n%s", noticeBody)
    78  
    79  	// if the root is a file, just handle the file itself
    80  	// otherwise walk the full tree
    81  	if info, err := os.Stat(root); err != nil {
    82  		return err
    83  	} else if !info.IsDir() {
    84  		c.Debugf("root is a file, processing and returning")
    85  		return c.processFile(action, noticeBody, root, info)
    86  	}
    87  
    88  	var didFail bool
    89  	err = filepath.Walk(root, func(path string, info os.FileInfo, fileErr error) error {
    90  		if fileErr != nil {
    91  			return fileErr
    92  		}
    93  
    94  		if skipErr := c.includeOrExclude(root, path, info); skipErr != nil {
    95  			if skipErr == ErrWalkSkip {
    96  				return nil
    97  			}
    98  			return skipErr
    99  		}
   100  
   101  		walkErr := c.processFile(action, noticeBody, path, info)
   102  		if walkErr != nil {
   103  			// if we don't exit on the first failure
   104  			// check if the error is just a verification error
   105  			// if so, mark that we've had a failure
   106  			// and return nil so the walk continues
   107  			if !c.Config.ExitFirstOrDefault() {
   108  				// if it's a sentinel error
   109  				// mark we've failed and return nil
   110  				if walkErr == ErrFailure {
   111  					didFail = true
   112  					return nil
   113  				}
   114  
   115  				// this error might be an os issue / something else
   116  				// return it
   117  				return walkErr
   118  			}
   119  
   120  			// otherwise always return the error
   121  			// this will abort the walk
   122  			return walkErr
   123  		}
   124  
   125  		// no error no problem
   126  		return nil
   127  	})
   128  
   129  	// if we had an error
   130  	// return it
   131  	if err != nil {
   132  		return err
   133  	}
   134  
   135  	// if we failed at some point, ideally
   136  	// because we're set to not exit first
   137  	// return the sentinel error
   138  	if didFail {
   139  		return ErrFailure
   140  	}
   141  	return nil
   142  }
   143  
   144  // GetStdout returns standard out.
   145  func (c Copyright) GetStdout() io.Writer {
   146  	if c.QuietOrDefault() {
   147  		return io.Discard
   148  	}
   149  	if c.Stdout != nil {
   150  		return c.Stdout
   151  	}
   152  	return os.Stdout
   153  }
   154  
   155  // GetStderr returns standard error.
   156  func (c Copyright) GetStderr() io.Writer {
   157  	if c.QuietOrDefault() {
   158  		return io.Discard
   159  	}
   160  	if c.Stderr != nil {
   161  		return c.Stderr
   162  	}
   163  	return os.Stderr
   164  }
   165  
   166  // Errorf writes to stderr.
   167  func (c Copyright) Errorf(format string, args ...interface{}) {
   168  	fmt.Fprintf(c.GetStderr(), format+"\n", args...)
   169  }
   170  
   171  // Verbosef writes to stdout if the `Verbose` flag is true.
   172  func (c Copyright) Verbosef(format string, args ...interface{}) {
   173  	if !c.VerboseOrDefault() {
   174  		return
   175  	}
   176  	fmt.Fprintf(c.GetStdout(), format+"\n", args...)
   177  }
   178  
   179  // Debugf writes to stdout if the `Debug` flag is true.
   180  func (c Copyright) Debugf(format string, args ...interface{}) {
   181  	if !c.DebugOrDefault() {
   182  		return
   183  	}
   184  	fmt.Fprintf(c.GetStdout(), format+"\n", args...)
   185  }
   186  
   187  //
   188  // actions
   189  //
   190  
   191  func (c Copyright) inject(path string, info os.FileInfo, file, notice []byte) error {
   192  	injectedContents := c.injectedContents(path, file, notice)
   193  	if injectedContents == nil {
   194  		return nil
   195  	}
   196  	return os.WriteFile(path, injectedContents, info.Mode().Perm())
   197  }
   198  
   199  func (c Copyright) remove(path string, info os.FileInfo, file, notice []byte) error {
   200  	removedContents := c.removedContents(path, file, notice)
   201  	if removedContents == nil {
   202  		return nil
   203  	}
   204  	return os.WriteFile(path, removedContents, info.Mode().Perm())
   205  }
   206  
   207  func (c Copyright) verify(path string, _ os.FileInfo, file, notice []byte) error {
   208  	fileExtension := filepath.Ext(path)
   209  	var err error
   210  	if c.hasShebang(file) {
   211  		err = c.shebangVerifyNotice(path, file, notice)
   212  	} else if fileExtension == ExtensionGo { // we have to treat go files specially because of build tags
   213  		err = c.goVerifyNotice(path, file, notice)
   214  	} else if fileExtension == ExtensionTS {
   215  		err = c.tsVerifyNotice(path, file, notice)
   216  	} else {
   217  		err = c.verifyNotice(path, file, notice)
   218  	}
   219  
   220  	if err != nil {
   221  		// verify prints the file that had the issue
   222  		// as part of the normal action
   223  		c.Errorf("%+v", err)
   224  		if c.Config.ShowDiffOrDefault() {
   225  			c.showDiff(path, file, notice)
   226  		}
   227  		return ErrFailure
   228  	}
   229  	return nil
   230  }
   231  
   232  //
   233  // internal helpers
   234  //
   235  
   236  // includeOrExclude makes the determination if we should process a path (file or directory).
   237  func (c Copyright) includeOrExclude(root, path string, info os.FileInfo) error {
   238  	if info.IsDir() {
   239  		if path == root {
   240  			return ErrWalkSkip
   241  		}
   242  	}
   243  
   244  	if c.Config.Excludes != nil {
   245  		for _, exclude := range c.Config.Excludes {
   246  			if stringutil.Glob(path, exclude) {
   247  				c.Debugf("path %s matches exclude %s", path, exclude)
   248  				if info.IsDir() {
   249  					return filepath.SkipDir
   250  				}
   251  				return ErrWalkSkip
   252  			}
   253  		}
   254  	}
   255  
   256  	if c.Config.IncludeFiles != nil {
   257  		var includePath bool
   258  		for _, include := range c.Config.IncludeFiles {
   259  			if stringutil.Glob(path, include) {
   260  				includePath = true
   261  				break
   262  			}
   263  		}
   264  		if !includePath {
   265  			c.Debugf("path %s does not match any includes", path)
   266  			return ErrWalkSkip
   267  		}
   268  	}
   269  
   270  	if info.IsDir() {
   271  		return ErrWalkSkip
   272  	}
   273  
   274  	return nil
   275  }
   276  
   277  // processFile processes a single file with the action
   278  func (c Copyright) processFile(action Action, noticeBody, path string, info os.FileInfo) error {
   279  	fileExtension := filepath.Ext(path)
   280  	noticeTemplate, ok := c.noticeTemplateByExtension(fileExtension)
   281  	if !ok {
   282  		return fmt.Errorf("invalid copyright injection file; %s", filepath.Base(path))
   283  	}
   284  	notice, err := c.compileNoticeTemplate(noticeTemplate, noticeBody)
   285  	if err != nil {
   286  		return err
   287  	}
   288  	fileContents, err := os.ReadFile(path)
   289  	if err != nil {
   290  		return err
   291  	}
   292  	return action(path, info, fileContents, []byte(notice))
   293  }
   294  
   295  // noticeTemplateByExtension gets a notice template by extension or the default.
   296  func (c Copyright) noticeTemplateByExtension(fileExtension string) (noticeTemplate string, ok bool) {
   297  	if !strings.HasPrefix(fileExtension, ".") {
   298  		fileExtension = "." + fileExtension
   299  	}
   300  
   301  	// check if there is a filetype specific notice template
   302  	extensionNoticeTemplates := c.ExtensionNoticeTemplatesOrDefault()
   303  	if noticeTemplate, ok = extensionNoticeTemplates[fileExtension]; ok {
   304  		return
   305  	}
   306  
   307  	// check if we have a fallback notice template
   308  	if c.FallbackNoticeTemplate != "" {
   309  		noticeTemplate = c.FallbackNoticeTemplate
   310  		ok = true
   311  		return
   312  	}
   313  
   314  	// fail
   315  	return
   316  }
   317  
   318  func (c Copyright) injectedContents(path string, file, notice []byte) []byte {
   319  	fileExtension := filepath.Ext(path)
   320  	if c.hasShebang(file) {
   321  		return c.shebangInjectNotice(path, file, notice)
   322  	}
   323  
   324  	if fileExtension == ExtensionGo {
   325  		return c.goInjectNotice(path, file, notice)
   326  	} else if fileExtension == ExtensionTS {
   327  		return c.tsInjectNotice(path, file, notice)
   328  	}
   329  
   330  	return c.injectNotice(path, file, notice)
   331  }
   332  
   333  func (Copyright) hasShebang(file []byte) bool {
   334  	return shebangMatch.Match(file)
   335  }
   336  
   337  func (c Copyright) removedContents(path string, file, notice []byte) []byte {
   338  	fileExtension := filepath.Ext(path)
   339  	if c.hasShebang(file) {
   340  		return c.shebangRemoveNotice(path, file, notice)
   341  	}
   342  
   343  	if fileExtension == ExtensionGo { // we have to treat go files specially because of build tags
   344  		return c.goRemoveNotice(path, file, notice)
   345  	} else if fileExtension == ExtensionTS {
   346  		return c.tsRemoveNotice(path, file, notice)
   347  	}
   348  
   349  	return c.removeNotice(path, file, notice)
   350  }
   351  
   352  // shebangInjectNotice explicitly handles files that start with a shebang line.
   353  // This assumes these are not `*.go` source files so has more in common with
   354  // `injectNotice()` than with `goInjectNotice()`.
   355  func (c Copyright) shebangInjectNotice(path string, file, notice []byte) []byte {
   356  	// Strip shebang lines from beginning of file
   357  	shebangLines := shebangMatch.Find(file)
   358  	file = shebangMatch.ReplaceAll(file, nil)
   359  
   360  	if c.fileHasCopyrightHeader(file, notice) {
   361  		return nil
   362  	}
   363  	c.Verbosef("injecting notice: %s", path)
   364  
   365  	// remove any existing notice-ish looking text ...
   366  	file = c.removeCopyrightHeader(file, notice)
   367  	return c.mergeFileSections(shebangLines, notice, file)
   368  }
   369  
   370  // goInjectNotice handles go files differently because they may contain build tags.
   371  func (c Copyright) goInjectNotice(path string, file, notice []byte) []byte {
   372  	goBuildTag := goBuildTagMatch.Find(file)
   373  	file = goBuildTagMatch.ReplaceAll(file, nil)
   374  	if c.fileHasCopyrightHeader(file, notice) {
   375  		return nil
   376  	}
   377  
   378  	c.Verbosef("injecting notice: %s", path)
   379  	file = c.removeCopyrightHeader(file, notice)
   380  	return c.mergeFileSections(goBuildTag, notice, file)
   381  }
   382  
   383  // goInjectNotice handles ts files differently because they may contain build tags.
   384  func (c Copyright) tsInjectNotice(path string, file, notice []byte) []byte {
   385  	tsReferenceTags := tsReferenceTagsMatch.Find(file)
   386  	file = tsReferenceTagsMatch.ReplaceAll(file, nil)
   387  	if c.fileHasCopyrightHeader(file, notice) {
   388  		return nil
   389  	}
   390  
   391  	c.Verbosef("injecting notice: %s", path)
   392  	file = c.removeCopyrightHeader(file, notice)
   393  	return c.mergeFileSections(tsReferenceTags, notice, file)
   394  }
   395  
   396  func (c Copyright) injectNotice(path string, file, notice []byte) []byte {
   397  	if c.fileHasCopyrightHeader(file, notice) {
   398  		return nil
   399  	}
   400  	c.Verbosef("injecting notice: %s", path)
   401  
   402  	// remove any existing notice-ish looking text ...
   403  	file = c.removeCopyrightHeader(file, notice)
   404  	return c.mergeFileSections(notice, file)
   405  }
   406  
   407  // shebangRemoveNotice explicitly handles files that start with a shebang line.
   408  // This assumes these are not `*.go` source files so has more in common with
   409  // `removeNotice()` than with `goRemoveNotice()`.
   410  func (c Copyright) shebangRemoveNotice(path string, file, notice []byte) []byte {
   411  	// Strip shebang lines from beginning of file
   412  	shebangLines := shebangMatch.Find(file)
   413  	file = shebangMatch.ReplaceAll(file, nil)
   414  
   415  	if !c.fileHasCopyrightHeader(file, notice) {
   416  		return nil
   417  	}
   418  	c.Verbosef("removing notice: %s", path)
   419  	removed := c.removeCopyrightHeader(file, notice)
   420  	return c.mergeFileSections(shebangLines, removed)
   421  }
   422  
   423  func (c Copyright) goRemoveNotice(path string, file, notice []byte) []byte {
   424  	goBuildTag := goBuildTagMatch.FindString(string(file))
   425  	file = goBuildTagMatch.ReplaceAll(file, nil)
   426  	if !c.fileHasCopyrightHeader(file, notice) {
   427  		return nil
   428  	}
   429  	c.Verbosef("removing notice: %s", path)
   430  	return c.mergeFileSections([]byte(goBuildTag), c.removeCopyrightHeader(file, notice))
   431  }
   432  
   433  func (c Copyright) tsRemoveNotice(path string, file, notice []byte) []byte {
   434  	tsImportTags := tsReferenceTagsMatch.FindString(string(file))
   435  	file = tsReferenceTagsMatch.ReplaceAll(file, nil)
   436  	if !c.fileHasCopyrightHeader(file, notice) {
   437  		return nil
   438  	}
   439  	c.Verbosef("removing notice: %s", path)
   440  	return c.mergeFileSections([]byte(tsImportTags), c.removeCopyrightHeader(file, notice))
   441  }
   442  
   443  func (c Copyright) removeNotice(path string, file, notice []byte) []byte {
   444  	if !c.fileHasCopyrightHeader(file, notice) {
   445  		return nil
   446  	}
   447  	c.Verbosef("removing notice: %s", path)
   448  	return c.removeCopyrightHeader(file, notice)
   449  }
   450  
   451  // shebangVerifyNotice explicitly handles files that start with a shebang line.
   452  // This assumes these are not `*.go` source files so has more in common with
   453  // `verifyNotice()` than with `goVerifyNotice()`.
   454  func (c Copyright) shebangVerifyNotice(path string, file, notice []byte) error {
   455  	// Strip and ignore shebang lines from beginning of file
   456  	file = shebangMatch.ReplaceAll(file, nil)
   457  
   458  	c.Debugf("verifying (shebang): %s", path)
   459  	if !c.fileHasCopyrightHeader(file, notice) {
   460  		return fmt.Errorf(VerifyErrorFormat, path)
   461  	}
   462  	return nil
   463  }
   464  
   465  func (c Copyright) goVerifyNotice(path string, file, notice []byte) error {
   466  	c.Debugf("verifying (go): %s", path)
   467  	fileLessTags := goBuildTagMatch.ReplaceAll(file, nil)
   468  	if !c.fileHasCopyrightHeader(fileLessTags, notice) {
   469  		return fmt.Errorf(VerifyErrorFormat, path)
   470  	}
   471  	return nil
   472  }
   473  
   474  func (c Copyright) tsVerifyNotice(path string, file, notice []byte) error {
   475  	c.Debugf("verifying (ts): %s", path)
   476  	fileLessTags := tsReferenceTagsMatch.ReplaceAll(file, nil)
   477  	if !c.fileHasCopyrightHeader(fileLessTags, notice) {
   478  		return fmt.Errorf(VerifyErrorFormat, path)
   479  	}
   480  	return nil
   481  }
   482  
   483  func (c Copyright) verifyNotice(path string, file, notice []byte) error {
   484  	c.Debugf("verifying: %s", path)
   485  	if !c.fileHasCopyrightHeader(file, notice) {
   486  		return fmt.Errorf(VerifyErrorFormat, path)
   487  	}
   488  	return nil
   489  }
   490  
   491  func (c Copyright) createNoticeMatchExpression(notice []byte, trailingSpaceStrict bool) *regexp.Regexp {
   492  	noticeString := string(notice)
   493  	noticeExpr := yearMatch.ReplaceAllString(regexp.QuoteMeta(noticeString), yearExpr)
   494  	noticeExpr = `^(\s*)` + noticeExpr
   495  	if !trailingSpaceStrict {
   496  		// remove trailing space
   497  		noticeExpr = strings.TrimRightFunc(noticeExpr, unicode.IsSpace)
   498  		// match trailing space
   499  		noticeExpr = noticeExpr + `(\s*)`
   500  	}
   501  	return regexp.MustCompile(noticeExpr)
   502  }
   503  
   504  func (c Copyright) fileHasCopyrightHeader(fileContents, notice []byte) bool {
   505  	return c.createNoticeMatchExpression(notice, true).Match(fileContents)
   506  }
   507  
   508  func (c Copyright) removeCopyrightHeader(fileContents []byte, notice []byte) []byte {
   509  	return c.createNoticeMatchExpression(notice, false).ReplaceAll(fileContents, nil)
   510  }
   511  
   512  func (c Copyright) mergeFileSections(sections ...[]byte) []byte {
   513  	var fullLength int
   514  	for _, section := range sections {
   515  		fullLength += len(section)
   516  	}
   517  
   518  	combined := make([]byte, fullLength)
   519  
   520  	var written int
   521  	for _, section := range sections {
   522  		copy(combined[written:], section)
   523  		written += len(section)
   524  	}
   525  	return combined
   526  }
   527  
   528  func (c Copyright) prefix(prefix string, s string) string {
   529  	lines := strings.Split(s, "\n")
   530  	var output []string
   531  	for _, l := range lines {
   532  		output = append(output, prefix+l)
   533  	}
   534  	return strings.Join(output, "\n")
   535  }
   536  
   537  func (c Copyright) compileNoticeTemplate(noticeTemplate, noticeBody string) (string, error) {
   538  	return c.processTemplate(noticeTemplate, c.templateViewModel(map[string]interface{}{
   539  		"Notice": noticeBody,
   540  	}))
   541  }
   542  
   543  func (c Copyright) templateViewModel(extra ...map[string]interface{}) map[string]interface{} {
   544  	base := map[string]interface{}{
   545  		"Year":    c.YearOrDefault(),
   546  		"Company": c.CompanyOrDefault(),
   547  		"License": c.LicenseOrDefault(),
   548  	}
   549  	for _, m := range extra {
   550  		for key, value := range m {
   551  			base[key] = value
   552  		}
   553  	}
   554  	return base
   555  }
   556  
   557  func (c Copyright) compileRestrictionsTemplate(restrictionsTemplate string) (string, error) {
   558  	return c.processTemplate(restrictionsTemplate, c.templateViewModel())
   559  }
   560  
   561  func (c Copyright) compileNoticeBodyTemplate(noticeBodyTemplate string) (string, error) {
   562  	restrictions, err := c.compileRestrictionsTemplate(c.RestrictionsOrDefault())
   563  	if err != nil {
   564  		return "", err
   565  	}
   566  	viewModel := c.templateViewModel(map[string]interface{}{
   567  		"Restrictions": restrictions,
   568  	})
   569  	output, err := c.processTemplate(noticeBodyTemplate, viewModel)
   570  	if err != nil {
   571  		return "", err
   572  	}
   573  	return output, nil
   574  }
   575  
   576  func (c Copyright) processTemplate(text string, viewmodel interface{}) (string, error) {
   577  	tmpl := template.New("output")
   578  	tmpl = tmpl.Funcs(template.FuncMap{
   579  		"prefix": c.prefix,
   580  	})
   581  	compiled, err := tmpl.Parse(text)
   582  	if err != nil {
   583  		return "", err
   584  	}
   585  
   586  	output := new(bytes.Buffer)
   587  	if err = compiled.Execute(output, viewmodel); err != nil {
   588  		return "", err
   589  	}
   590  	return output.String(), nil
   591  }
   592  
   593  func (c Copyright) showDiff(path string, file, notice []byte) {
   594  	noticeLineCount := len(stringutil.SplitLines(string(notice),
   595  		stringutil.OptSplitLinesIncludeEmptyLines(true),
   596  		stringutil.OptSplitLinesIncludeNewLine(true),
   597  	))
   598  	fileLines := stringutil.SplitLines(string(file),
   599  		stringutil.OptSplitLinesIncludeEmptyLines(true),
   600  		stringutil.OptSplitLinesIncludeNewLine(true),
   601  	)
   602  	if len(fileLines) < noticeLineCount {
   603  		noticeLineCount = len(fileLines)
   604  	}
   605  	fileTruncated := strings.Join(fileLines[:noticeLineCount], "")
   606  	fileDiff := diff.New().Diff(string(notice), fileTruncated, true /*checklines*/)
   607  	prettyDiff := diff.PrettyText(fileDiff)
   608  	if strings.TrimSpace(prettyDiff) != "" {
   609  		fmt.Fprintf(c.GetStderr(), "%s: diff\n", path)
   610  		fmt.Fprintln(c.GetStderr(), prettyDiff)
   611  	}
   612  }