github.com/blend/go-sdk@v1.20220411.3/sourceutil/go_import_rewrite.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package sourceutil
     9  
    10  import (
    11  	"context"
    12  	"go/ast"
    13  	"regexp"
    14  	"strings"
    15  )
    16  
    17  // GoImportRewrite visits and optionally mutates imports for go files.
    18  func GoImportRewrite(opts ...GoImportRewriteOption) GoImportVisitor {
    19  	var rewriteOpts GoImportRewriteOptions
    20  	for _, opt := range opts {
    21  		opt(&rewriteOpts)
    22  	}
    23  	return func(ctx context.Context, importSpec *ast.ImportSpec) error {
    24  		return rewriteOpts.Apply(ctx, importSpec)
    25  	}
    26  }
    27  
    28  // OptGoImportPathMatches returns a rewrite filter that returns if an import path matches a given expression.
    29  func OptGoImportPathMatches(expr string) GoImportRewriteOption {
    30  	return func(opts *GoImportRewriteOptions) {
    31  		opts.Filter = func(ctx context.Context, importSpec *ast.ImportSpec) (bool, error) {
    32  			compiledExpr, err := regexp.Compile(expr)
    33  			if err != nil {
    34  				return false, err
    35  			}
    36  			return compiledExpr.MatchString(RemoveQuotes(importSpec.Path.Value)), nil
    37  		}
    38  	}
    39  }
    40  
    41  // OptGoImportNameMatches returns a rewrite filter that returns if an import name matches a given expression.
    42  func OptGoImportNameMatches(expr string) GoImportRewriteOption {
    43  	return func(opts *GoImportRewriteOptions) {
    44  		opts.Filter = func(ctx context.Context, importSpec *ast.ImportSpec) (bool, error) {
    45  			compiledExpr, err := regexp.Compile(expr)
    46  			if err != nil {
    47  				return false, err
    48  			}
    49  			return compiledExpr.MatchString(importSpec.Name.Name), nil
    50  		}
    51  	}
    52  }
    53  
    54  // OptGoImportAddName adds a name if one is not already specified.
    55  func OptGoImportAddName(name string) GoImportRewriteOption {
    56  	return func(opts *GoImportRewriteOptions) {
    57  		opts.NameVisitor = func(ctx context.Context, nameNode *ast.Ident) error {
    58  			if nameNode.Name == "" {
    59  				nameNode.Name = name
    60  			}
    61  			return nil
    62  		}
    63  	}
    64  }
    65  
    66  // OptGoImportSetAlias sets the import alias to a given value.
    67  //
    68  // Setting to "" will remove the alias.
    69  func OptGoImportSetAlias(name string) GoImportRewriteOption {
    70  	return func(opts *GoImportRewriteOptions) {
    71  		opts.NameVisitor = func(ctx context.Context, nameNode *ast.Ident) error {
    72  			nameNode.Name = name
    73  			return nil
    74  		}
    75  	}
    76  }
    77  
    78  // OptGoImportSetPath sets an import path to a given value.
    79  func OptGoImportSetPath(path string) GoImportRewriteOption {
    80  	return func(opts *GoImportRewriteOptions) {
    81  		opts.PathVisitor = func(ctx context.Context, pathNode *ast.BasicLit) error {
    82  			pathNode.Value = path
    83  			if !strings.HasPrefix(pathNode.Value, "\"") {
    84  				pathNode.Value = "\"" + pathNode.Value
    85  			}
    86  			if !strings.HasSuffix(pathNode.Value, "\"") {
    87  				pathNode.Value = pathNode.Value + "\""
    88  			}
    89  			return nil
    90  		}
    91  	}
    92  }
    93  
    94  // OptGoImportPathRewrite returns a path filter and rewrite expression.
    95  func OptGoImportPathRewrite(matchExpr, outputExpr string) GoImportRewriteOption {
    96  	return func(opts *GoImportRewriteOptions) {
    97  		compiledMatch, compileErr := regexp.Compile(matchExpr)
    98  		opts.Filter = func(ctx context.Context, importSpec *ast.ImportSpec) (output bool, err error) {
    99  			if compileErr != nil {
   100  				err = compileErr
   101  				return
   102  			}
   103  			importPath := RemoveQuotes(importSpec.Path.Value)
   104  			output = compiledMatch.MatchString(importPath)
   105  			return
   106  		}
   107  		opts.PathVisitor = func(ctx context.Context, path *ast.BasicLit) error {
   108  			output := []byte{}
   109  			for _, submatches := range compiledMatch.FindAllStringSubmatchIndex(RemoveQuotes(path.Value), -1) {
   110  				output = compiledMatch.ExpandString(output, outputExpr, RemoveQuotes(path.Value), submatches)
   111  			}
   112  			path.Value = string(output)
   113  			if !strings.HasPrefix(path.Value, "\"") {
   114  				path.Value = "\"" + path.Value
   115  			}
   116  			if !strings.HasSuffix(path.Value, "\"") {
   117  				path.Value = path.Value + "\""
   118  			}
   119  			return nil
   120  		}
   121  	}
   122  }
   123  
   124  // GoImportVisitor mutates an ast import.
   125  type GoImportVisitor func(context.Context, *ast.ImportSpec) error
   126  
   127  // GoImportRewriteOption mutates the import rewrite options
   128  type GoImportRewriteOption func(*GoImportRewriteOptions)
   129  
   130  // GoImportRewriteOptions breaks the mutator out into field specific mutators.
   131  type GoImportRewriteOptions struct {
   132  	Filter         func(context.Context, *ast.ImportSpec) (bool, error)
   133  	CommentVisitor func(context.Context, *ast.CommentGroup) error
   134  	DocVisitor     func(context.Context, *ast.CommentGroup) error
   135  	NameVisitor    func(context.Context, *ast.Ident) error
   136  	PathVisitor    func(context.Context, *ast.BasicLit) error
   137  }
   138  
   139  // Apply applies the options to the import.
   140  func (opts GoImportRewriteOptions) Apply(ctx context.Context, importSpec *ast.ImportSpec) error {
   141  	if opts.Filter != nil {
   142  		if ok, err := opts.Filter(ctx, importSpec); err != nil {
   143  			return err
   144  		} else if !ok {
   145  			return nil
   146  		}
   147  	}
   148  	if opts.CommentVisitor != nil {
   149  		if importSpec.Comment == nil {
   150  			importSpec.Comment = &ast.CommentGroup{}
   151  		}
   152  		if err := opts.CommentVisitor(ctx, importSpec.Comment); err != nil {
   153  			return err
   154  		}
   155  	}
   156  	if opts.DocVisitor != nil {
   157  		if importSpec.Doc == nil {
   158  			importSpec.Doc = &ast.CommentGroup{}
   159  		}
   160  		if err := opts.DocVisitor(ctx, importSpec.Doc); err != nil {
   161  			return err
   162  		}
   163  	}
   164  	if opts.NameVisitor != nil {
   165  		if importSpec.Name == nil {
   166  			importSpec.Name = &ast.Ident{}
   167  		}
   168  		if err := opts.NameVisitor(ctx, importSpec.Name); err != nil {
   169  			return err
   170  		}
   171  	}
   172  	if opts.PathVisitor != nil {
   173  		if err := opts.PathVisitor(ctx, importSpec.Path); err != nil {
   174  			return err
   175  		}
   176  	}
   177  	return nil
   178  }