github.com/blend/go-sdk@v1.20220411.3/sourceutil/copy_rewriter.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package sourceutil
     9  
    10  import (
    11  	"context"
    12  	"fmt"
    13  	"go/ast"
    14  	"go/parser"
    15  	"go/printer"
    16  	"go/token"
    17  	"io"
    18  	"os"
    19  	"path/filepath"
    20  	"strings"
    21  
    22  	"github.com/blend/go-sdk/stringutil"
    23  )
    24  
    25  // CopyRewriter copies a source to a destination, and applies rewrite rules to the file(s) it copies.
    26  type CopyRewriter struct {
    27  	Source              string
    28  	Destination         string
    29  	SkipGlobs           []string
    30  	GoImportVisitors    []GoImportVisitor
    31  	GoAstVistiors       []GoAstVisitor
    32  	StringSubstitutions []StringSubstitution
    33  	DryRun              bool
    34  	RemoveDestination   bool
    35  	KeepTemporary       bool
    36  
    37  	Quiet   *bool
    38  	Verbose *bool
    39  	Debug   *bool
    40  
    41  	Stdout io.Writer
    42  	Stderr io.Writer
    43  }
    44  
    45  // Execute is the command body.
    46  func (cr CopyRewriter) Execute(ctx context.Context) error {
    47  	if _, err := os.Stat(cr.Source); err != nil {
    48  		return fmt.Errorf("source not found at %s", cr.Source)
    49  	}
    50  	tempDir, err := os.MkdirTemp("", "repoctl")
    51  	if err != nil {
    52  		return err
    53  	}
    54  	if !cr.KeepTemporary {
    55  		defer func() {
    56  			if _, err = os.Stat(tempDir); err == nil {
    57  				cr.Verbosef("cleaning up temp dir %s", tempDir)
    58  				os.RemoveAll(tempDir)
    59  			}
    60  		}()
    61  	}
    62  
    63  	// walk files
    64  	err = filepath.Walk(cr.Source, func(path string, info os.FileInfo, err error) error {
    65  		if err != nil {
    66  			return err
    67  		}
    68  
    69  		base := strings.TrimPrefix(strings.TrimPrefix(path, cr.Source), "/")
    70  		destination := filepath.Join(tempDir, base)
    71  
    72  		for _, skipGlob := range cr.SkipGlobs {
    73  			if stringutil.Glob(base, skipGlob) {
    74  				if info.IsDir() {
    75  					cr.Verbosef("%s: skipping dir", base)
    76  					return filepath.SkipDir
    77  				}
    78  				cr.Verbosef("%s: skipping", base)
    79  				return nil
    80  			}
    81  		}
    82  
    83  		if info.IsDir() {
    84  			if _, err := os.Stat(destination); err != nil {
    85  				cr.Verbosef("%s", base)
    86  				if !cr.DryRun {
    87  					cr.Debugf("%s: creating %s", base, destination)
    88  					if err = os.MkdirAll(destination, DefaultDirPerms); err != nil {
    89  						return err
    90  					}
    91  				} else {
    92  					cr.Debugf("%s: dry-run; creating dir %s", base, destination)
    93  				}
    94  			}
    95  			return nil
    96  		}
    97  
    98  		cr.Verbosef("%s", base)
    99  		if filepath.Ext(path) == ".go" {
   100  			if err := cr.copyGoSourceFile(ctx, destination, path); err != nil {
   101  				return err
   102  			}
   103  		} else {
   104  			if !cr.DryRun {
   105  				if err := Copy(ctx, destination, path); err != nil {
   106  					return err
   107  				}
   108  			}
   109  		}
   110  		return nil
   111  	})
   112  
   113  	if !cr.DryRun {
   114  		if cr.RemoveDestination {
   115  			cr.Verbosef("removing destination dir %s", cr.Destination)
   116  			if err := os.RemoveAll(cr.Destination); err != nil {
   117  				return err
   118  			}
   119  		}
   120  		cr.Verbosef("recursively copying %s to %s", tempDir, cr.Destination)
   121  		if err := CopyAll(cr.Destination, tempDir); err != nil {
   122  			return err
   123  		}
   124  	} else {
   125  		cr.Verbosef("%s", "dry-run; skipping final copy")
   126  	}
   127  	return nil
   128  }
   129  
   130  // copyGoSourceFile rewrites the imports for a golang file at a given path
   131  func (cr CopyRewriter) copyGoSourceFile(ctx context.Context, destinationPath, sourcePath string) error {
   132  	contents, err := os.ReadFile(sourcePath)
   133  	if err != nil {
   134  		return err
   135  	}
   136  	var writer io.WriteCloser
   137  	if cr.DryRun {
   138  		writer = nopWriteCloser{io.Discard}
   139  	} else {
   140  		writer, err = os.Create(destinationPath)
   141  		if err != nil {
   142  			return err
   143  		}
   144  		defer writer.Close()
   145  	}
   146  	if err = cr.rewriteGoAst(ctx, sourcePath, contents, writer); err != nil {
   147  		return err
   148  	}
   149  	return cr.rewriteContents(ctx, destinationPath)
   150  }
   151  
   152  func (cr CopyRewriter) rewriteGoAst(ctx context.Context, sourcePath string, contents []byte, writer io.Writer) error {
   153  	fset := token.NewFileSet()
   154  	fileAst, err := parser.ParseFile(fset, sourcePath, contents, parser.AllErrors|parser.ParseComments)
   155  	if err != nil {
   156  		return err
   157  	}
   158  
   159  	for importIndex := range fileAst.Imports { // foreach file import
   160  		cr.Debugf("processing import %s", fileAst.Imports[importIndex].Path.Value)
   161  		for _, rewriteRule := range cr.GoImportVisitors { // foreach import rule
   162  			if err := rewriteRule(ctx, fileAst.Imports[importIndex]); err != nil {
   163  				return err
   164  			}
   165  		}
   166  	}
   167  	for _, rewrite := range cr.GoAstVistiors {
   168  		ast.Inspect(fileAst, func(n ast.Node) bool {
   169  			if n == nil {
   170  				return false
   171  			}
   172  			return rewrite(ctx, n)
   173  		})
   174  	}
   175  	return printer.Fprint(writer, fset, fileAst)
   176  }
   177  
   178  func (cr CopyRewriter) rewriteContents(ctx context.Context, sourcePath string) error {
   179  	if len(cr.StringSubstitutions) == 0 {
   180  		return nil
   181  	}
   182  
   183  	stat, err := os.Stat(sourcePath)
   184  	if err != nil {
   185  		return err
   186  	}
   187  
   188  	contents, err := os.ReadFile(sourcePath)
   189  	if err != nil {
   190  		return err
   191  	}
   192  
   193  	var output string
   194  	var ok bool
   195  	for _, rule := range cr.StringSubstitutions {
   196  		output, ok = rule(ctx, string(contents))
   197  		if ok {
   198  			contents = []byte(output)
   199  		}
   200  	}
   201  	if cr.DryRun {
   202  		cr.Debugf("dry-run; skipping rewriting file %s", sourcePath)
   203  		return nil
   204  	}
   205  	cr.Debugf("rewriting file %s", sourcePath)
   206  	return os.WriteFile(sourcePath, contents, stat.Mode())
   207  }
   208  
   209  // QuietOrDefault returns a value or a default.
   210  func (cr CopyRewriter) QuietOrDefault() bool {
   211  	if cr.Quiet != nil {
   212  		return *cr.Quiet
   213  	}
   214  	return false
   215  }
   216  
   217  // VerboseOrDefault returns a value or a default.
   218  func (cr CopyRewriter) VerboseOrDefault() bool {
   219  	if cr.Verbose != nil {
   220  		return *cr.Verbose
   221  	}
   222  	return false
   223  }
   224  
   225  // DebugOrDefault returns a value or a default.
   226  func (cr CopyRewriter) DebugOrDefault() bool {
   227  	if cr.Debug != nil {
   228  		return *cr.Debug
   229  	}
   230  	return false
   231  }
   232  
   233  // GetStdout returns standard out.
   234  func (cr CopyRewriter) GetStdout() io.Writer {
   235  	if cr.QuietOrDefault() {
   236  		return io.Discard
   237  	}
   238  	if cr.Stdout != nil {
   239  		return cr.Stdout
   240  	}
   241  	return os.Stdout
   242  }
   243  
   244  // GetStderr returns standard error.
   245  func (cr CopyRewriter) GetStderr() io.Writer {
   246  	if cr.QuietOrDefault() {
   247  		return io.Discard
   248  	}
   249  	if cr.Stderr != nil {
   250  		return cr.Stderr
   251  	}
   252  	return os.Stderr
   253  }
   254  
   255  // Verbosef writes to stdout if the `Verbose` flag is true.
   256  func (cr CopyRewriter) Verbosef(format string, args ...interface{}) {
   257  	if !cr.VerboseOrDefault() {
   258  		return
   259  	}
   260  	fmt.Fprintf(cr.GetStdout(), format+"\n", args...)
   261  }
   262  
   263  // Debugf writes to stdout if the `Debug` flag is true.
   264  func (cr CopyRewriter) Debugf(format string, args ...interface{}) {
   265  	if !cr.DebugOrDefault() {
   266  		return
   267  	}
   268  	fmt.Fprintf(cr.GetStdout(), format+"\n", args...)
   269  }
   270  
   271  type nopWriteCloser struct {
   272  	io.Writer
   273  }
   274  
   275  func (nopWriteCloser) Close() error { return nil }