github.com/windmilleng/wat@v0.0.2-0.20180626175338-9349b638e250/cli/wat/train.go (about)

     1  package wat
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"fmt"
     8  	"io/ioutil"
     9  	"os"
    10  	"path/filepath"
    11  	"regexp"
    12  	"sort"
    13  	"strings"
    14  	"time"
    15  
    16  	pb "gopkg.in/cheggaaa/pb.v1"
    17  
    18  	isatty "github.com/mattn/go-isatty"
    19  	"github.com/spf13/cobra"
    20  )
    21  
    22  const trainRecencyCutoff = time.Hour
    23  const trainTTL = 48 * time.Hour
    24  
    25  // Only fuzz files that match this suffix.
    26  // TODO(nick): Will we need to make this configurable?
    27  var fuzzSuffixes = []string{
    28  	// TODO(nick): Right now, we add comments to the file that
    29  	// will only work in JS and Go. If we add other languages, we will
    30  	// need to make the fuzz step more configurable.
    31  	".go",
    32  	".js",
    33  }
    34  
    35  var matchFalse = regexp.MustCompile("\\bfalse\\b")
    36  var matchZero = regexp.MustCompile("\\b0\\b")
    37  
    38  var trainCmd = &cobra.Command{
    39  	Use:   "train",
    40  	Short: "Train a model to make decisions on what to test",
    41  	Run:   train,
    42  }
    43  
    44  func train(cmd *cobra.Command, args []string) {
    45  	ctx, cancel := context.WithTimeout(context.Background(), CmdTimeout)
    46  	defer cancel()
    47  
    48  	ws, err := GetOrInitWatWorkspace()
    49  	if err != nil {
    50  		ws.Fatal("GetWatWorkspace", err)
    51  	}
    52  
    53  	cmds, err := populateAt(ctx, ws)
    54  	if err != nil {
    55  		ws.Fatal("List", err)
    56  	}
    57  
    58  	logs, err := Train(ctx, ws, cmds, 0 /* always fresh */)
    59  	if err != nil {
    60  		ws.Fatal("Train", err)
    61  	}
    62  
    63  	encoder := json.NewEncoder(os.Stdout)
    64  	encoder.SetIndent("", "  ")
    65  	err = encoder.Encode(logs)
    66  	if err != nil {
    67  		ws.Fatal("Encode", err)
    68  	}
    69  }
    70  
    71  // Gets training data.
    72  //
    73  // If sufficiently fresh training data lives on disk, return that data.
    74  // Otherwise, generate new training data and write it to disk.
    75  func Train(ctx context.Context, ws WatWorkspace, cmds []WatCommand, ttl time.Duration) ([]CommandLogGroup, error) {
    76  	if ttl > 0 {
    77  		info, err := ws.Stat(fnameCmdLog)
    78  		if err != nil && !os.IsNotExist(err) {
    79  			return nil, err
    80  		}
    81  
    82  		// TODO(nick): This will do training if the user hasn't run wat for a while.
    83  		// It might make sense to be more aggressive about this, e.g., run training
    84  		// if the user hasn't explicitly trained for a while.
    85  		exists := err == nil
    86  		if exists && time.Since(info.ModTime()) < ttl {
    87  			logs, err := ReadCmdLogGroups(ws)
    88  			if err != nil {
    89  				return nil, err
    90  			}
    91  			return logs, nil
    92  		}
    93  	}
    94  
    95  	result, err := trainAt(ctx, ws, cmds)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  
   100  	err = CmdLogGroupsToFile(ws, result)
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  	return result, nil
   105  }
   106  
   107  type LogSource int
   108  
   109  const (
   110  	_ = iota
   111  
   112  	// An edit made by the user
   113  	LogSourceUser LogSource = iota
   114  
   115  	// An made-up command-log used to bootstrap training,
   116  	// so that we have interesting data to work with before the
   117  	// user runs any commands.
   118  	LogSourceBootstrap
   119  
   120  	// An edit automatically generated by a fuzzer
   121  	LogSourceFuzz
   122  
   123  	// Logs generated when the trainer runs the commands
   124  	// in the workspace for the first time.
   125  	LogSourceTrainInit
   126  )
   127  
   128  // All the commands that ran at a particular state of the workspace, grouped together.
   129  type CommandLogGroup struct {
   130  	Logs    []CommandLog
   131  	Context LogContext
   132  }
   133  
   134  func newCommandLogGroup(ctx LogContext) *CommandLogGroup {
   135  	return &CommandLogGroup{Context: ctx}
   136  }
   137  
   138  func (g *CommandLogGroup) Add(l CommandLog) {
   139  	g.Logs = append(g.Logs, l)
   140  }
   141  
   142  type LogContext struct {
   143  	// watRoot-relative paths of files that have been recently edited.
   144  	// The definition of "recent" is deliberately fuzzy and might change.
   145  	RecentEdits []string
   146  
   147  	StartTime time.Time
   148  	Source    LogSource
   149  }
   150  
   151  type CommandLog struct {
   152  	// The Command field of WatCommand
   153  	Command string
   154  
   155  	Success  bool
   156  	Duration time.Duration
   157  }
   158  
   159  func trainAt(ctx context.Context, ws WatWorkspace, cmds []WatCommand) ([]CommandLogGroup, error) {
   160  	if isatty.IsTerminal(os.Stdout.Fd()) {
   161  		fmt.Fprintln(os.Stderr, "Beginning training...type <Enter> or <Esc> to interrupt")
   162  
   163  		var cancel func()
   164  		ctx, cancel = context.WithCancel(ctx)
   165  		defer cancel()
   166  
   167  		go func() {
   168  			waitOnInterruptChar(ctx, []rune{AsciiEnter, AsciiLineFeed, AsciiEsc})
   169  			ws.a.Incr(statTrainingInterrupted, nil)
   170  			cancel()
   171  		}()
   172  	}
   173  
   174  	files, err := ws.WalkRoot()
   175  	if err != nil {
   176  		return nil, err
   177  	}
   178  	sort.Sort(sort.Reverse(fileInfos(files)))
   179  
   180  	result := make([]CommandLogGroup, 0, len(cmds))
   181  
   182  	// Run all commands in the current workspace.
   183  	recentEdit := ""
   184  	if len(files) > 0 && time.Since(files[0].modTime) < trainRecencyCutoff {
   185  		recentEdit = files[0].name
   186  	}
   187  	g, err := runInitGroup(ctx, cmds, ws.Root(), recentEdit)
   188  	if err != nil {
   189  		return nil, err
   190  	}
   191  	if len(g.Logs) != 0 {
   192  		result = append(result, g)
   193  	}
   194  
   195  	// Fuzz each file and run all commands. This may take a long time. We expect
   196  	// the user to cancel or time to run out before we finish, so we fuzz the files
   197  	// in order of recent edits, and handle timeout/cancel gracefully.
   198  	for _, f := range files {
   199  		if ctx.Err() != nil {
   200  			break
   201  		}
   202  
   203  		if !shouldFuzzFile(f.name) {
   204  			continue
   205  		}
   206  
   207  		g, err := fuzzAndRun(ctx, cmds, ws.Root(), f.name)
   208  		if err != nil {
   209  			return nil, err
   210  		}
   211  
   212  		if len(g.Logs) != 0 {
   213  			result = append(result, g)
   214  		}
   215  	}
   216  
   217  	return result, nil
   218  }
   219  
   220  // Create an "init" group that runs all the commands in the current workspace.
   221  func runInitGroup(ctx context.Context, cmds []WatCommand, root string, recentEdit string) (CommandLogGroup, error) {
   222  	fmt.Fprintln(os.Stderr, "Running all tests in the current workspace")
   223  	return runCmdsWithProgress(ctx, cmds, root, LogContext{
   224  		StartTime:   time.Now(),
   225  		Source:      LogSourceTrainInit,
   226  		RecentEdits: []string{recentEdit},
   227  	})
   228  }
   229  
   230  func runCmdsWithProgress(ctx context.Context, cmds []WatCommand, root string, logCtx LogContext) (CommandLogGroup, error) {
   231  	g := CommandLogGroup{
   232  		Context: logCtx,
   233  	}
   234  	bar := pb.New(len(cmds))
   235  	bar.Output = os.Stderr
   236  	bar.Start()
   237  	defer bar.FinishPrint("")
   238  
   239  	for i, cmd := range cmds {
   240  		l, err := runCmdAndLog(ctx, root, cmd, ioutil.Discard, ioutil.Discard)
   241  		if err != nil {
   242  			if err == context.DeadlineExceeded || err == context.Canceled {
   243  				break
   244  			}
   245  			return CommandLogGroup{}, err
   246  		}
   247  		g.Logs = append(g.Logs, l)
   248  		bar.Set(i + 1)
   249  	}
   250  
   251  	return g, nil
   252  }
   253  
   254  func shouldFuzzFile(fileToFuzz string) bool {
   255  	for _, suffix := range fuzzSuffixes {
   256  		if strings.HasSuffix(fileToFuzz, suffix) {
   257  			return true
   258  		}
   259  	}
   260  	return false
   261  }
   262  
   263  // A dumb mutation: replace false with true and 0 with 1.
   264  func fuzz(contents []byte) []byte {
   265  	contents = matchFalse.ReplaceAll(contents, []byte("true"))
   266  	contents = matchZero.ReplaceAll(contents, []byte("1"))
   267  	return contents
   268  }
   269  
   270  // Make a random edit to a file and run all tests in the workspace.
   271  func fuzzAndRun(ctx context.Context, cmds []WatCommand, root, fileToFuzz string) (CommandLogGroup, error) {
   272  	absPath := filepath.Join(root, fileToFuzz)
   273  	oldContents, err := ioutil.ReadFile(absPath)
   274  	if err != nil {
   275  		return CommandLogGroup{}, err
   276  	}
   277  
   278  	newContents := fuzz(oldContents)
   279  	if bytes.Equal(newContents, oldContents) {
   280  		// if fuzzing does nothing, don't bother.
   281  		return CommandLogGroup{}, nil
   282  	}
   283  
   284  	// TODO(nick): right now this only works in JS and Go
   285  	newContents = append(newContents,
   286  		[]byte("\n// Modified by WAT fuzzer (https://github.com/windmilleng/wat)")...)
   287  
   288  	// We know the file exists, so we expect that this file mode will be ignored
   289  	mode := permFile
   290  
   291  	// It's super important that we clean up the file, even if the user
   292  	// tries to kill the process.
   293  	tearDown := createCleanup(func() {
   294  		ioutil.WriteFile(absPath, oldContents, mode)
   295  	})
   296  	defer tearDown()
   297  
   298  	err = ioutil.WriteFile(absPath, newContents, mode)
   299  	if err != nil {
   300  		return CommandLogGroup{}, err
   301  	}
   302  
   303  	_, _ = fmt.Fprintf(os.Stderr, "Fuzzing %q and running all tests\n", fileToFuzz)
   304  	return runCmdsWithProgress(ctx, cmds, root, LogContext{
   305  		StartTime:   time.Now(),
   306  		Source:      LogSourceFuzz,
   307  		RecentEdits: []string{fileToFuzz},
   308  	})
   309  }