github.com/segakazzz/buffalo@v0.16.22-0.20210119082501-1f52048d3feb/buffalo/cmd/fix/middleware.go (about) 1 package fix 2 3 import ( 4 "go/ast" 5 "go/parser" 6 "go/printer" 7 "go/token" 8 "io/ioutil" 9 "os" 10 "path/filepath" 11 "strings" 12 13 "golang.org/x/tools/go/ast/astutil" 14 ) 15 16 //MiddlewareTransformer moves from our old middleware package to new one 17 type MiddlewareTransformer struct { 18 PackagesReplacement map[string]string 19 Aliases map[string]string 20 } 21 22 func (mw MiddlewareTransformer) transformPackages(r *Runner) error { 23 return filepath.Walk(".", mw.processFile) 24 } 25 26 func (mw MiddlewareTransformer) processFile(p string, fi os.FileInfo, err error) error { 27 er := onlyRelevantFiles(p, fi, err, func(p string) error { 28 if err := mw.rewriteMiddlewareUses(p); err != nil { 29 return err 30 } 31 32 fset, f, err := buildASTFor(p) 33 if err != nil { 34 if e := err.Error(); strings.Contains(e, "expected 'package', found 'EOF'") { 35 return nil 36 } 37 38 return err 39 } 40 41 //Replacing mw packages 42 for old, new := range mw.PackagesReplacement { 43 deleted := astutil.DeleteImport(fset, f, old) 44 if deleted { 45 astutil.AddNamedImport(fset, f, mw.Aliases[new], new) 46 } 47 } 48 49 if err := mw.addMissingRootMiddlewareImports(fset, f, p); err != nil { 50 return err 51 } 52 53 ast.SortImports(fset, f) 54 55 temp, err := writeTempResult(p, fset, f) 56 if err != nil { 57 return err 58 } 59 60 // rename the .temp to .go 61 return os.Rename(temp, p) 62 }) 63 64 return er 65 } 66 67 func (mw MiddlewareTransformer) addMissingRootMiddlewareImports(fset *token.FileSet, f *ast.File, p string) error { 68 read, err := ioutil.ReadFile(p) 69 if err != nil { 70 return err 71 } 72 73 content := string(read) 74 75 astutil.DeleteImport(fset, f, "github.com/gobuffalo/buffalo/middleware") 76 if strings.Contains(content, "paramlogger.ParameterLogger") { 77 astutil.AddNamedImport(fset, f, "paramlogger", "github.com/gobuffalo/mw-paramlogger") 78 } 79 80 if strings.Contains(content, "popmw.Transaction") { 81 astutil.AddImport(fset, f, "github.com/gobuffalo/buffalo-pop/v2/pop/popmw") 82 } 83 84 if strings.Contains(content, "contenttype.Add") || strings.Contains(content, "contenttype.Set") { 85 astutil.AddNamedImport(fset, f, "contenttype", "github.com/gobuffalo/mw-contenttype") 86 } 87 88 return ioutil.WriteFile(p, []byte(content), 0) 89 } 90 91 func (mw MiddlewareTransformer) rewriteMiddlewareUses(p string) error { 92 read, err := ioutil.ReadFile(p) 93 if err != nil { 94 return err 95 } 96 97 newContents := string(read) 98 newContents = strings.Replace(newContents, "middleware.SetContentType", "contenttype.Set", -1) 99 newContents = strings.Replace(newContents, "middleware.AddContentType", "contenttype.Add", -1) 100 newContents = strings.Replace(newContents, "middleware.ParameterLogger", "paramlogger.ParameterLogger", -1) 101 newContents = strings.Replace(newContents, "middleware.PopTransaction", "popmw.Transaction", -1) 102 newContents = strings.Replace(newContents, "ssl.ForceSSL", "forcessl.Middleware", -1) 103 104 err = ioutil.WriteFile(p, []byte(newContents), 0) 105 return err 106 } 107 108 func writeTempResult(name string, fset *token.FileSet, f *ast.File) (string, error) { 109 temp := name + ".temp" 110 w, err := os.Create(temp) 111 if err != nil { 112 return "", err 113 } 114 115 // write changes to .temp file, and include proper formatting. 116 err = (&printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}).Fprint(w, fset, f) 117 if err != nil { 118 return "", err 119 } 120 121 // close the writer 122 err = w.Close() 123 if err != nil { 124 return "", err 125 } 126 127 return temp, nil 128 } 129 130 func buildASTFor(p string) (*token.FileSet, *ast.File, error) { 131 fset := token.NewFileSet() 132 f, err := parser.ParseFile(fset, p, nil, parser.ParseComments) 133 return fset, f, err 134 } 135 136 //onlyRelevantFiles processes only .go files excluding folders like node_modules and vendor. 137 func onlyRelevantFiles(p string, fi os.FileInfo, err error, fn func(p string) error) error { 138 if err != nil { 139 return err 140 } 141 142 if fi.IsDir() { 143 base := filepath.Base(p) 144 if strings.HasPrefix(base, "_") { 145 return filepath.SkipDir 146 } 147 for _, n := range []string{"vendor", "node_modules", ".git"} { 148 if base == n { 149 return filepath.SkipDir 150 } 151 } 152 return nil 153 } 154 155 ext := filepath.Ext(p) 156 if ext != ".go" { 157 return nil 158 } 159 160 return fn(p) 161 }