github.com/liquid-dev/text@v0.3.3-liquid/message/pipeline/rewrite.go (about)

     1  // Copyright 2017 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package pipeline
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"go/ast"
    11  	"go/constant"
    12  	"go/format"
    13  	"go/token"
    14  	"io"
    15  	"os"
    16  	"strings"
    17  
    18  	"github.com/liquid-dev/tools/go/loader"
    19  )
    20  
    21  const printerType = "github.com/liquid-dev/text/message.Printer"
    22  
    23  // Rewrite rewrites the Go files in a single package to use the localization
    24  // machinery and rewrites strings to adopt best practices when possible.
    25  // If w is not nil the generated files are written to it, each files with a
    26  // "--- <filename>" header. Otherwise the files are overwritten.
    27  func Rewrite(w io.Writer, args ...string) error {
    28  	conf := &loader.Config{
    29  		AllowErrors: true, // Allow unused instances of message.Printer.
    30  	}
    31  	prog, err := loadPackages(conf, args)
    32  	if err != nil {
    33  		return wrap(err, "")
    34  	}
    35  
    36  	for _, info := range prog.InitialPackages() {
    37  		for _, f := range info.Files {
    38  			// Associate comments with nodes.
    39  
    40  			// Pick up initialized Printers at the package level.
    41  			r := rewriter{info: info, conf: conf}
    42  			for _, n := range info.InitOrder {
    43  				if t := r.info.Types[n.Rhs].Type.String(); strings.HasSuffix(t, printerType) {
    44  					r.printerVar = n.Lhs[0].Name()
    45  				}
    46  			}
    47  
    48  			ast.Walk(&r, f)
    49  
    50  			w := w
    51  			if w == nil {
    52  				var err error
    53  				if w, err = os.Create(conf.Fset.File(f.Pos()).Name()); err != nil {
    54  					return wrap(err, "open failed")
    55  				}
    56  			} else {
    57  				fmt.Fprintln(w, "---", conf.Fset.File(f.Pos()).Name())
    58  			}
    59  
    60  			if err := format.Node(w, conf.Fset, f); err != nil {
    61  				return wrap(err, "go format failed")
    62  			}
    63  		}
    64  	}
    65  
    66  	return nil
    67  }
    68  
    69  type rewriter struct {
    70  	info       *loader.PackageInfo
    71  	conf       *loader.Config
    72  	printerVar string
    73  }
    74  
    75  // print returns Go syntax for the specified node.
    76  func (r *rewriter) print(n ast.Node) string {
    77  	var buf bytes.Buffer
    78  	format.Node(&buf, r.conf.Fset, n)
    79  	return buf.String()
    80  }
    81  
    82  func (r *rewriter) Visit(n ast.Node) ast.Visitor {
    83  	// Save the state by scope.
    84  	if _, ok := n.(*ast.BlockStmt); ok {
    85  		r := *r
    86  		return &r
    87  	}
    88  	// Find Printers created by assignment.
    89  	stmt, ok := n.(*ast.AssignStmt)
    90  	if ok {
    91  		for _, v := range stmt.Lhs {
    92  			if r.printerVar == r.print(v) {
    93  				r.printerVar = ""
    94  			}
    95  		}
    96  		for i, v := range stmt.Rhs {
    97  			if t := r.info.Types[v].Type.String(); strings.HasSuffix(t, printerType) {
    98  				r.printerVar = r.print(stmt.Lhs[i])
    99  				return r
   100  			}
   101  		}
   102  	}
   103  	// Find Printers created by variable declaration.
   104  	spec, ok := n.(*ast.ValueSpec)
   105  	if ok {
   106  		for _, v := range spec.Names {
   107  			if r.printerVar == r.print(v) {
   108  				r.printerVar = ""
   109  			}
   110  		}
   111  		for i, v := range spec.Values {
   112  			if t := r.info.Types[v].Type.String(); strings.HasSuffix(t, printerType) {
   113  				r.printerVar = r.print(spec.Names[i])
   114  				return r
   115  			}
   116  		}
   117  	}
   118  	if r.printerVar == "" {
   119  		return r
   120  	}
   121  	call, ok := n.(*ast.CallExpr)
   122  	if !ok {
   123  		return r
   124  	}
   125  
   126  	// TODO: Handle literal values?
   127  	sel, ok := call.Fun.(*ast.SelectorExpr)
   128  	if !ok {
   129  		return r
   130  	}
   131  	meth := r.info.Selections[sel]
   132  
   133  	source := r.print(sel.X)
   134  	fun := r.print(sel.Sel)
   135  	if meth != nil {
   136  		source = meth.Recv().String()
   137  		fun = meth.Obj().Name()
   138  	}
   139  
   140  	// TODO: remove cheap hack and check if the type either
   141  	// implements some interface or is specifically of type
   142  	// "github.com/liquid-dev/text/message".Printer.
   143  	m, ok := rewriteFuncs[source]
   144  	if !ok {
   145  		return r
   146  	}
   147  
   148  	rewriteType, ok := m[fun]
   149  	if !ok {
   150  		return r
   151  	}
   152  	ident := ast.NewIdent(r.printerVar)
   153  	ident.NamePos = sel.X.Pos()
   154  	sel.X = ident
   155  	if rewriteType.method != "" {
   156  		sel.Sel.Name = rewriteType.method
   157  	}
   158  
   159  	// Analyze arguments.
   160  	argn := rewriteType.arg
   161  	if rewriteType.format || argn >= len(call.Args) {
   162  		return r
   163  	}
   164  	hasConst := false
   165  	for _, a := range call.Args[argn:] {
   166  		if v := r.info.Types[a].Value; v != nil && v.Kind() == constant.String {
   167  			hasConst = true
   168  			break
   169  		}
   170  	}
   171  	if !hasConst {
   172  		return r
   173  	}
   174  	sel.Sel.Name = rewriteType.methodf
   175  
   176  	// We are done if there is only a single string that does not need to be
   177  	// escaped.
   178  	if len(call.Args) == 1 {
   179  		s, ok := constStr(r.info, call.Args[0])
   180  		if ok && !strings.Contains(s, "%") && !rewriteType.newLine {
   181  			return r
   182  		}
   183  	}
   184  
   185  	// Rewrite arguments as format string.
   186  	expr := &ast.BasicLit{
   187  		ValuePos: call.Lparen,
   188  		Kind:     token.STRING,
   189  	}
   190  	newArgs := append(call.Args[:argn:argn], expr)
   191  	newStr := []string{}
   192  	for i, a := range call.Args[argn:] {
   193  		if s, ok := constStr(r.info, a); ok {
   194  			newStr = append(newStr, strings.Replace(s, "%", "%%", -1))
   195  		} else {
   196  			newStr = append(newStr, "%v")
   197  			newArgs = append(newArgs, call.Args[argn+i])
   198  		}
   199  	}
   200  	s := strings.Join(newStr, rewriteType.sep)
   201  	if rewriteType.newLine {
   202  		s += "\n"
   203  	}
   204  	expr.Value = fmt.Sprintf("%q", s)
   205  
   206  	call.Args = newArgs
   207  
   208  	// TODO: consider creating an expression instead of a constant string and
   209  	// then wrapping it in an escape function or so:
   210  	// call.Args[argn+i] = &ast.CallExpr{
   211  	// 		Fun: &ast.SelectorExpr{
   212  	// 			X:   ast.NewIdent("message"),
   213  	// 			Sel: ast.NewIdent("Lookup"),
   214  	// 		},
   215  	// 		Args: []ast.Expr{a},
   216  	// 	}
   217  	// }
   218  
   219  	return r
   220  }
   221  
   222  type rewriteType struct {
   223  	// method is the name of the equivalent method on a printer, or "" if it is
   224  	// the same.
   225  	method string
   226  
   227  	// methodf is the method to use if the arguments can be rewritten as a
   228  	// arguments to a printf-style call.
   229  	methodf string
   230  
   231  	// format is true if the method takes a formatting string followed by
   232  	// substitution arguments.
   233  	format bool
   234  
   235  	// arg indicates the position of the argument to extract. If all is
   236  	// positive, all arguments from this argument onwards needs to be extracted.
   237  	arg int
   238  
   239  	sep     string
   240  	newLine bool
   241  }
   242  
   243  // rewriteFuncs list functions that can be directly mapped to the printer
   244  // functions of the message package.
   245  var rewriteFuncs = map[string]map[string]rewriteType{
   246  	// TODO: Printer -> *github.com/liquid-dev/text/message.Printer
   247  	"fmt": {
   248  		"Print":  rewriteType{methodf: "Printf"},
   249  		"Sprint": rewriteType{methodf: "Sprintf"},
   250  		"Fprint": rewriteType{methodf: "Fprintf"},
   251  
   252  		"Println":  rewriteType{methodf: "Printf", sep: " ", newLine: true},
   253  		"Sprintln": rewriteType{methodf: "Sprintf", sep: " ", newLine: true},
   254  		"Fprintln": rewriteType{methodf: "Fprintf", sep: " ", newLine: true},
   255  
   256  		"Printf":  rewriteType{method: "Printf", format: true},
   257  		"Sprintf": rewriteType{method: "Sprintf", format: true},
   258  		"Fprintf": rewriteType{method: "Fprintf", format: true},
   259  	},
   260  }
   261  
   262  func constStr(info *loader.PackageInfo, e ast.Expr) (s string, ok bool) {
   263  	v := info.Types[e].Value
   264  	if v == nil || v.Kind() != constant.String {
   265  		return "", false
   266  	}
   267  	return constant.StringVal(v), true
   268  }