gorgonia.org/gorgonia@v0.9.17/utils.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 "hash/fnv" 6 "math" 7 8 "github.com/chewxy/math32" 9 "github.com/pkg/errors" 10 "gonum.org/v1/gonum/graph" 11 "gonum.org/v1/gonum/graph/iterator" 12 "gorgonia.org/tensor" 13 ) 14 15 const ( 16 maxFloat32 = math32.MaxFloat32 17 maxFloat64 = math.MaxFloat64 18 ) 19 20 // NodesToValueGrads is a utility function that converts a Nodes to a slice of ValueGrad for the solvers 21 func NodesToValueGrads(in Nodes) (out []ValueGrad) { 22 out = make([]ValueGrad, len(in)) 23 for i := range in { 24 out[i] = in[i] 25 } 26 return out 27 } 28 29 func graphNodeToNode(in graph.Nodes) (out Nodes) { 30 out = make(Nodes, in.Len()) 31 for i := 0; in.Next(); i++ { 32 out[i] = in.Node().(*Node) 33 } 34 35 return 36 } 37 38 func sliceNodesToNodes(in []graph.Node) (out Nodes) { 39 out = make(Nodes, len(in)) 40 for i := range in { 41 out[i] = in[i].(*Node) 42 } 43 return 44 } 45 46 func nodeToGraphNode(in []*Node) graph.Nodes { 47 nodes := make([]graph.Node, len(in)) 48 for i, n := range in { 49 nodes[i] = n 50 } 51 return iterator.NewOrderedNodes(nodes) 52 } 53 54 func tensorInfo(t tensor.Tensor) (dt tensor.Dtype, dim int) { 55 dt = t.Dtype() 56 dim = t.Dims() 57 return 58 } 59 60 func valueToInt(v Value) (int, error) { 61 var intV int 62 switch sv := v.(type) { 63 case *F64: 64 intV = int(float64(*sv)) 65 case *F32: 66 intV = int(float32(*sv)) 67 case *I: 68 intV = int(*sv) 69 case *I32: 70 intV = int(int32(*sv)) 71 case *I64: 72 intV = int(int64(*sv)) 73 case *U8: 74 intV = int(byte(*sv)) 75 default: 76 return -1, errors.Errorf("Expected values to be all Scalar Value. Got %v of %T instead", v, v) 77 } 78 return intV, nil 79 } 80 81 // valuesToInts will FORCIBLY cast floats to ints. 82 func valuesToInts(values []Value) (retVal []int, err error) { 83 retVal = tensor.BorrowInts(len(values)) 84 for i, v := range values { 85 var intV int 86 switch sv := v.(type) { 87 case *F64: 88 intV = int(float64(*sv)) 89 case *F32: 90 intV = int(float32(*sv)) 91 case *I: 92 intV = int(*sv) 93 case *I32: 94 intV = int(int32(*sv)) 95 case *I64: 96 intV = int(int64(*sv)) 97 case *U8: 98 intV = int(byte(*sv)) 99 case Scalar: 100 return nil, errors.Errorf(nyiTypeFail, "valueToInts", v) 101 default: 102 return nil, errors.Errorf("Expected values to be all Scalar Value. Got %v of %T instead", v, v) 103 104 } 105 retVal[i] = intV 106 } 107 return 108 } 109 110 func valuesToTensors(values []Value) (retVal []tensor.Tensor, err error) { 111 retVal = make([]tensor.Tensor, len(values)) 112 for i, v := range values { 113 if vt, ok := v.(tensor.Tensor); ok { 114 retVal[i] = vt 115 continue 116 } 117 return nil, errors.Errorf("Expected values to all be tensor.Tensor. Got %v of %T in %dth index of the slice", v, v, i) 118 } 119 return 120 } 121 122 func intRange(start, end int) []int { 123 size := end - start 124 incr := true 125 if start > end { 126 incr = false 127 size = start - end 128 } 129 130 if size < 0 { 131 panic("Cannot create an int range that is somehow negative in size") 132 } 133 134 retVal := make([]int, size) 135 136 for i, v := 0, start; i < size; i++ { 137 retVal[i] = v 138 if incr { 139 v++ 140 } else { 141 v-- 142 } 143 } 144 return retVal 145 } 146 147 func ones(dt tensor.Dtype, sizes ...int) (retVal Value) { 148 if len(sizes) == 0 { 149 return one(dt) 150 } 151 return tensor.Ones(dt, sizes...) 152 } 153 154 func hasInf(v Value, dev Device) bool { 155 switch vt := v.(type) { 156 case *F64: 157 return math.IsInf(float64(*vt), 0) 158 case *F32: 159 return math32.IsInf(float32(*vt), 0) 160 case tensor.Tensor: 161 if e, ok := vt.Engine().(tensor.InfChecker); ok { 162 ok, _ := e.HasInf(vt) // BUG: errors not checked 163 return ok 164 } 165 166 dt := vt.Dtype() 167 if dt != tensor.Float64 && dt != tensor.Float32 { 168 return false 169 } 170 switch dt { 171 case tensor.Float32: 172 data := vt.Data().([]float32) 173 for _, datum := range data { 174 if math32.IsInf(datum, 0) { 175 return true 176 } 177 } 178 case tensor.Float64: 179 data := vt.Data().([]float64) 180 for _, datum := range data { 181 if math.IsInf(datum, 0) { 182 return true 183 } 184 } 185 } 186 return false 187 case *dualValue: 188 return hasInf(vt.Value, dev) || hasInf(vt.d, dev) 189 default: 190 err := nyi("hasInf", v) 191 panic(err) 192 } 193 } 194 195 func hasNaN(v Value, dev Device) bool { 196 switch vt := v.(type) { 197 case *F64: 198 return math.IsNaN(float64(*vt)) 199 case *F32: 200 return math32.IsNaN(float32(*vt)) 201 case tensor.Tensor: 202 if e, ok := vt.Engine().(tensor.NaNChecker); ok { 203 ok, _ := e.HasNaN(vt) // BUG: errors not checked 204 return ok 205 } 206 207 dt := vt.Dtype() 208 if dt != tensor.Float64 && dt != tensor.Float32 { 209 return false 210 } 211 212 switch dt { 213 case tensor.Float32: 214 data := vt.Data().([]float32) 215 for _, datum := range data { 216 if math32.IsNaN(datum) { 217 return true 218 } 219 } 220 case tensor.Float64: 221 data := vt.Data().([]float64) 222 for _, datum := range data { 223 if math.IsNaN(datum) { 224 return true 225 } 226 } 227 } 228 return false 229 case *dualValue: 230 return hasNaN(vt.Value, dev) || hasNaN(vt.d, dev) 231 default: 232 err := nyi("hasNaN", vt) 233 panic(err) 234 } 235 } 236 237 func setZero(val Value) (retVal Value) { 238 switch v := val.(type) { 239 case Zeroer: 240 v.Zero() 241 return v 242 case Scalar: 243 return zero(v.Dtype()) 244 default: 245 panic(fmt.Sprintf("setZero not implemented yet for %T", v)) 246 } 247 } 248 249 func checkArity(op arityer, inputs int) error { 250 if inputs != op.Arity() && op.Arity() >= 0 { 251 return errors.Errorf("%v has an arity of %d. Got %d instead", op, op.Arity(), inputs) 252 } 253 return nil 254 } 255 256 func maxInt(a, b int) int { 257 if a > b { 258 return a 259 } 260 return b 261 } 262 263 func minInt(a, b int) int { 264 if a < b { 265 return a 266 } 267 return b 268 } 269 270 func ceilDivInt(a, b int) int { 271 return (a + b - 1) / b 272 } 273 274 func simpleHash(op hashWriter) uint32 { 275 h := fnv.New32a() 276 op.WriteHash(h) 277 return h.Sum32() 278 } 279 280 func getDV(x, y *Node) (xdv, ydv *dualValue) { 281 return x.boundTo.(*dualValue), y.boundTo.(*dualValue) 282 } 283 284 func getDV3(x, y, z *Node) (xdv, ydv, zdv *dualValue) { 285 return x.boundTo.(*dualValue), y.boundTo.(*dualValue), z.boundTo.(*dualValue) 286 } 287 288 func getConst(x *Node, constant string) (retVal *Node, err error) { 289 var dt tensor.Dtype 290 if dt, err = dtypeOf(x.t); err != nil { 291 return nil, errors.Wrap(err, dtypeOfFail) 292 } 293 294 if m, ok := constmap[constant]; ok { 295 if n, ok := m[dt]; ok { 296 return n, nil 297 } 298 } 299 return nil, errors.Errorf("constant %v not provided for %v", constant, dt) 300 } 301 302 func scalarEquiv(s tensor.Shape) bool { 303 if len(s) == 0 { 304 return true 305 } 306 prod := 1 307 for _, v := range s { 308 prod *= v 309 } 310 311 return prod == 1 312 }