gorgonia.org/gorgonia@v0.9.17/errors.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/pkg/errors"
     7  )
     8  
     9  // NoOpError is an error returned when an operation does nothing.
    10  type NoOpError interface {
    11  	NoOp() bool
    12  }
    13  
    14  type noopError struct{}
    15  
    16  func (e noopError) NoOp() bool    { return true }
    17  func (e noopError) Error() string { return "NoOp" }
    18  
    19  // errNoStabilization is an error used internally for when there is no stabilization mechanism is found.
    20  type errNoStabilization interface {
    21  	error
    22  	noStabilization() bool
    23  }
    24  
    25  // nostabilizationErr is used internally to communicate that there isn't any stabilization possible
    26  type noStabilizationErr struct{}
    27  
    28  func (noStabilizationErr) Error() string         { return "No stabilization mechanism found" }
    29  func (noStabilizationErr) noStabilization() bool { return true }
    30  
    31  // noIncrErr is an error used internally when a Value cannot be incremented
    32  type noIncrErr struct {
    33  	v Value
    34  }
    35  
    36  func (noIncrErr) Error() string  { return incrErr }
    37  func (e noIncrErr) Value() Value { return e.v }
    38  
    39  // oomError represents an Out of tensor.Memory error. It is typically used for CUDA related machine work
    40  type oomError struct {
    41  	res       int64
    42  	allocated int64
    43  }
    44  
    45  func (e oomError) Reserved() int64  { return e.res }
    46  func (e oomError) Allocated() int64 { return e.allocated }
    47  func (e oomError) Error() string    { return fmt.Sprintf("allocated/reserved: %v/%v", e.allocated, e.res) }
    48  
    49  // AutoDiffError is an error which should be passed if the function is not differentiable. This is useful for Op implementations
    50  type AutoDiffError struct{}
    51  
    52  func (err AutoDiffError) Error() string { return "AutoDiffError" }
    53  
    54  // vmContextualError is an error that is used to wrap errors that arise from the VM
    55  type vmContextualError struct {
    56  	error
    57  	node  *Node // which node was it processing
    58  	instr int   // what instruction ID it was
    59  }
    60  
    61  func (err vmContextualError) Node() *Node        { return err.node }
    62  func (err vmContextualError) Value() Value       { return err.node.Value() }
    63  func (err vmContextualError) InstructionID() int { return err.instr }
    64  func (err vmContextualError) Err() error         { return err.error }
    65  
    66  func nyi(what string, implFor interface{}) error {
    67  	return errors.Errorf(nyiFail, what, implFor)
    68  }
    69  
    70  func nondiffErr(op Op) error {
    71  	return errors.Errorf("%s is a non-differentiable function", op)
    72  }
    73  
    74  // checkErrSetDeriv sets the deriv if the error is a Valuer. Helper function for linalg operations
    75  func checkErrSetDeriv(err error, dv *dualValue) error {
    76  	if ver, ok := err.(Valuer); ok {
    77  		return dv.SetDeriv(ver.Value())
    78  	}
    79  	return err
    80  }
    81  
    82  // SymDiffError provides the context at which an error occurred
    83  type SymDiffError struct {
    84  	nodes   Nodes
    85  	single  *Node
    86  	grad    *Node
    87  	gradMap map[*Node]Nodes
    88  	err     error
    89  }
    90  
    91  func (err SymDiffError) Error() string { return err.err.Error() }
    92  
    93  // Nodes returns the nodes involved in the error
    94  func (err SymDiffError) Nodes() Nodes { return err.nodes }
    95  
    96  // Node returns a specific node involved in the error
    97  func (err SymDiffError) Node() *Node { return err.single }
    98  
    99  // Grads returns the grads involved in the error
   100  func (err SymDiffError) Grads() map[*Node]Nodes { return err.gradMap }
   101  
   102  // Grad returns a specific grad involved in the error
   103  func (err SymDiffError) Grad() *Node { return err.grad }