gorgonia.org/gorgonia@v0.9.17/vm.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"bytes"
     5  	"log"
     6  	"os"
     7  
     8  	"gorgonia.org/tensor"
     9  )
    10  
    11  // VM represents a structure that can execute a graph or program. There are two VMs (both unexported):
    12  //		- *tapeMachine
    13  //		- *lispMachine
    14  //
    15  // The *tapeMachine pre-compiles a graph into a list of instructions, then executes the instructions linearly and sequentially.
    16  // The main tradeoff is dynamism. Graphs cannot be dynamically created on the fly as a re-compilation process is required
    17  // (and compilation is relatively expensive). However, graphs executed with the *tapeMachine run much faster as plenty of optimizations
    18  // has been done in the code generation stage.
    19  //
    20  // The *lispMachine allows for graphs to be dynamically built and executed upon. The tradeoff is that executing a graph on *lispMachine
    21  // is generally slower than on *tapeMachine, given the same static "image" of a graph.
    22  type VM interface {
    23  	RunAll() error
    24  	Reset()
    25  
    26  	// Close closes all the machine resources (CUDA, if any, loggers if any)
    27  	Close() error
    28  }
    29  
    30  const (
    31  	fwdOnly byte = iota
    32  	bwdOnly
    33  	watchNaN
    34  	watchInf
    35  	allocVals
    36  	spare2 // spare2 = trace in tapeVM,
    37  	spare3 // spare3 = bindDV in tapeVM, manualRootGrad in LispVM
    38  	watchAll
    39  )
    40  
    41  // VMOpt is a VM creation option
    42  type VMOpt func(m VM)
    43  
    44  // WithLogger creates a VM with the supplied logger. If the logger is nil, a default logger, writing to os.stderr will be created.
    45  func WithLogger(logger *log.Logger) VMOpt {
    46  	f := func(m VM) {
    47  		if logger == nil {
    48  			logger = log.New(os.Stderr, "", 0)
    49  		}
    50  		switch v := m.(type) {
    51  		case *lispMachine:
    52  			v.logger = logger
    53  			v.buf = new(bytes.Buffer)
    54  		case *tapeMachine:
    55  			v.logger = logger
    56  			v.buf = new(bytes.Buffer)
    57  		default:
    58  			panic(nyi("WithLogger", v))
    59  		}
    60  	}
    61  	return f
    62  }
    63  
    64  // WithValueFmt defines how the logger will output the values. It defaults to "%3.3f"
    65  func WithValueFmt(format string) VMOpt {
    66  	f := func(m VM) {
    67  		switch v := m.(type) {
    68  		case *lispMachine:
    69  			v.valueFmt = format
    70  		case *tapeMachine:
    71  			v.valueFmt = format
    72  		default:
    73  			panic(nyi("WithValueFmt", v))
    74  		}
    75  	}
    76  	return f
    77  }
    78  
    79  // WithWatchlist creates a VM with a watchlist. When the execution touches the things in the watchlist, the VM's logger will the log it.
    80  // This allows for watching and finetuning of the algorithm. When nothing is passed in, then the VM will default to watching and logging every single
    81  // execution object.
    82  //
    83  // The watchlist allows for different things to be watched, depending on VM type:
    84  //		*lispMachine will ONLY take *Node
    85  //		*tapeMachine will take int (for register IDs) or *Node.
    86  func WithWatchlist(list ...interface{}) VMOpt {
    87  	f := func(m VM) {
    88  		switch v := m.(type) {
    89  		case *lispMachine:
    90  			if len(list) == 0 {
    91  				v.doWatchAll()
    92  				return
    93  			}
    94  
    95  			for _, item := range list {
    96  				n := item.(*Node) // will panic if node is not passed in. This is expected behaviour.
    97  				v.watchlist = append(v.watchlist, n)
    98  			}
    99  		case *tapeMachine:
   100  			if len(list) == 0 {
   101  				v.doWatchAll()
   102  				return
   103  			}
   104  
   105  			for _, item := range list {
   106  				switch i := item.(type) {
   107  				case int:
   108  					v.watchRegs = append(v.watchRegs, register{id: i})
   109  				case *Node:
   110  					v.watchNodes = append(v.watchNodes, i)
   111  				default:
   112  					panic("WithWatchlist only works with register ids or nodes")
   113  				}
   114  			}
   115  		default:
   116  			panic(nyi("WithWatchlist", v))
   117  		}
   118  	}
   119  	return f
   120  }
   121  
   122  // WithNaNWatch creates a VM that will watch for NaNs when executing. This slows the execution down.
   123  func WithNaNWatch() VMOpt {
   124  	f := func(m VM) {
   125  		switch v := m.(type) {
   126  		case *lispMachine:
   127  			v.doWatchNaN()
   128  		case *tapeMachine:
   129  			v.doWatchNaN()
   130  		default:
   131  			panic(nyi("withNaNWatch", v))
   132  		}
   133  	}
   134  	return f
   135  }
   136  
   137  // WithInfWatch creates a VM that will watch for Infs when executing. It watches for +Inf, -Inf and Inf. No choice there. This slows the execution down.
   138  func WithInfWatch() VMOpt {
   139  	f := func(m VM) {
   140  		switch v := m.(type) {
   141  		case *lispMachine:
   142  			v.doWatchInf()
   143  		case *tapeMachine:
   144  			v.doWatchInf()
   145  		default:
   146  			panic(nyi("withInfWatch", v))
   147  		}
   148  	}
   149  	return f
   150  }
   151  
   152  // ExecuteFwdOnly creates a VM that will execute a graph forwards only - it will not do back propagation.
   153  // This option is only for *lispMachine. Try it on any other VMs and it will panic.
   154  func ExecuteFwdOnly() VMOpt {
   155  	f := func(m VM) {
   156  		switch v := m.(type) {
   157  		case *lispMachine:
   158  			v.doExecFwd()
   159  			v.dontExecBwd()
   160  		default:
   161  			panic(nyi("ExecuteFwdOnly", v))
   162  		}
   163  	}
   164  	return f
   165  }
   166  
   167  // ExecuteBwdOnly creates a VM that will execute a graph by doing back propagation only.
   168  // The assumption is of course, that the forward graph has already been executed, and there
   169  // are already values associated with the nodes.
   170  // This option is only for *lispMachine. Try it on any other VMs and it will panic.
   171  func ExecuteBwdOnly() VMOpt {
   172  	f := func(m VM) {
   173  		switch v := m.(type) {
   174  		case *lispMachine:
   175  			v.doExecBwd()
   176  			v.dontExecFwd()
   177  		default:
   178  			panic(nyi("ExecuteBwdOnly", v))
   179  		}
   180  	}
   181  	return f
   182  }
   183  
   184  // LogFwd logs the forward execution of a graph.
   185  // This option is only for *lispMachine. Try it on any other VMs and it will panic.
   186  func LogFwd() VMOpt {
   187  	f := func(m VM) {
   188  		switch v := m.(type) {
   189  		case *lispMachine:
   190  			v.doLogFwd()
   191  		default:
   192  			panic(nyi("LogFwdOnly", v))
   193  		}
   194  	}
   195  	return f
   196  }
   197  
   198  // LogBwd logs the backwards execution of a graph.
   199  // This option is only for *lispMachine. Try it on any other VMs and it will panic.
   200  func LogBwd() VMOpt {
   201  	f := func(m VM) {
   202  		switch v := m.(type) {
   203  		case *lispMachine:
   204  			v.doLogBwd()
   205  		default:
   206  			panic(nyi("LogBwdOnly", v))
   207  		}
   208  	}
   209  	return f
   210  }
   211  
   212  // LogBothDir logs both directions of the execution of the graph.
   213  // This option is only available for *lispMachine.
   214  func LogBothDir() VMOpt {
   215  	f := func(m VM) {
   216  		switch v := m.(type) {
   217  		case *lispMachine:
   218  			v.doLogFwd()
   219  			v.doLogBwd()
   220  		default:
   221  			panic(nyi("LogBothDir", v))
   222  		}
   223  	}
   224  	return f
   225  }
   226  
   227  // TraceExec is an option for *tapeMachine only.
   228  // It stores an immutable copy of the executed value into the node, instead of a mutable value, which may be clobbered
   229  func TraceExec() VMOpt {
   230  	f := func(m VM) {
   231  		switch v := m.(type) {
   232  		case *tapeMachine:
   233  			v.doTrace()
   234  		default:
   235  			panic(nyi("TraceExec", v))
   236  		}
   237  	}
   238  	return f
   239  }
   240  
   241  // BindDualValues is an option for *tapeMachine only.
   242  // This is useful to set when using a Solver
   243  func BindDualValues(nodes ...*Node) VMOpt {
   244  	f := func(m VM) {
   245  		switch v := m.(type) {
   246  		case *tapeMachine:
   247  			v.doBindDV()
   248  			v.bindNodesDV = append(v.bindNodesDV, nodes...)
   249  			v.bindNodesDV = v.bindNodesDV.Set()
   250  		default:
   251  			// on by default for LispMachine
   252  		}
   253  	}
   254  	return f
   255  }
   256  
   257  // WithPrecompiled is an option to pass in compiled programs.
   258  // This is useful for users who use the CompileFunction function
   259  func WithPrecompiled(prog *program, locMap map[*Node]register) VMOpt {
   260  	f := func(m VM) {
   261  		switch v := m.(type) {
   262  		case *tapeMachine:
   263  			v.p = prog
   264  			v.locMap = locMap
   265  			v.cpumem = make([]Value, prog.cpulocs)
   266  			v.gpumem = make([]Value, prog.gpulocs)
   267  		default:
   268  			// no op
   269  		}
   270  	}
   271  	return f
   272  }
   273  
   274  // WithManualGradient allows the user to set the gradient of the root, before backprop. The root gradients should be set using the SetDeriv method
   275  func WithManualGradient() VMOpt {
   276  	f := func(m VM) {
   277  		switch v := m.(type) {
   278  		case *lispMachine:
   279  			v.allowSetRootGrad()
   280  		default:
   281  			// noop
   282  		}
   283  	}
   284  	return f
   285  }
   286  
   287  // WithEngine sets the tensor engine for computation inside the VM.
   288  func WithEngine(e tensor.Engine) VMOpt {
   289  	f := func(m VM) {
   290  		switch v := m.(type) {
   291  		case *lispMachine:
   292  			v.setEngine(e)
   293  		case *tapeMachine:
   294  			v.setEngine(e)
   295  		}
   296  	}
   297  	return f
   298  }