github.com/singlemusic/buffalo@v0.16.30/buffalo/cmd/fix/middleware.go (about)

     1  package fix
     2  
     3  import (
     4  	"go/ast"
     5  	"go/parser"
     6  	"go/printer"
     7  	"go/token"
     8  	"io/ioutil"
     9  	"os"
    10  	"path/filepath"
    11  	"strings"
    12  
    13  	"golang.org/x/tools/go/ast/astutil"
    14  )
    15  
    16  //MiddlewareTransformer moves from our old middleware package to new one
    17  type MiddlewareTransformer struct {
    18  	PackagesReplacement map[string]string
    19  	Aliases             map[string]string
    20  }
    21  
    22  func (mw MiddlewareTransformer) transformPackages(r *Runner) error {
    23  	return filepath.Walk(".", mw.processFile)
    24  }
    25  
    26  func (mw MiddlewareTransformer) processFile(p string, fi os.FileInfo, err error) error {
    27  	er := onlyRelevantFiles(p, fi, err, func(p string) error {
    28  		if err := mw.rewriteMiddlewareUses(p); err != nil {
    29  			return err
    30  		}
    31  
    32  		fset, f, err := buildASTFor(p)
    33  		if err != nil {
    34  			if e := err.Error(); strings.Contains(e, "expected 'package', found 'EOF'") {
    35  				return nil
    36  			}
    37  
    38  			return err
    39  		}
    40  
    41  		//Replacing mw packages
    42  		for old, new := range mw.PackagesReplacement {
    43  			deleted := astutil.DeleteImport(fset, f, old)
    44  			if deleted {
    45  				astutil.AddNamedImport(fset, f, mw.Aliases[new], new)
    46  			}
    47  		}
    48  
    49  		if err := mw.addMissingRootMiddlewareImports(fset, f, p); err != nil {
    50  			return err
    51  		}
    52  
    53  		ast.SortImports(fset, f)
    54  
    55  		temp, err := writeTempResult(p, fset, f)
    56  		if err != nil {
    57  			return err
    58  		}
    59  
    60  		// rename the .temp to .go
    61  		return os.Rename(temp, p)
    62  	})
    63  
    64  	return er
    65  }
    66  
    67  func (mw MiddlewareTransformer) addMissingRootMiddlewareImports(fset *token.FileSet, f *ast.File, p string) error {
    68  	read, err := ioutil.ReadFile(p)
    69  	if err != nil {
    70  		return err
    71  	}
    72  
    73  	content := string(read)
    74  
    75  	astutil.DeleteImport(fset, f, "github.com/gobuffalo/buffalo/middleware")
    76  	if strings.Contains(content, "paramlogger.ParameterLogger") {
    77  		astutil.AddNamedImport(fset, f, "paramlogger", "github.com/gobuffalo/mw-paramlogger")
    78  	}
    79  
    80  	if strings.Contains(content, "popmw.Transaction") {
    81  		astutil.AddImport(fset, f, "github.com/gobuffalo/buffalo-pop/v2/pop/popmw")
    82  	}
    83  
    84  	if strings.Contains(content, "contenttype.Add") || strings.Contains(content, "contenttype.Set") {
    85  		astutil.AddNamedImport(fset, f, "contenttype", "github.com/gobuffalo/mw-contenttype")
    86  	}
    87  
    88  	return ioutil.WriteFile(p, []byte(content), 0)
    89  }
    90  
    91  func (mw MiddlewareTransformer) rewriteMiddlewareUses(p string) error {
    92  	read, err := ioutil.ReadFile(p)
    93  	if err != nil {
    94  		return err
    95  	}
    96  
    97  	newContents := string(read)
    98  	newContents = strings.Replace(newContents, "middleware.SetContentType", "contenttype.Set", -1)
    99  	newContents = strings.Replace(newContents, "middleware.AddContentType", "contenttype.Add", -1)
   100  	newContents = strings.Replace(newContents, "middleware.ParameterLogger", "paramlogger.ParameterLogger", -1)
   101  	newContents = strings.Replace(newContents, "middleware.PopTransaction", "popmw.Transaction", -1)
   102  	newContents = strings.Replace(newContents, "ssl.ForceSSL", "forcessl.Middleware", -1)
   103  
   104  	err = ioutil.WriteFile(p, []byte(newContents), 0)
   105  	return err
   106  }
   107  
   108  func writeTempResult(name string, fset *token.FileSet, f *ast.File) (string, error) {
   109  	temp := name + ".temp"
   110  	w, err := os.Create(temp)
   111  	if err != nil {
   112  		return "", err
   113  	}
   114  
   115  	// write changes to .temp file, and include proper formatting.
   116  	err = (&printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}).Fprint(w, fset, f)
   117  	if err != nil {
   118  		return "", err
   119  	}
   120  
   121  	// close the writer
   122  	err = w.Close()
   123  	if err != nil {
   124  		return "", err
   125  	}
   126  
   127  	return temp, nil
   128  }
   129  
   130  func buildASTFor(p string) (*token.FileSet, *ast.File, error) {
   131  	fset := token.NewFileSet()
   132  	f, err := parser.ParseFile(fset, p, nil, parser.ParseComments)
   133  	return fset, f, err
   134  }
   135  
   136  //onlyRelevantFiles processes only .go files excluding folders like node_modules and vendor.
   137  func onlyRelevantFiles(p string, fi os.FileInfo, err error, fn func(p string) error) error {
   138  	if err != nil {
   139  		return err
   140  	}
   141  
   142  	if fi.IsDir() {
   143  		base := filepath.Base(p)
   144  		if strings.HasPrefix(base, "_") {
   145  			return filepath.SkipDir
   146  		}
   147  		for _, n := range []string{"vendor", "node_modules", ".git"} {
   148  			if base == n {
   149  				return filepath.SkipDir
   150  			}
   151  		}
   152  		return nil
   153  	}
   154  
   155  	ext := filepath.Ext(p)
   156  	if ext != ".go" {
   157  		return nil
   158  	}
   159  
   160  	return fn(p)
   161  }