gorgonia.org/gorgonia@v0.9.17/vm.go (about) 1 package gorgonia 2 3 import ( 4 "bytes" 5 "log" 6 "os" 7 8 "gorgonia.org/tensor" 9 ) 10 11 // VM represents a structure that can execute a graph or program. There are two VMs (both unexported): 12 // - *tapeMachine 13 // - *lispMachine 14 // 15 // The *tapeMachine pre-compiles a graph into a list of instructions, then executes the instructions linearly and sequentially. 16 // The main tradeoff is dynamism. Graphs cannot be dynamically created on the fly as a re-compilation process is required 17 // (and compilation is relatively expensive). However, graphs executed with the *tapeMachine run much faster as plenty of optimizations 18 // has been done in the code generation stage. 19 // 20 // The *lispMachine allows for graphs to be dynamically built and executed upon. The tradeoff is that executing a graph on *lispMachine 21 // is generally slower than on *tapeMachine, given the same static "image" of a graph. 22 type VM interface { 23 RunAll() error 24 Reset() 25 26 // Close closes all the machine resources (CUDA, if any, loggers if any) 27 Close() error 28 } 29 30 const ( 31 fwdOnly byte = iota 32 bwdOnly 33 watchNaN 34 watchInf 35 allocVals 36 spare2 // spare2 = trace in tapeVM, 37 spare3 // spare3 = bindDV in tapeVM, manualRootGrad in LispVM 38 watchAll 39 ) 40 41 // VMOpt is a VM creation option 42 type VMOpt func(m VM) 43 44 // WithLogger creates a VM with the supplied logger. If the logger is nil, a default logger, writing to os.stderr will be created. 45 func WithLogger(logger *log.Logger) VMOpt { 46 f := func(m VM) { 47 if logger == nil { 48 logger = log.New(os.Stderr, "", 0) 49 } 50 switch v := m.(type) { 51 case *lispMachine: 52 v.logger = logger 53 v.buf = new(bytes.Buffer) 54 case *tapeMachine: 55 v.logger = logger 56 v.buf = new(bytes.Buffer) 57 default: 58 panic(nyi("WithLogger", v)) 59 } 60 } 61 return f 62 } 63 64 // WithValueFmt defines how the logger will output the values. It defaults to "%3.3f" 65 func WithValueFmt(format string) VMOpt { 66 f := func(m VM) { 67 switch v := m.(type) { 68 case *lispMachine: 69 v.valueFmt = format 70 case *tapeMachine: 71 v.valueFmt = format 72 default: 73 panic(nyi("WithValueFmt", v)) 74 } 75 } 76 return f 77 } 78 79 // WithWatchlist creates a VM with a watchlist. When the execution touches the things in the watchlist, the VM's logger will the log it. 80 // This allows for watching and finetuning of the algorithm. When nothing is passed in, then the VM will default to watching and logging every single 81 // execution object. 82 // 83 // The watchlist allows for different things to be watched, depending on VM type: 84 // *lispMachine will ONLY take *Node 85 // *tapeMachine will take int (for register IDs) or *Node. 86 func WithWatchlist(list ...interface{}) VMOpt { 87 f := func(m VM) { 88 switch v := m.(type) { 89 case *lispMachine: 90 if len(list) == 0 { 91 v.doWatchAll() 92 return 93 } 94 95 for _, item := range list { 96 n := item.(*Node) // will panic if node is not passed in. This is expected behaviour. 97 v.watchlist = append(v.watchlist, n) 98 } 99 case *tapeMachine: 100 if len(list) == 0 { 101 v.doWatchAll() 102 return 103 } 104 105 for _, item := range list { 106 switch i := item.(type) { 107 case int: 108 v.watchRegs = append(v.watchRegs, register{id: i}) 109 case *Node: 110 v.watchNodes = append(v.watchNodes, i) 111 default: 112 panic("WithWatchlist only works with register ids or nodes") 113 } 114 } 115 default: 116 panic(nyi("WithWatchlist", v)) 117 } 118 } 119 return f 120 } 121 122 // WithNaNWatch creates a VM that will watch for NaNs when executing. This slows the execution down. 123 func WithNaNWatch() VMOpt { 124 f := func(m VM) { 125 switch v := m.(type) { 126 case *lispMachine: 127 v.doWatchNaN() 128 case *tapeMachine: 129 v.doWatchNaN() 130 default: 131 panic(nyi("withNaNWatch", v)) 132 } 133 } 134 return f 135 } 136 137 // WithInfWatch creates a VM that will watch for Infs when executing. It watches for +Inf, -Inf and Inf. No choice there. This slows the execution down. 138 func WithInfWatch() VMOpt { 139 f := func(m VM) { 140 switch v := m.(type) { 141 case *lispMachine: 142 v.doWatchInf() 143 case *tapeMachine: 144 v.doWatchInf() 145 default: 146 panic(nyi("withInfWatch", v)) 147 } 148 } 149 return f 150 } 151 152 // ExecuteFwdOnly creates a VM that will execute a graph forwards only - it will not do back propagation. 153 // This option is only for *lispMachine. Try it on any other VMs and it will panic. 154 func ExecuteFwdOnly() VMOpt { 155 f := func(m VM) { 156 switch v := m.(type) { 157 case *lispMachine: 158 v.doExecFwd() 159 v.dontExecBwd() 160 default: 161 panic(nyi("ExecuteFwdOnly", v)) 162 } 163 } 164 return f 165 } 166 167 // ExecuteBwdOnly creates a VM that will execute a graph by doing back propagation only. 168 // The assumption is of course, that the forward graph has already been executed, and there 169 // are already values associated with the nodes. 170 // This option is only for *lispMachine. Try it on any other VMs and it will panic. 171 func ExecuteBwdOnly() VMOpt { 172 f := func(m VM) { 173 switch v := m.(type) { 174 case *lispMachine: 175 v.doExecBwd() 176 v.dontExecFwd() 177 default: 178 panic(nyi("ExecuteBwdOnly", v)) 179 } 180 } 181 return f 182 } 183 184 // LogFwd logs the forward execution of a graph. 185 // This option is only for *lispMachine. Try it on any other VMs and it will panic. 186 func LogFwd() VMOpt { 187 f := func(m VM) { 188 switch v := m.(type) { 189 case *lispMachine: 190 v.doLogFwd() 191 default: 192 panic(nyi("LogFwdOnly", v)) 193 } 194 } 195 return f 196 } 197 198 // LogBwd logs the backwards execution of a graph. 199 // This option is only for *lispMachine. Try it on any other VMs and it will panic. 200 func LogBwd() VMOpt { 201 f := func(m VM) { 202 switch v := m.(type) { 203 case *lispMachine: 204 v.doLogBwd() 205 default: 206 panic(nyi("LogBwdOnly", v)) 207 } 208 } 209 return f 210 } 211 212 // LogBothDir logs both directions of the execution of the graph. 213 // This option is only available for *lispMachine. 214 func LogBothDir() VMOpt { 215 f := func(m VM) { 216 switch v := m.(type) { 217 case *lispMachine: 218 v.doLogFwd() 219 v.doLogBwd() 220 default: 221 panic(nyi("LogBothDir", v)) 222 } 223 } 224 return f 225 } 226 227 // TraceExec is an option for *tapeMachine only. 228 // It stores an immutable copy of the executed value into the node, instead of a mutable value, which may be clobbered 229 func TraceExec() VMOpt { 230 f := func(m VM) { 231 switch v := m.(type) { 232 case *tapeMachine: 233 v.doTrace() 234 default: 235 panic(nyi("TraceExec", v)) 236 } 237 } 238 return f 239 } 240 241 // BindDualValues is an option for *tapeMachine only. 242 // This is useful to set when using a Solver 243 func BindDualValues(nodes ...*Node) VMOpt { 244 f := func(m VM) { 245 switch v := m.(type) { 246 case *tapeMachine: 247 v.doBindDV() 248 v.bindNodesDV = append(v.bindNodesDV, nodes...) 249 v.bindNodesDV = v.bindNodesDV.Set() 250 default: 251 // on by default for LispMachine 252 } 253 } 254 return f 255 } 256 257 // WithPrecompiled is an option to pass in compiled programs. 258 // This is useful for users who use the CompileFunction function 259 func WithPrecompiled(prog *program, locMap map[*Node]register) VMOpt { 260 f := func(m VM) { 261 switch v := m.(type) { 262 case *tapeMachine: 263 v.p = prog 264 v.locMap = locMap 265 v.cpumem = make([]Value, prog.cpulocs) 266 v.gpumem = make([]Value, prog.gpulocs) 267 default: 268 // no op 269 } 270 } 271 return f 272 } 273 274 // WithManualGradient allows the user to set the gradient of the root, before backprop. The root gradients should be set using the SetDeriv method 275 func WithManualGradient() VMOpt { 276 f := func(m VM) { 277 switch v := m.(type) { 278 case *lispMachine: 279 v.allowSetRootGrad() 280 default: 281 // noop 282 } 283 } 284 return f 285 } 286 287 // WithEngine sets the tensor engine for computation inside the VM. 288 func WithEngine(e tensor.Engine) VMOpt { 289 f := func(m VM) { 290 switch v := m.(type) { 291 case *lispMachine: 292 v.setEngine(e) 293 case *tapeMachine: 294 v.setEngine(e) 295 } 296 } 297 return f 298 }