github.com/tilt-dev/wat@v0.0.2-0.20180626175338-9349b638e250/cli/wat/train.go (about) 1 package wat 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "io/ioutil" 9 "os" 10 "path/filepath" 11 "regexp" 12 "sort" 13 "strings" 14 "time" 15 16 pb "gopkg.in/cheggaaa/pb.v1" 17 18 isatty "github.com/mattn/go-isatty" 19 "github.com/spf13/cobra" 20 ) 21 22 const trainRecencyCutoff = time.Hour 23 const trainTTL = 48 * time.Hour 24 25 // Only fuzz files that match this suffix. 26 // TODO(nick): Will we need to make this configurable? 27 var fuzzSuffixes = []string{ 28 // TODO(nick): Right now, we add comments to the file that 29 // will only work in JS and Go. If we add other languages, we will 30 // need to make the fuzz step more configurable. 31 ".go", 32 ".js", 33 } 34 35 var matchFalse = regexp.MustCompile("\\bfalse\\b") 36 var matchZero = regexp.MustCompile("\\b0\\b") 37 38 var trainCmd = &cobra.Command{ 39 Use: "train", 40 Short: "Train a model to make decisions on what to test", 41 Run: train, 42 } 43 44 func train(cmd *cobra.Command, args []string) { 45 ctx, cancel := context.WithTimeout(context.Background(), CmdTimeout) 46 defer cancel() 47 48 ws, err := GetOrInitWatWorkspace() 49 if err != nil { 50 ws.Fatal("GetWatWorkspace", err) 51 } 52 53 cmds, err := populateAt(ctx, ws) 54 if err != nil { 55 ws.Fatal("List", err) 56 } 57 58 logs, err := Train(ctx, ws, cmds, 0 /* always fresh */) 59 if err != nil { 60 ws.Fatal("Train", err) 61 } 62 63 encoder := json.NewEncoder(os.Stdout) 64 encoder.SetIndent("", " ") 65 err = encoder.Encode(logs) 66 if err != nil { 67 ws.Fatal("Encode", err) 68 } 69 } 70 71 // Gets training data. 72 // 73 // If sufficiently fresh training data lives on disk, return that data. 74 // Otherwise, generate new training data and write it to disk. 75 func Train(ctx context.Context, ws WatWorkspace, cmds []WatCommand, ttl time.Duration) ([]CommandLogGroup, error) { 76 if ttl > 0 { 77 info, err := ws.Stat(fnameCmdLog) 78 if err != nil && !os.IsNotExist(err) { 79 return nil, err 80 } 81 82 // TODO(nick): This will do training if the user hasn't run wat for a while. 83 // It might make sense to be more aggressive about this, e.g., run training 84 // if the user hasn't explicitly trained for a while. 85 exists := err == nil 86 if exists && time.Since(info.ModTime()) < ttl { 87 logs, err := ReadCmdLogGroups(ws) 88 if err != nil { 89 return nil, err 90 } 91 return logs, nil 92 } 93 } 94 95 result, err := trainAt(ctx, ws, cmds) 96 if err != nil { 97 return nil, err 98 } 99 100 err = CmdLogGroupsToFile(ws, result) 101 if err != nil { 102 return nil, err 103 } 104 return result, nil 105 } 106 107 type LogSource int 108 109 const ( 110 _ = iota 111 112 // An edit made by the user 113 LogSourceUser LogSource = iota 114 115 // An made-up command-log used to bootstrap training, 116 // so that we have interesting data to work with before the 117 // user runs any commands. 118 LogSourceBootstrap 119 120 // An edit automatically generated by a fuzzer 121 LogSourceFuzz 122 123 // Logs generated when the trainer runs the commands 124 // in the workspace for the first time. 125 LogSourceTrainInit 126 ) 127 128 // All the commands that ran at a particular state of the workspace, grouped together. 129 type CommandLogGroup struct { 130 Logs []CommandLog 131 Context LogContext 132 } 133 134 func newCommandLogGroup(ctx LogContext) *CommandLogGroup { 135 return &CommandLogGroup{Context: ctx} 136 } 137 138 func (g *CommandLogGroup) Add(l CommandLog) { 139 g.Logs = append(g.Logs, l) 140 } 141 142 type LogContext struct { 143 // watRoot-relative paths of files that have been recently edited. 144 // The definition of "recent" is deliberately fuzzy and might change. 145 RecentEdits []string 146 147 StartTime time.Time 148 Source LogSource 149 } 150 151 type CommandLog struct { 152 // The Command field of WatCommand 153 Command string 154 155 Success bool 156 Duration time.Duration 157 } 158 159 func trainAt(ctx context.Context, ws WatWorkspace, cmds []WatCommand) ([]CommandLogGroup, error) { 160 if isatty.IsTerminal(os.Stdout.Fd()) { 161 fmt.Fprintln(os.Stderr, "Beginning training...type <Enter> or <Esc> to interrupt") 162 163 var cancel func() 164 ctx, cancel = context.WithCancel(ctx) 165 defer cancel() 166 167 go func() { 168 waitOnInterruptChar(ctx, []rune{AsciiEnter, AsciiLineFeed, AsciiEsc}) 169 ws.a.Incr(statTrainingInterrupted, nil) 170 cancel() 171 }() 172 } 173 174 files, err := ws.WalkRoot() 175 if err != nil { 176 return nil, err 177 } 178 sort.Sort(sort.Reverse(fileInfos(files))) 179 180 result := make([]CommandLogGroup, 0, len(cmds)) 181 182 // Run all commands in the current workspace. 183 recentEdit := "" 184 if len(files) > 0 && time.Since(files[0].modTime) < trainRecencyCutoff { 185 recentEdit = files[0].name 186 } 187 g, err := runInitGroup(ctx, cmds, ws.Root(), recentEdit) 188 if err != nil { 189 return nil, err 190 } 191 if len(g.Logs) != 0 { 192 result = append(result, g) 193 } 194 195 // Fuzz each file and run all commands. This may take a long time. We expect 196 // the user to cancel or time to run out before we finish, so we fuzz the files 197 // in order of recent edits, and handle timeout/cancel gracefully. 198 for _, f := range files { 199 if ctx.Err() != nil { 200 break 201 } 202 203 if !shouldFuzzFile(f.name) { 204 continue 205 } 206 207 g, err := fuzzAndRun(ctx, cmds, ws.Root(), f.name) 208 if err != nil { 209 return nil, err 210 } 211 212 if len(g.Logs) != 0 { 213 result = append(result, g) 214 } 215 } 216 217 return result, nil 218 } 219 220 // Create an "init" group that runs all the commands in the current workspace. 221 func runInitGroup(ctx context.Context, cmds []WatCommand, root string, recentEdit string) (CommandLogGroup, error) { 222 fmt.Fprintln(os.Stderr, "Running all tests in the current workspace") 223 return runCmdsWithProgress(ctx, cmds, root, LogContext{ 224 StartTime: time.Now(), 225 Source: LogSourceTrainInit, 226 RecentEdits: []string{recentEdit}, 227 }) 228 } 229 230 func runCmdsWithProgress(ctx context.Context, cmds []WatCommand, root string, logCtx LogContext) (CommandLogGroup, error) { 231 g := CommandLogGroup{ 232 Context: logCtx, 233 } 234 bar := pb.New(len(cmds)) 235 bar.Output = os.Stderr 236 bar.Start() 237 defer bar.FinishPrint("") 238 239 for i, cmd := range cmds { 240 l, err := runCmdAndLog(ctx, root, cmd, ioutil.Discard, ioutil.Discard) 241 if err != nil { 242 if err == context.DeadlineExceeded || err == context.Canceled { 243 break 244 } 245 return CommandLogGroup{}, err 246 } 247 g.Logs = append(g.Logs, l) 248 bar.Set(i + 1) 249 } 250 251 return g, nil 252 } 253 254 func shouldFuzzFile(fileToFuzz string) bool { 255 for _, suffix := range fuzzSuffixes { 256 if strings.HasSuffix(fileToFuzz, suffix) { 257 return true 258 } 259 } 260 return false 261 } 262 263 // A dumb mutation: replace false with true and 0 with 1. 264 func fuzz(contents []byte) []byte { 265 contents = matchFalse.ReplaceAll(contents, []byte("true")) 266 contents = matchZero.ReplaceAll(contents, []byte("1")) 267 return contents 268 } 269 270 // Make a random edit to a file and run all tests in the workspace. 271 func fuzzAndRun(ctx context.Context, cmds []WatCommand, root, fileToFuzz string) (CommandLogGroup, error) { 272 absPath := filepath.Join(root, fileToFuzz) 273 oldContents, err := ioutil.ReadFile(absPath) 274 if err != nil { 275 return CommandLogGroup{}, err 276 } 277 278 newContents := fuzz(oldContents) 279 if bytes.Equal(newContents, oldContents) { 280 // if fuzzing does nothing, don't bother. 281 return CommandLogGroup{}, nil 282 } 283 284 // TODO(nick): right now this only works in JS and Go 285 newContents = append(newContents, 286 []byte("\n// Modified by WAT fuzzer (https://github.com/windmilleng/wat)")...) 287 288 // We know the file exists, so we expect that this file mode will be ignored 289 mode := permFile 290 291 // It's super important that we clean up the file, even if the user 292 // tries to kill the process. 293 tearDown := createCleanup(func() { 294 ioutil.WriteFile(absPath, oldContents, mode) 295 }) 296 defer tearDown() 297 298 err = ioutil.WriteFile(absPath, newContents, mode) 299 if err != nil { 300 return CommandLogGroup{}, err 301 } 302 303 _, _ = fmt.Fprintf(os.Stderr, "Fuzzing %q and running all tests\n", fileToFuzz) 304 return runCmdsWithProgress(ctx, cmds, root, LogContext{ 305 StartTime: time.Now(), 306 Source: LogSourceFuzz, 307 RecentEdits: []string{fileToFuzz}, 308 }) 309 }