gorgonia.org/gorgonia@v0.9.17/const.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"math"
     5  
     6  	"gorgonia.org/tensor"
     7  )
     8  
     9  const (
    10  	// graphviz name for a full graph
    11  	fullGraphName = "fullGraph"
    12  
    13  	// group names
    14  	exprgraphClust = "expressionGraph"
    15  	constantsClust = "constants"
    16  	inputsClust    = "inputs"
    17  	gradClust      = "gradients"
    18  	strayClust     = "undifferentiated nodes"
    19  
    20  	// subgraphs to rank the same
    21  	outsideSubG = "outsides"
    22  	inputConsts = "inputConsts"
    23  
    24  	// special nodes for graphviz hacking
    25  	outsideRoot   = "outsideRoot"
    26  	outsideInputs = "outsideInputs"
    27  	insideInputs  = "insideInputs"
    28  	outsideConsts = "outsideConsts"
    29  	insideConsts  = "insideConsts"
    30  	outsideExprG  = "outsideExprG"
    31  	insideExprG   = "insideExprG"
    32  	outsideGrads  = "outsideGrads"
    33  	insideGrads   = "insideGrads"
    34  
    35  	// error messages
    36  	sortFail            = "Failed to sort"
    37  	cloneFail           = "Failed to carry clone(%v)"
    38  	clone0Fail          = "Failed to carry clone0()"
    39  	nyiTypeFail         = "%s not yet implemented for %T"
    40  	nyiFail             = "%s not yet implemented for %v"
    41  	dtypeOfFail         = "Failed to carry dtypeOf()"
    42  	mulFail             = "Failed to carry Mul()"
    43  	applyOpFail         = "Failed to carryApplyOp()"
    44  	opDoFail            = "Failed to carry op.Do()"
    45  	binOpDoFail         = "Failed to carry binOp.Do()"
    46  	binOpNodeFail       = "Failed to carry binary operation %T"
    47  	applyFail           = "Failed to carry Apply()"
    48  	binOpFail           = "Binary operator received %d arguments"
    49  	hadamardProdFail    = "Failed to carry hadamardProd()"
    50  	hadamardDivFail     = "Failed to carry hadamardDiv()"
    51  	cubeFail            = "Failed to carry cube()"
    52  	negFail             = "Failed to carry Neg()"
    53  	invFail             = "Failed to carry Inv()"
    54  	pointWiseMulFail    = "Failed to carry PointWiseMul()"
    55  	pointWiseSquareFail = "Failed to carry PointWiseSquare()"
    56  	clampFail           = "Failed to carry Clamp()"
    57  	invSqrtFail         = "Failed to carry InvSqrt()"
    58  	subFail             = "Failed to carry Sub()"
    59  	addFail             = "Failed to carry Add()"
    60  	signFail            = "Failed to carry Sign()"
    61  	softplusFail        = "Failed to carry Softplus()"
    62  	incrErr             = "increment couldn't be done. Safe op was performed instead"
    63  	bindFail            = "Failed to bind"
    64  	anyToValueFail      = "Failed to convert %v(%T) into a Value"
    65  	dtypeExtractionFail = "Failed to extract dtype from %v"
    66  	operationError      = "Operation failed"
    67  	doFail              = "Doing %v failed"
    68  	unsafeDoFail        = "UnsafeDoing %v failed."
    69  	tFail               = "Failed to transpose Tensor"
    70  	repFail             = "Failed to repeat Tensor along %d %d times"
    71  	reshapeFail         = "Failed to reshape Tensor into %v. DataSize was: %d"
    72  	sliceFail           = "Failed to slice Tensor with %v"
    73  	execFail            = "Failed to execute %v in node %v"
    74  	autodiffFail        = "Failed to differentiate %v"
    75  	undefinedOnShape    = "%v undefined on shape %v"
    76  	unsupportedDtype    = "dtype %v is not yet supported"
    77  	gradOnDeviceFail    = "Cannot get gradient of %v on %v"
    78  	makeValueFail       = "Unable to make value of %v with shape %v"
    79  	allocFail           = "Unable to allocate %v bytes on %v"
    80  
    81  	shapeMismatchErr = "Shape Mismatch. Expected %v. Got %v instead."
    82  )
    83  
    84  var empty struct{}
    85  
    86  var (
    87  	onef32   = NewConstant(float32(1.0))
    88  	onef64   = NewConstant(float64(1.0))
    89  	zerof32  = NewConstant(float32(0.0))
    90  	zerof64  = NewConstant(float64(0.0))
    91  	twof64   = NewConstant(float64(2.0))
    92  	twof32   = NewConstant(float32(2.0))
    93  	threef64 = NewConstant(float64(3.0))
    94  	threef32 = NewConstant(float32(3.0))
    95  	ln2f64   = NewConstant(math.Ln2)
    96  	ln2f32   = NewConstant(float32(math.Ln2))
    97  
    98  	onef32ConstOp  = onef32.op.(constant)
    99  	onef64ConstOp  = onef64.op.(constant)
   100  	zerof32ConstOp = zerof32.op.(constant)
   101  	zerof64ConstOp = zerof64.op.(constant)
   102  
   103  	constmap map[string]map[tensor.Dtype]*Node
   104  )
   105  
   106  var oneone = tensor.Shape{1, 1}
   107  
   108  func init() {
   109  	constmap = map[string]map[tensor.Dtype]*Node{
   110  		"zero": {
   111  			Float32: zerof32,
   112  			Float64: zerof64,
   113  		},
   114  		"one": {
   115  			Float32: onef32,
   116  			Float64: onef64,
   117  		},
   118  		"two": {
   119  			Float32: twof32,
   120  			Float64: twof64,
   121  		},
   122  		"three": {
   123  			Float32: threef32,
   124  			Float64: threef64,
   125  		},
   126  		"log2": {
   127  			Float32: ln2f32,
   128  			Float64: ln2f64,
   129  		},
   130  	}
   131  
   132  }