gorgonia.org/gorgonia@v0.9.17/utils.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  	"hash/fnv"
     6  	"math"
     7  
     8  	"github.com/chewxy/math32"
     9  	"github.com/pkg/errors"
    10  	"gonum.org/v1/gonum/graph"
    11  	"gonum.org/v1/gonum/graph/iterator"
    12  	"gorgonia.org/tensor"
    13  )
    14  
    15  const (
    16  	maxFloat32 = math32.MaxFloat32
    17  	maxFloat64 = math.MaxFloat64
    18  )
    19  
    20  // NodesToValueGrads is a utility function that converts a Nodes to a slice of ValueGrad for the solvers
    21  func NodesToValueGrads(in Nodes) (out []ValueGrad) {
    22  	out = make([]ValueGrad, len(in))
    23  	for i := range in {
    24  		out[i] = in[i]
    25  	}
    26  	return out
    27  }
    28  
    29  func graphNodeToNode(in graph.Nodes) (out Nodes) {
    30  	out = make(Nodes, in.Len())
    31  	for i := 0; in.Next(); i++ {
    32  		out[i] = in.Node().(*Node)
    33  	}
    34  
    35  	return
    36  }
    37  
    38  func sliceNodesToNodes(in []graph.Node) (out Nodes) {
    39  	out = make(Nodes, len(in))
    40  	for i := range in {
    41  		out[i] = in[i].(*Node)
    42  	}
    43  	return
    44  }
    45  
    46  func nodeToGraphNode(in []*Node) graph.Nodes {
    47  	nodes := make([]graph.Node, len(in))
    48  	for i, n := range in {
    49  		nodes[i] = n
    50  	}
    51  	return iterator.NewOrderedNodes(nodes)
    52  }
    53  
    54  func tensorInfo(t tensor.Tensor) (dt tensor.Dtype, dim int) {
    55  	dt = t.Dtype()
    56  	dim = t.Dims()
    57  	return
    58  }
    59  
    60  func valueToInt(v Value) (int, error) {
    61  	var intV int
    62  	switch sv := v.(type) {
    63  	case *F64:
    64  		intV = int(float64(*sv))
    65  	case *F32:
    66  		intV = int(float32(*sv))
    67  	case *I:
    68  		intV = int(*sv)
    69  	case *I32:
    70  		intV = int(int32(*sv))
    71  	case *I64:
    72  		intV = int(int64(*sv))
    73  	case *U8:
    74  		intV = int(byte(*sv))
    75  	default:
    76  		return -1, errors.Errorf("Expected values to be all Scalar Value. Got %v of %T instead", v, v)
    77  	}
    78  	return intV, nil
    79  }
    80  
    81  // valuesToInts will FORCIBLY cast floats to ints.
    82  func valuesToInts(values []Value) (retVal []int, err error) {
    83  	retVal = tensor.BorrowInts(len(values))
    84  	for i, v := range values {
    85  		var intV int
    86  		switch sv := v.(type) {
    87  		case *F64:
    88  			intV = int(float64(*sv))
    89  		case *F32:
    90  			intV = int(float32(*sv))
    91  		case *I:
    92  			intV = int(*sv)
    93  		case *I32:
    94  			intV = int(int32(*sv))
    95  		case *I64:
    96  			intV = int(int64(*sv))
    97  		case *U8:
    98  			intV = int(byte(*sv))
    99  		case Scalar:
   100  			return nil, errors.Errorf(nyiTypeFail, "valueToInts", v)
   101  		default:
   102  			return nil, errors.Errorf("Expected values to be all Scalar Value. Got %v of %T instead", v, v)
   103  
   104  		}
   105  		retVal[i] = intV
   106  	}
   107  	return
   108  }
   109  
   110  func valuesToTensors(values []Value) (retVal []tensor.Tensor, err error) {
   111  	retVal = make([]tensor.Tensor, len(values))
   112  	for i, v := range values {
   113  		if vt, ok := v.(tensor.Tensor); ok {
   114  			retVal[i] = vt
   115  			continue
   116  		}
   117  		return nil, errors.Errorf("Expected values to all be tensor.Tensor. Got %v of %T in %dth index of the slice", v, v, i)
   118  	}
   119  	return
   120  }
   121  
   122  func intRange(start, end int) []int {
   123  	size := end - start
   124  	incr := true
   125  	if start > end {
   126  		incr = false
   127  		size = start - end
   128  	}
   129  
   130  	if size < 0 {
   131  		panic("Cannot create an int range that is somehow negative in size")
   132  	}
   133  
   134  	retVal := make([]int, size)
   135  
   136  	for i, v := 0, start; i < size; i++ {
   137  		retVal[i] = v
   138  		if incr {
   139  			v++
   140  		} else {
   141  			v--
   142  		}
   143  	}
   144  	return retVal
   145  }
   146  
   147  func ones(dt tensor.Dtype, sizes ...int) (retVal Value) {
   148  	if len(sizes) == 0 {
   149  		return one(dt)
   150  	}
   151  	return tensor.Ones(dt, sizes...)
   152  }
   153  
   154  func hasInf(v Value, dev Device) bool {
   155  	switch vt := v.(type) {
   156  	case *F64:
   157  		return math.IsInf(float64(*vt), 0)
   158  	case *F32:
   159  		return math32.IsInf(float32(*vt), 0)
   160  	case tensor.Tensor:
   161  		if e, ok := vt.Engine().(tensor.InfChecker); ok {
   162  			ok, _ := e.HasInf(vt) // BUG: errors not checked
   163  			return ok
   164  		}
   165  
   166  		dt := vt.Dtype()
   167  		if dt != tensor.Float64 && dt != tensor.Float32 {
   168  			return false
   169  		}
   170  		switch dt {
   171  		case tensor.Float32:
   172  			data := vt.Data().([]float32)
   173  			for _, datum := range data {
   174  				if math32.IsInf(datum, 0) {
   175  					return true
   176  				}
   177  			}
   178  		case tensor.Float64:
   179  			data := vt.Data().([]float64)
   180  			for _, datum := range data {
   181  				if math.IsInf(datum, 0) {
   182  					return true
   183  				}
   184  			}
   185  		}
   186  		return false
   187  	case *dualValue:
   188  		return hasInf(vt.Value, dev) || hasInf(vt.d, dev)
   189  	default:
   190  		err := nyi("hasInf", v)
   191  		panic(err)
   192  	}
   193  }
   194  
   195  func hasNaN(v Value, dev Device) bool {
   196  	switch vt := v.(type) {
   197  	case *F64:
   198  		return math.IsNaN(float64(*vt))
   199  	case *F32:
   200  		return math32.IsNaN(float32(*vt))
   201  	case tensor.Tensor:
   202  		if e, ok := vt.Engine().(tensor.NaNChecker); ok {
   203  			ok, _ := e.HasNaN(vt) // BUG: errors not checked
   204  			return ok
   205  		}
   206  
   207  		dt := vt.Dtype()
   208  		if dt != tensor.Float64 && dt != tensor.Float32 {
   209  			return false
   210  		}
   211  
   212  		switch dt {
   213  		case tensor.Float32:
   214  			data := vt.Data().([]float32)
   215  			for _, datum := range data {
   216  				if math32.IsNaN(datum) {
   217  					return true
   218  				}
   219  			}
   220  		case tensor.Float64:
   221  			data := vt.Data().([]float64)
   222  			for _, datum := range data {
   223  				if math.IsNaN(datum) {
   224  					return true
   225  				}
   226  			}
   227  		}
   228  		return false
   229  	case *dualValue:
   230  		return hasNaN(vt.Value, dev) || hasNaN(vt.d, dev)
   231  	default:
   232  		err := nyi("hasNaN", vt)
   233  		panic(err)
   234  	}
   235  }
   236  
   237  func setZero(val Value) (retVal Value) {
   238  	switch v := val.(type) {
   239  	case Zeroer:
   240  		v.Zero()
   241  		return v
   242  	case Scalar:
   243  		return zero(v.Dtype())
   244  	default:
   245  		panic(fmt.Sprintf("setZero not implemented yet for %T", v))
   246  	}
   247  }
   248  
   249  func checkArity(op arityer, inputs int) error {
   250  	if inputs != op.Arity() && op.Arity() >= 0 {
   251  		return errors.Errorf("%v has an arity of %d. Got %d instead", op, op.Arity(), inputs)
   252  	}
   253  	return nil
   254  }
   255  
   256  func maxInt(a, b int) int {
   257  	if a > b {
   258  		return a
   259  	}
   260  	return b
   261  }
   262  
   263  func minInt(a, b int) int {
   264  	if a < b {
   265  		return a
   266  	}
   267  	return b
   268  }
   269  
   270  func ceilDivInt(a, b int) int {
   271  	return (a + b - 1) / b
   272  }
   273  
   274  func simpleHash(op hashWriter) uint32 {
   275  	h := fnv.New32a()
   276  	op.WriteHash(h)
   277  	return h.Sum32()
   278  }
   279  
   280  func getDV(x, y *Node) (xdv, ydv *dualValue) {
   281  	return x.boundTo.(*dualValue), y.boundTo.(*dualValue)
   282  }
   283  
   284  func getDV3(x, y, z *Node) (xdv, ydv, zdv *dualValue) {
   285  	return x.boundTo.(*dualValue), y.boundTo.(*dualValue), z.boundTo.(*dualValue)
   286  }
   287  
   288  func getConst(x *Node, constant string) (retVal *Node, err error) {
   289  	var dt tensor.Dtype
   290  	if dt, err = dtypeOf(x.t); err != nil {
   291  		return nil, errors.Wrap(err, dtypeOfFail)
   292  	}
   293  
   294  	if m, ok := constmap[constant]; ok {
   295  		if n, ok := m[dt]; ok {
   296  			return n, nil
   297  		}
   298  	}
   299  	return nil, errors.Errorf("constant %v not provided for %v", constant, dt)
   300  }
   301  
   302  func scalarEquiv(s tensor.Shape) bool {
   303  	if len(s) == 0 {
   304  		return true
   305  	}
   306  	prod := 1
   307  	for _, v := range s {
   308  		prod *= v
   309  	}
   310  
   311  	return prod == 1
   312  }