gorgonia.org/gorgonia@v0.9.17/operatorPointwise_unary_const.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  
     7  	"github.com/chewxy/math32"
     8  )
     9  
    10  var (
    11  	/* float64 */
    12  
    13  	// non differentiable
    14  	absf64   = sf64UnaryOperator(math.Abs)
    15  	signf64  = sf64UnaryOperator(_signf64)
    16  	ceilf64  = sf64UnaryOperator(math.Ceil)
    17  	floorf64 = sf64UnaryOperator(math.Floor)
    18  
    19  	// differentiable
    20  	sinf64         = sf64UnaryOperator(math.Sin)
    21  	cosf64         = sf64UnaryOperator(math.Cos)
    22  	expf64         = sf64UnaryOperator(math.Exp)
    23  	lnf64          = sf64UnaryOperator(math.Log)
    24  	log2f64        = sf64UnaryOperator(math.Log2)
    25  	negf64         = sf64UnaryOperator(_negf64)
    26  	squaref64      = sf64UnaryOperator(_squaref64)
    27  	sqrtf64        = sf64UnaryOperator(math.Sqrt)
    28  	inversef64     = sf64UnaryOperator(_inversef64)
    29  	inverseSqrtf64 = sf64UnaryOperator(_inverseSqrtf64)
    30  
    31  	// activation functions
    32  	cubef64    = sf64UnaryOperator(_cubef64)
    33  	tanhf64    = sf64UnaryOperator(_tanhf64)
    34  	sigmoidf64 = sf64UnaryOperator(_sigmoidf64)
    35  
    36  	// numerical stabilization optimization
    37  	log1pf64    = sf64UnaryOperator(math.Log1p)
    38  	expm1f64    = sf64UnaryOperator(math.Expm1)
    39  	softplusf64 = sf64UnaryOperator(_softplusf64)
    40  	// softplus isn't necessarily only a numerical stabilization op
    41  	// (you can use it elsewhere), but I included it under numerical optimization
    42  
    43  	/* Float32 */
    44  
    45  	// non differentiable
    46  	absf32   = sf32UnaryOperator(math32.Abs)
    47  	signf32  = sf32UnaryOperator(_signf32)
    48  	ceilf32  = sf32UnaryOperator(math32.Ceil)
    49  	floorf32 = sf32UnaryOperator(math32.Floor)
    50  
    51  	// start differentiable
    52  	sinf32         = sf32UnaryOperator(math32.Sin)
    53  	cosf32         = sf32UnaryOperator(math32.Cos)
    54  	expf32         = sf32UnaryOperator(math32.Exp)
    55  	lnf32          = sf32UnaryOperator(math32.Log)
    56  	log2f32        = sf32UnaryOperator(math32.Log2)
    57  	negf32         = sf32UnaryOperator(_negf32)
    58  	squaref32      = sf32UnaryOperator(_squaref32)
    59  	sqrtf32        = sf32UnaryOperator(math32.Sqrt)
    60  	inversef32     = sf32UnaryOperator(_inversef32)
    61  	inverseSqrtf32 = sf32UnaryOperator(_inverseSqrtf32)
    62  
    63  	// typically used in activation functions
    64  	cubef32    = sf32UnaryOperator(_cubef32)
    65  	tanhf32    = sf32UnaryOperator(_tanhf32)
    66  	sigmoidf32 = sf32UnaryOperator(_sigmoidf32)
    67  
    68  	// numerical stabilization optimization
    69  	log1pf32    = sf32UnaryOperator(math32.Log1p)
    70  	expm1f32    = sf32UnaryOperator(math32.Expm1)
    71  	softplusf32 = sf32UnaryOperator(_softplusf32)
    72  )
    73  
    74  type ʘUnaryOperatorType byte
    75  
    76  const (
    77  	absOpType ʘUnaryOperatorType = iota
    78  	signOpType
    79  	ceilOpType
    80  	floorOpType
    81  
    82  	// start differentiable
    83  	sinOpType
    84  	cosOpType
    85  	expOpType
    86  	lnOpType
    87  	log2OpType
    88  	negOpType
    89  	squareOpType
    90  	sqrtOpType
    91  	inverseOpType     // multiplicative inverse
    92  	inverseSqrtOpType // 1/sqrt(x)
    93  
    94  	// typically used in activation functions
    95  	cubeOpType
    96  	tanhOpType
    97  	sigmoidOpType
    98  
    99  	// optimization related
   100  	log1pOpType
   101  	expm1OpType
   102  	softplusOpType
   103  
   104  	maxʘUnaryOperator // delimits end of all possible unary ops
   105  )
   106  
   107  func (u ʘUnaryOperatorType) String() string {
   108  	if u >= maxʘUnaryOperator {
   109  		return fmt.Sprintf("UNSUPPORTED UNARY OPERATOR (%d); max: %d", u, maxʘUnaryOperator)
   110  	}
   111  
   112  	return ʘUnaryOpStrs[u]
   113  }
   114  
   115  // ʘUnaryOpStrs is the string representation for a unaryOpType
   116  // It should be held constant.
   117  var ʘUnaryOpStrs = [maxʘUnaryOperator]string{
   118  	"abs", "sign", "ceil", "floor",
   119  	"sin", "cos", "exp",
   120  	"ln", "log2", "neg", "square", "sqrt",
   121  	"inv", "invSqrt",
   122  	"cube", "tanh", "sigmoid",
   123  
   124  	"log1p", "expm1", "softplus",
   125  }
   126  
   127  // ʘUnaryOpDifferentiable is the array of whether a unary operator is differentiable
   128  // It should be held constant
   129  var ʘUnaryOpDifferentiable = [maxʘUnaryOperator]bool{
   130  	true, false, false, false,
   131  	true, true, true,
   132  	true, true, true, true, true,
   133  	true, true,
   134  	true, true, true,
   135  
   136  	true, true, true,
   137  }
   138  
   139  var ʘUnaryOpDiffExprs = [maxʘUnaryOperator]func(x, y, gradY *Node) (*Node, error){
   140  	absDiffExpr, nondiffUnaryOpExpr, nondiffUnaryOpExpr, nondiffUnaryOpExpr,
   141  	sinDiffExpr, cosDiffExpr, expDiffExpr,
   142  	lnDiffExpr, log2DiffExpr, negDiffExpr, squareDiffExpr, sqrtDiffExpr,
   143  	inverseDiffExpr, inverseSqrtDiffExpr, cubeDiffExpr, tanhDiffExpr, sigmoidDiffExpr,
   144  
   145  	log1pDiffExpr, expm1DiffExpr, softplusDiffExpr,
   146  }
   147  
   148  var ʘUnaryOpDiffFns = [maxʘUnaryOperator]func(x, y *Node) error{
   149  	absDiff, nondiffUnaryOp, nondiffUnaryOp, nondiffUnaryOp,
   150  	sinDiff, cosDiff, expDiff,
   151  	lnDiff, log2Diff, negDiff, squareDiff, sqrtDiff,
   152  	inverseDiff, inverseSqrtDiff, cubeDiff, tanhDiff, sigmoidDiff,
   153  
   154  	log1pDiff, expm1Diff, softplusDiff,
   155  }
   156  
   157  var sf64UnaryOperators = [maxʘUnaryOperator]*sf64UnaryOperator{
   158  	&absf64,
   159  	&signf64,
   160  	&ceilf64,
   161  	&floorf64,
   162  	&sinf64,
   163  	&cosf64,
   164  	&expf64,
   165  	&lnf64,
   166  	&log2f64,
   167  	&negf64,
   168  	&squaref64,
   169  	&sqrtf64,
   170  	&inversef64,
   171  	&inverseSqrtf64,
   172  	&cubef64,
   173  	&tanhf64,
   174  	&sigmoidf64,
   175  
   176  	&log1pf64,
   177  	&expm1f64,
   178  	&softplusf64,
   179  }
   180  
   181  var sf32UnaryOperators = [maxʘUnaryOperator]*sf32UnaryOperator{
   182  	&absf32,
   183  	&signf32,
   184  	&ceilf32,
   185  	&floorf32,
   186  	&sinf32,
   187  	&cosf32,
   188  	&expf32,
   189  	&lnf32,
   190  	&log2f32,
   191  	&negf32,
   192  	&squaref32,
   193  	&sqrtf32,
   194  	&inversef32,
   195  	&inverseSqrtf32,
   196  	&cubef32,
   197  	&tanhf32,
   198  	&sigmoidf32,
   199  
   200  	&log1pf32,
   201  	&expm1f32,
   202  	&softplusf32,
   203  }