github.com/tenntenn/testtime@v0.2.3-0.20221118081726-55bcd1f05226/cmd/testtime/overlay.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	_ "embed"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"go/ast"
    10  	"go/build"
    11  	"go/format"
    12  	"go/parser"
    13  	"go/token"
    14  	"io"
    15  	"os"
    16  	"path/filepath"
    17  	"runtime"
    18  
    19  	"golang.org/x/tools/go/ast/astutil"
    20  )
    21  
    22  //go:embed _partials/testtime.go
    23  var testtime string
    24  
    25  func createOverlay(update bool, dir string) (string, error) {
    26  
    27  	ver := build.Default.ReleaseTags[len(build.Default.ReleaseTags)-1]
    28  
    29  	overlay := filepath.Join(dir, fmt.Sprintf("overlay_%s.json", ver))
    30  	_, err := os.Stat(overlay)
    31  	switch {
    32  	case err == nil:
    33  		if !update {
    34  			return overlay, nil
    35  		}
    36  	case !errors.Is(err, os.ErrNotExist):
    37  		return "", err
    38  	}
    39  
    40  	if err := os.MkdirAll(dir, 0o700); err != nil {
    41  		return "", err
    42  	}
    43  
    44  	var buf bytes.Buffer
    45  	old, err := replaceTimeNow(&buf)
    46  	if err != nil {
    47  		return "", err
    48  	}
    49  
    50  	fmt.Fprint(&buf, testtime)
    51  
    52  	src, err := format.Source(buf.Bytes())
    53  	if err != nil {
    54  		return "", err
    55  	}
    56  
    57  	new := filepath.Join(dir, fmt.Sprintf("time_%s.go", ver))
    58  	if err := os.WriteFile(new, src, 0o600); err != nil {
    59  		return "", err
    60  	}
    61  
    62  	v := struct {
    63  		Replace map[string]string
    64  	}{map[string]string{old: new}}
    65  	jsonBytes, err := json.Marshal(v)
    66  	if err != nil {
    67  		return "", err
    68  	}
    69  	if err := os.WriteFile(overlay, jsonBytes, 0o600); err != nil {
    70  		return "", err
    71  	}
    72  
    73  	return overlay, nil
    74  }
    75  
    76  func replaceTimeNow(w io.Writer) (string, error) {
    77  	srcDir := filepath.Join(runtime.GOROOT(), "src")
    78  	pkg, err := build.Default.Import("time", srcDir, 0)
    79  	if err != nil {
    80  		return "", err
    81  	}
    82  
    83  	fset := token.NewFileSet()
    84  	pkgs, err := parser.ParseDir(fset, pkg.Dir, nil, parser.ParseComments)
    85  	if err != nil {
    86  		return "", err
    87  	}
    88  
    89  	if pkgs["time"] == nil {
    90  		return "", errors.New("cannot find time package")
    91  	}
    92  
    93  	var (
    94  		path   string
    95  		syntax *ast.File
    96  	)
    97  LOOP:
    98  	for name, file := range pkgs["time"].Files {
    99  		for _, decl := range file.Decls {
   100  			decl, _ := decl.(*ast.FuncDecl)
   101  			if decl == nil {
   102  				continue
   103  			}
   104  
   105  			if decl.Name.Name == "Now" {
   106  				decl.Name.Name = "_Now"
   107  				path = name
   108  				syntax = file
   109  				break LOOP
   110  			}
   111  		}
   112  	}
   113  
   114  	if path == "" || syntax == nil {
   115  		return "", errors.New("cannot find time.Now")
   116  	}
   117  
   118  	astutil.AddImport(fset, syntax, "sync")
   119  	astutil.AddImport(fset, syntax, "runtime")
   120  
   121  	if err := format.Node(w, fset, syntax); err != nil {
   122  		return "", err
   123  	}
   124  
   125  	return path, nil
   126  }