github.com/blend/go-sdk@v1.20220411.3/sourceutil/go_import_rewrite.go (about) 1 /* 2 3 Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package sourceutil 9 10 import ( 11 "context" 12 "go/ast" 13 "regexp" 14 "strings" 15 ) 16 17 // GoImportRewrite visits and optionally mutates imports for go files. 18 func GoImportRewrite(opts ...GoImportRewriteOption) GoImportVisitor { 19 var rewriteOpts GoImportRewriteOptions 20 for _, opt := range opts { 21 opt(&rewriteOpts) 22 } 23 return func(ctx context.Context, importSpec *ast.ImportSpec) error { 24 return rewriteOpts.Apply(ctx, importSpec) 25 } 26 } 27 28 // OptGoImportPathMatches returns a rewrite filter that returns if an import path matches a given expression. 29 func OptGoImportPathMatches(expr string) GoImportRewriteOption { 30 return func(opts *GoImportRewriteOptions) { 31 opts.Filter = func(ctx context.Context, importSpec *ast.ImportSpec) (bool, error) { 32 compiledExpr, err := regexp.Compile(expr) 33 if err != nil { 34 return false, err 35 } 36 return compiledExpr.MatchString(RemoveQuotes(importSpec.Path.Value)), nil 37 } 38 } 39 } 40 41 // OptGoImportNameMatches returns a rewrite filter that returns if an import name matches a given expression. 42 func OptGoImportNameMatches(expr string) GoImportRewriteOption { 43 return func(opts *GoImportRewriteOptions) { 44 opts.Filter = func(ctx context.Context, importSpec *ast.ImportSpec) (bool, error) { 45 compiledExpr, err := regexp.Compile(expr) 46 if err != nil { 47 return false, err 48 } 49 return compiledExpr.MatchString(importSpec.Name.Name), nil 50 } 51 } 52 } 53 54 // OptGoImportAddName adds a name if one is not already specified. 55 func OptGoImportAddName(name string) GoImportRewriteOption { 56 return func(opts *GoImportRewriteOptions) { 57 opts.NameVisitor = func(ctx context.Context, nameNode *ast.Ident) error { 58 if nameNode.Name == "" { 59 nameNode.Name = name 60 } 61 return nil 62 } 63 } 64 } 65 66 // OptGoImportSetAlias sets the import alias to a given value. 67 // 68 // Setting to "" will remove the alias. 69 func OptGoImportSetAlias(name string) GoImportRewriteOption { 70 return func(opts *GoImportRewriteOptions) { 71 opts.NameVisitor = func(ctx context.Context, nameNode *ast.Ident) error { 72 nameNode.Name = name 73 return nil 74 } 75 } 76 } 77 78 // OptGoImportSetPath sets an import path to a given value. 79 func OptGoImportSetPath(path string) GoImportRewriteOption { 80 return func(opts *GoImportRewriteOptions) { 81 opts.PathVisitor = func(ctx context.Context, pathNode *ast.BasicLit) error { 82 pathNode.Value = path 83 if !strings.HasPrefix(pathNode.Value, "\"") { 84 pathNode.Value = "\"" + pathNode.Value 85 } 86 if !strings.HasSuffix(pathNode.Value, "\"") { 87 pathNode.Value = pathNode.Value + "\"" 88 } 89 return nil 90 } 91 } 92 } 93 94 // OptGoImportPathRewrite returns a path filter and rewrite expression. 95 func OptGoImportPathRewrite(matchExpr, outputExpr string) GoImportRewriteOption { 96 return func(opts *GoImportRewriteOptions) { 97 compiledMatch, compileErr := regexp.Compile(matchExpr) 98 opts.Filter = func(ctx context.Context, importSpec *ast.ImportSpec) (output bool, err error) { 99 if compileErr != nil { 100 err = compileErr 101 return 102 } 103 importPath := RemoveQuotes(importSpec.Path.Value) 104 output = compiledMatch.MatchString(importPath) 105 return 106 } 107 opts.PathVisitor = func(ctx context.Context, path *ast.BasicLit) error { 108 output := []byte{} 109 for _, submatches := range compiledMatch.FindAllStringSubmatchIndex(RemoveQuotes(path.Value), -1) { 110 output = compiledMatch.ExpandString(output, outputExpr, RemoveQuotes(path.Value), submatches) 111 } 112 path.Value = string(output) 113 if !strings.HasPrefix(path.Value, "\"") { 114 path.Value = "\"" + path.Value 115 } 116 if !strings.HasSuffix(path.Value, "\"") { 117 path.Value = path.Value + "\"" 118 } 119 return nil 120 } 121 } 122 } 123 124 // GoImportVisitor mutates an ast import. 125 type GoImportVisitor func(context.Context, *ast.ImportSpec) error 126 127 // GoImportRewriteOption mutates the import rewrite options 128 type GoImportRewriteOption func(*GoImportRewriteOptions) 129 130 // GoImportRewriteOptions breaks the mutator out into field specific mutators. 131 type GoImportRewriteOptions struct { 132 Filter func(context.Context, *ast.ImportSpec) (bool, error) 133 CommentVisitor func(context.Context, *ast.CommentGroup) error 134 DocVisitor func(context.Context, *ast.CommentGroup) error 135 NameVisitor func(context.Context, *ast.Ident) error 136 PathVisitor func(context.Context, *ast.BasicLit) error 137 } 138 139 // Apply applies the options to the import. 140 func (opts GoImportRewriteOptions) Apply(ctx context.Context, importSpec *ast.ImportSpec) error { 141 if opts.Filter != nil { 142 if ok, err := opts.Filter(ctx, importSpec); err != nil { 143 return err 144 } else if !ok { 145 return nil 146 } 147 } 148 if opts.CommentVisitor != nil { 149 if importSpec.Comment == nil { 150 importSpec.Comment = &ast.CommentGroup{} 151 } 152 if err := opts.CommentVisitor(ctx, importSpec.Comment); err != nil { 153 return err 154 } 155 } 156 if opts.DocVisitor != nil { 157 if importSpec.Doc == nil { 158 importSpec.Doc = &ast.CommentGroup{} 159 } 160 if err := opts.DocVisitor(ctx, importSpec.Doc); err != nil { 161 return err 162 } 163 } 164 if opts.NameVisitor != nil { 165 if importSpec.Name == nil { 166 importSpec.Name = &ast.Ident{} 167 } 168 if err := opts.NameVisitor(ctx, importSpec.Name); err != nil { 169 return err 170 } 171 } 172 if opts.PathVisitor != nil { 173 if err := opts.PathVisitor(ctx, importSpec.Path); err != nil { 174 return err 175 } 176 } 177 return nil 178 }