go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/copyright/copyright.go (about)

     1  /*
     2  
     3  Copyright (c) 2023 - Present. Will Charczuk. All rights reserved.
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository.
     5  
     6  */
     7  
     8  package copyright
     9  
    10  import (
    11  	"context"
    12  	"fmt"
    13  	"io"
    14  	"os"
    15  	"path/filepath"
    16  	"regexp"
    17  	"strings"
    18  	"unicode"
    19  
    20  	"go.charczuk.com/sdk/diff"
    21  	"go.charczuk.com/sdk/stringutil"
    22  )
    23  
    24  // Copyright is the main type that injects, removes and verifies copyright headers.
    25  type Copyright struct {
    26  	// Config holds the configuration opitons.
    27  	Config
    28  	// Stdout is the writer for Verbose and Debug output.
    29  	// If it is unset, `os.Stdout` will be used.
    30  	Stdout io.Writer
    31  	// Stderr is the writer for Error output.
    32  	// If it is unset, `os.Stderr` will be used.
    33  	Stderr io.Writer
    34  }
    35  
    36  // Inject inserts the copyright header in any matching files that don't already
    37  // have the copyright header.
    38  func (c Copyright) Inject(ctx context.Context, root string) error {
    39  	return c.Walk(ctx, c.inject, root)
    40  }
    41  
    42  // Remove removes the copyright header in any matching files that
    43  // have the copyright header.
    44  func (c Copyright) Remove(ctx context.Context, root string) error {
    45  	return c.Walk(ctx, c.remove, root)
    46  }
    47  
    48  // Verify asserts that the files found during walk
    49  // have the copyright header.
    50  func (c Copyright) Verify(ctx context.Context, root string) error {
    51  	if err := c.Walk(ctx, c.verify, root); err != nil {
    52  		return err
    53  	}
    54  	c.Printf("copyright ok!")
    55  	return nil
    56  }
    57  
    58  // Walk traverses the tree recursively from the root and applies the given action.
    59  //
    60  // If the root is a file, it is handled singly and then walk will return.
    61  func (c Copyright) Walk(ctx context.Context, action Action, root string) error {
    62  	if c.Config.CopyrightHolder == "" {
    63  		return fmt.Errorf("copyright holder not configured; cannot continue")
    64  	}
    65  
    66  	noticeBody, err := c.compileNoticeTemplate(c.NoticeTemplateOrDefault())
    67  	if err != nil {
    68  		return err
    69  	}
    70  
    71  	c.Verbosef("using root: %s", root)
    72  	c.Verbosef("using excludes: %s", strings.Join(c.Config.Excludes, ", "))
    73  	c.Verbosef("using include files: %s", strings.Join(c.Config.Includes, ", "))
    74  	c.Verbosef("using notice body:\n%s", noticeBody)
    75  
    76  	// if the root is a file, just handle the file itself
    77  	// otherwise walk the full tree
    78  	if info, err := os.Stat(root); err != nil {
    79  		return err
    80  	} else if !info.IsDir() {
    81  		c.Debugf("root is a file, processing and returning")
    82  		return c.processFile(action, noticeBody, root, info)
    83  	}
    84  
    85  	var didFail bool
    86  	err = filepath.Walk(root, func(path string, info os.FileInfo, fileErr error) error {
    87  		if fileErr != nil {
    88  			return fileErr
    89  		}
    90  
    91  		if skipErr := c.includeOrExclude(root, path, info); skipErr != nil {
    92  			if skipErr == ErrWalkSkip {
    93  				return nil
    94  			}
    95  			return skipErr
    96  		}
    97  
    98  		walkErr := c.processFile(action, noticeBody, path, info)
    99  		if walkErr != nil {
   100  			// if we don't exit on the first failure
   101  			// check if the error is just a verification error
   102  			// if so, mark that we've had a failure
   103  			// and return nil so the walk continues
   104  			if !c.Config.ExitFirst {
   105  				// if it's a sentinel error
   106  				// mark we've failed and return nil
   107  				if walkErr == ErrFailure {
   108  					didFail = true
   109  					return nil
   110  				}
   111  
   112  				// this error might be an os issue / something else
   113  				// return it
   114  				return walkErr
   115  			}
   116  
   117  			// otherwise always return the error
   118  			// this will abort the walk
   119  			return walkErr
   120  		}
   121  
   122  		// no error no problem
   123  		return nil
   124  	})
   125  
   126  	// if we had an error
   127  	// return it
   128  	if err != nil {
   129  		return err
   130  	}
   131  
   132  	// if we failed at some point, ideally
   133  	// because we're set to not exit first
   134  	// return the sentinel error
   135  	if didFail {
   136  		return ErrFailure
   137  	}
   138  	return nil
   139  }
   140  
   141  // GetStdout returns standard out.
   142  func (c Copyright) GetStdout() io.Writer {
   143  	if c.Quiet {
   144  		return io.Discard
   145  	}
   146  	if c.Stdout != nil {
   147  		return c.Stdout
   148  	}
   149  	return os.Stdout
   150  }
   151  
   152  // GetStderr returns standard error.
   153  func (c Copyright) GetStderr() io.Writer {
   154  	if c.Quiet {
   155  		return io.Discard
   156  	}
   157  	if c.Stderr != nil {
   158  		return c.Stderr
   159  	}
   160  	return os.Stderr
   161  }
   162  
   163  // Printf writes to stdout.
   164  func (c Copyright) Printf(format string, args ...interface{}) {
   165  	fmt.Fprintf(c.GetStdout(), format+"\n", args...)
   166  }
   167  
   168  // Errorf writes to stderr.
   169  func (c Copyright) Errorf(format string, args ...interface{}) {
   170  	fmt.Fprintf(c.GetStderr(), format+"\n", args...)
   171  }
   172  
   173  // Verbosef writes to stdout if the `Verbose` flag is true.
   174  func (c Copyright) Verbosef(format string, args ...interface{}) {
   175  	if !c.Verbose {
   176  		return
   177  	}
   178  	fmt.Fprintf(c.GetStdout(), format+"\n", args...)
   179  }
   180  
   181  // Debugf writes to stdout if the `Debug` flag is true.
   182  func (c Copyright) Debugf(format string, args ...interface{}) {
   183  	if !c.Debug {
   184  		return
   185  	}
   186  	fmt.Fprintf(c.GetStdout(), format+"\n", args...)
   187  }
   188  
   189  //
   190  // actions
   191  //
   192  
   193  func (c Copyright) inject(path string, info os.FileInfo, file, notice []byte) error {
   194  	injectedContents := c.injectedContents(path, file, notice)
   195  	if injectedContents == nil {
   196  		return nil
   197  	}
   198  	if c.DryRun {
   199  		c.Verbosef("[dry-run] skipping writing %s", path)
   200  		return nil
   201  	}
   202  	return os.WriteFile(path, injectedContents, info.Mode().Perm())
   203  }
   204  
   205  func (c Copyright) remove(path string, info os.FileInfo, file, notice []byte) error {
   206  	removedContents := c.removedContents(path, file, notice)
   207  	if removedContents == nil {
   208  		return nil
   209  	}
   210  	if c.DryRun {
   211  		c.Verbosef("[dry-run] skipping writing %s", path)
   212  		return nil
   213  	}
   214  	return os.WriteFile(path, removedContents, info.Mode().Perm())
   215  }
   216  
   217  func (c Copyright) verify(path string, _ os.FileInfo, file, notice []byte) error {
   218  	fileExtension := Extension(filepath.Ext(path))
   219  	var err error
   220  	if c.hasShebang(file) {
   221  		err = c.shebangVerifyNotice(path, file, notice)
   222  	} else if fileExtension == ExtensionGo { // we have to treat go files specially because of build tags
   223  		err = c.goVerifyNotice(path, file, notice)
   224  	} else if fileExtension == ExtensionTS {
   225  		err = c.tsVerifyNotice(path, file, notice)
   226  	} else {
   227  		err = c.verifyNotice(path, file, notice)
   228  	}
   229  
   230  	if err != nil {
   231  		// verify prints the file that had the issue
   232  		// as part of the normal action
   233  		c.Errorf("%+v", err)
   234  		if c.Config.ShowDiff {
   235  			c.showDiff(path, file, notice)
   236  		}
   237  		return ErrFailure
   238  	}
   239  	return nil
   240  }
   241  
   242  //
   243  // internal helpers
   244  //
   245  
   246  // includeOrExclude makes the determination if we should process a path (file or directory).
   247  func (c Copyright) includeOrExclude(root, path string, info os.FileInfo) error {
   248  	if info.IsDir() {
   249  		if path == root {
   250  			return ErrWalkSkip
   251  		}
   252  	}
   253  
   254  	if c.Config.Excludes != nil {
   255  		for _, exclude := range c.Config.Excludes {
   256  			if stringutil.Glob(path, exclude) {
   257  				c.Debugf("path %s matches exclude %s", path, exclude)
   258  				if info.IsDir() {
   259  					return filepath.SkipDir
   260  				}
   261  				return ErrWalkSkip
   262  			}
   263  		}
   264  	}
   265  
   266  	if c.Config.Includes != nil {
   267  		var includePath bool
   268  		for _, include := range c.Config.Includes {
   269  			if stringutil.Glob(path, include) {
   270  				includePath = true
   271  				break
   272  			}
   273  		}
   274  		if !includePath {
   275  			c.Debugf("path %s does not match any includes", path)
   276  			return ErrWalkSkip
   277  		}
   278  	}
   279  
   280  	if info.IsDir() {
   281  		return ErrWalkSkip
   282  	}
   283  
   284  	return nil
   285  }
   286  
   287  // processFile processes a single file with the action
   288  func (c Copyright) processFile(action Action, noticeBody, path string, info os.FileInfo) error {
   289  	fileExtension := Extension(filepath.Ext(path))
   290  	injectionTemplate, ok := c.injectionTemplateByExtension(fileExtension)
   291  	if !ok {
   292  		return fmt.Errorf("copyright injection template for file not found, cannot continue; %s", filepath.Base(path))
   293  	}
   294  
   295  	injectedNotice, err := c.compileInjectionTemplate(injectionTemplate, noticeBody)
   296  	if err != nil {
   297  		return err
   298  	}
   299  	fileContents, err := os.ReadFile(path)
   300  	if err != nil {
   301  		return err
   302  	}
   303  	return action(path, info, fileContents, []byte(injectedNotice))
   304  }
   305  
   306  // injectionTemplateByExtension gets a notice template by extension or the default.
   307  func (c Copyright) injectionTemplateByExtension(fileExtension Extension) (noticeTemplate string, ok bool) {
   308  	if !strings.HasPrefix(string(fileExtension), ".") {
   309  		fileExtension = "." + fileExtension
   310  	}
   311  
   312  	// check if there is a filetype specific notice template
   313  	extensionNoticeTemplates := c.ExtensionInjectionTemplatesOrDefault()
   314  	if noticeTemplate, ok = extensionNoticeTemplates[fileExtension]; ok {
   315  		return
   316  	}
   317  	// fail
   318  	return
   319  }
   320  
   321  func (c Copyright) injectedContents(path string, file, notice []byte) []byte {
   322  	fileExtension := Extension(filepath.Ext(path))
   323  	if c.hasShebang(file) {
   324  		return c.shebangInjectNotice(path, file, notice)
   325  	}
   326  
   327  	switch fileExtension {
   328  	case ExtensionGo:
   329  		return c.goInjectNotice(path, file, notice)
   330  	case ExtensionTS:
   331  		return c.tsInjectNotice(path, file, notice)
   332  	default:
   333  		return c.injectNotice(path, file, notice)
   334  	}
   335  }
   336  
   337  func (Copyright) hasShebang(file []byte) bool {
   338  	return shebangMatch.Match(file)
   339  }
   340  
   341  func (c Copyright) removedContents(path string, file, notice []byte) []byte {
   342  	fileExtension := Extension(filepath.Ext(path))
   343  	if c.hasShebang(file) {
   344  		return c.shebangRemoveNotice(path, file, notice)
   345  	}
   346  
   347  	switch fileExtension {
   348  	case ExtensionGo:
   349  		return c.goRemoveNotice(path, file, notice)
   350  	case ExtensionTS:
   351  		return c.tsRemoveNotice(path, file, notice)
   352  	default:
   353  		return c.removeNotice(path, file, notice)
   354  	}
   355  }
   356  
   357  // shebangInjectNotice explicitly handles files that start with a shebang line.
   358  //
   359  // This assumes these are not `*.go` source files (which would be handled separately)
   360  // so has more in common in implementation with `injectNotice()` than with `goInjectNotice()`.
   361  func (c Copyright) shebangInjectNotice(path string, file, notice []byte) []byte {
   362  	// Strip shebang lines from beginning of file
   363  	shebangLines := shebangMatch.Find(file)
   364  	file = shebangMatch.ReplaceAll(file, nil)
   365  
   366  	if c.fileHasCopyrightHeader(file, notice) {
   367  		return nil
   368  	}
   369  	c.Verbosef("injecting notice: %s", path)
   370  
   371  	// remove any existing notice-ish looking text ...
   372  	file = c.removeCopyrightHeader(file, notice)
   373  	return c.mergeFileSections(shebangLines, notice, file)
   374  }
   375  
   376  // goInjectNotice handles go files differently because they may contain build tags.
   377  func (c Copyright) goInjectNotice(path string, file, notice []byte) []byte {
   378  	goBuildTag := goBuildTagMatch.Find(file)
   379  	file = goBuildTagMatch.ReplaceAll(file, nil)
   380  	if c.fileHasCopyrightHeader(file, notice) {
   381  		return nil
   382  	}
   383  
   384  	c.Verbosef("injecting notice: %s", path)
   385  	file = c.removeCopyrightHeader(file, notice)
   386  	return c.mergeFileSections(goBuildTag, notice, file)
   387  }
   388  
   389  // tsInjectNotice handles ts files differently because they may have pragma comments.
   390  func (c Copyright) tsInjectNotice(path string, file, notice []byte) []byte {
   391  	tsReferenceTags := tsReferenceTagsMatch.Find(file)
   392  	file = tsReferenceTagsMatch.ReplaceAll(file, nil)
   393  	if c.fileHasCopyrightHeader(file, notice) {
   394  		return nil
   395  	}
   396  
   397  	c.Verbosef("injecting notice: %s", path)
   398  	file = c.removeCopyrightHeader(file, notice)
   399  	return c.mergeFileSections(tsReferenceTags, notice, file)
   400  }
   401  
   402  // injectNotice is the default notice injector.
   403  func (c Copyright) injectNotice(path string, file, notice []byte) []byte {
   404  	if c.fileHasCopyrightHeader(file, notice) {
   405  		return nil
   406  	}
   407  
   408  	c.Verbosef("injecting notice: %s", path)
   409  	file = c.removeCopyrightHeader(file, notice)
   410  	return c.mergeFileSections(notice, file)
   411  }
   412  
   413  // shebangRemoveNotice explicitly handles files that start with a shebang line.
   414  //
   415  // This assumes these are not `*.go` source files so has more in common with
   416  // `removeNotice()` than with `goRemoveNotice()`.
   417  func (c Copyright) shebangRemoveNotice(path string, file, notice []byte) []byte {
   418  	// Strip shebang lines from beginning of file
   419  	shebangLines := shebangMatch.Find(file)
   420  	file = shebangMatch.ReplaceAll(file, nil)
   421  
   422  	if !c.fileHasCopyrightHeader(file, notice) {
   423  		return nil
   424  	}
   425  	c.Verbosef("removing notice: %s", path)
   426  	removed := c.removeCopyrightHeader(file, notice)
   427  	return c.mergeFileSections(shebangLines, removed)
   428  }
   429  
   430  func (c Copyright) goRemoveNotice(path string, file, notice []byte) []byte {
   431  	goBuildTag := goBuildTagMatch.FindString(string(file))
   432  	file = goBuildTagMatch.ReplaceAll(file, nil)
   433  	if !c.fileHasCopyrightHeader(file, notice) {
   434  		return nil
   435  	}
   436  	c.Verbosef("removing notice: %s", path)
   437  	return c.mergeFileSections([]byte(goBuildTag), c.removeCopyrightHeader(file, notice))
   438  }
   439  
   440  func (c Copyright) tsRemoveNotice(path string, file, notice []byte) []byte {
   441  	tsImportTags := tsReferenceTagsMatch.FindString(string(file))
   442  	file = tsReferenceTagsMatch.ReplaceAll(file, nil)
   443  	if !c.fileHasCopyrightHeader(file, notice) {
   444  		return nil
   445  	}
   446  	c.Verbosef("removing notice: %s", path)
   447  	return c.mergeFileSections([]byte(tsImportTags), c.removeCopyrightHeader(file, notice))
   448  }
   449  
   450  func (c Copyright) removeNotice(path string, file, notice []byte) []byte {
   451  	if !c.fileHasCopyrightHeader(file, notice) {
   452  		return nil
   453  	}
   454  	c.Verbosef("removing notice: %s", path)
   455  	return c.removeCopyrightHeader(file, notice)
   456  }
   457  
   458  // shebangVerifyNotice explicitly handles files that start with a shebang line.
   459  // This assumes these are not `*.go` source files so has more in common with
   460  // `verifyNotice()` than with `goVerifyNotice()`.
   461  func (c Copyright) shebangVerifyNotice(path string, file, notice []byte) error {
   462  	// Strip and ignore shebang lines from beginning of file
   463  	file = shebangMatch.ReplaceAll(file, nil)
   464  
   465  	c.Debugf("verifying (shebang): %s", path)
   466  	if !c.fileHasCopyrightHeader(file, notice) {
   467  		return fmt.Errorf(ErrVerifyFormat, path)
   468  	}
   469  	return nil
   470  }
   471  
   472  func (c Copyright) goVerifyNotice(path string, file, notice []byte) error {
   473  	c.Debugf("verifying (go): %s", path)
   474  	fileLessTags := goBuildTagMatch.ReplaceAll(file, nil)
   475  	if !c.fileHasCopyrightHeader(fileLessTags, notice) {
   476  		return fmt.Errorf(ErrVerifyFormat, path)
   477  	}
   478  	return nil
   479  }
   480  
   481  func (c Copyright) tsVerifyNotice(path string, file, notice []byte) error {
   482  	c.Debugf("verifying (ts): %s", path)
   483  	fileLessTags := tsReferenceTagsMatch.ReplaceAll(file, nil)
   484  	if !c.fileHasCopyrightHeader(fileLessTags, notice) {
   485  		return fmt.Errorf(ErrVerifyFormat, path)
   486  	}
   487  	return nil
   488  }
   489  
   490  func (c Copyright) verifyNotice(path string, file, notice []byte) error {
   491  	c.Debugf("verifying: %s", path)
   492  	if !c.fileHasCopyrightHeader(file, notice) {
   493  		return fmt.Errorf(ErrVerifyFormat, path)
   494  	}
   495  	return nil
   496  }
   497  
   498  func (c Copyright) createNoticeMatchExpression(notice []byte, trailingSpaceStrict bool) *regexp.Regexp {
   499  	noticeString := string(notice)
   500  	noticeExpr := yearMatch.ReplaceAllString(regexp.QuoteMeta(noticeString), yearExpr)
   501  	noticeExpr = `^(\s*)` + noticeExpr
   502  	if !trailingSpaceStrict {
   503  		// remove trailing space
   504  		noticeExpr = strings.TrimRightFunc(noticeExpr, unicode.IsSpace)
   505  		// match trailing space
   506  		noticeExpr = noticeExpr + `(\s*)`
   507  	}
   508  	return regexp.MustCompile(noticeExpr)
   509  }
   510  
   511  func (c Copyright) fileHasCopyrightHeader(fileContents, notice []byte) bool {
   512  	return c.createNoticeMatchExpression(notice, true).Match(fileContents)
   513  }
   514  
   515  func (c Copyright) removeCopyrightHeader(fileContents []byte, notice []byte) []byte {
   516  	return c.createNoticeMatchExpression(notice, false).ReplaceAll(fileContents, nil)
   517  }
   518  
   519  func (c Copyright) mergeFileSections(sections ...[]byte) []byte {
   520  	var fullLength int
   521  	for _, section := range sections {
   522  		fullLength += len(section)
   523  	}
   524  
   525  	combined := make([]byte, fullLength)
   526  
   527  	var written int
   528  	for _, section := range sections {
   529  		copy(combined[written:], section)
   530  		written += len(section)
   531  	}
   532  	return combined
   533  }
   534  
   535  func (c Copyright) compileInjectionTemplate(noticeTemplate, noticeBody string) (string, error) {
   536  	return renderTemplate(noticeTemplate, c.templateViewModel(map[string]interface{}{
   537  		"Notice": noticeBody,
   538  	}))
   539  }
   540  
   541  func (c Copyright) templateViewModel(extra ...map[string]interface{}) map[string]interface{} {
   542  	base := map[string]interface{}{
   543  		TemplateFieldYear:            c.YearOrDefault(),
   544  		TemplateFieldCopyrightHolder: c.CopyrightHolderOrDefault(),
   545  		TemplateFieldLicense:         c.LicenseOrDefault(),
   546  	}
   547  	for _, m := range extra {
   548  		for key, value := range m {
   549  			base[key] = value
   550  		}
   551  	}
   552  	return base
   553  }
   554  
   555  func (c Copyright) compileRestrictionsTemplate(restrictionsTemplate string) (string, error) {
   556  	return renderTemplate(restrictionsTemplate, c.templateViewModel())
   557  }
   558  
   559  func (c Copyright) compileNoticeTemplate(noticeBodyTemplate string) (string, error) {
   560  	restrictions, err := c.compileRestrictionsTemplate(c.RestrictionsOrDefault())
   561  	if err != nil {
   562  		return "", err
   563  	}
   564  	viewModel := c.templateViewModel(map[string]interface{}{
   565  		"Restrictions": restrictions,
   566  	})
   567  	output, err := renderTemplate(noticeBodyTemplate, viewModel)
   568  	if err != nil {
   569  		return "", err
   570  	}
   571  	return output, nil
   572  }
   573  
   574  func (c Copyright) showDiff(path string, file, notice []byte) {
   575  	noticeLineCount := len(stringutil.SplitLines(string(notice)))
   576  	fileLines := stringutil.SplitLines(string(file))
   577  	if len(fileLines) < noticeLineCount {
   578  		noticeLineCount = len(fileLines)
   579  	}
   580  	fileTruncated := strings.Join(fileLines[:noticeLineCount], "")
   581  	fileDiff := diff.New().Diff(string(notice), fileTruncated, true /*checklines*/)
   582  	prettyDiff := diff.PrettyText(fileDiff)
   583  	if strings.TrimSpace(prettyDiff) != "" {
   584  		fmt.Fprintf(c.GetStderr(), "%s: diff\n", path)
   585  		fmt.Fprintln(c.GetStderr(), prettyDiff)
   586  	}
   587  }