gorgonia.org/gorgonia@v0.9.17/values_utils.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 6 "github.com/chewxy/hm" 7 "github.com/pkg/errors" 8 "gorgonia.org/tensor" 9 ) 10 11 // TypeOf returns the Type of the value 12 func TypeOf(v Value) hm.Type { 13 switch t := v.(type) { 14 case tensor.Tensor: 15 dt, dim := tensorInfo(t) 16 return makeTensorType(dim, dt) 17 case Scalar: 18 return t.Dtype() 19 case Typer: 20 return t.Type() 21 22 default: 23 panic(fmt.Sprintf("TypeOf Not yet implemented for %v %T", v, v)) 24 } 25 } 26 27 func typeCheckTypeOf(v Value) hm.Type { 28 switch t := v.(type) { 29 case tensor.Tensor: 30 dt, dim := tensorInfo(t) 31 return newTensorType(dim, dt) 32 case Scalar: 33 return t.Dtype() 34 case Typer: 35 return t.Type() 36 37 default: 38 panic(fmt.Sprintf("TypeOf Not yet implemented for %v %T", v, v)) 39 } 40 } 41 42 // ValueEq is the equality function for values 43 func ValueEq(a, b Value) bool { 44 if a == nil && b == nil { 45 return true 46 } 47 switch at := a.(type) { 48 case Scalar: 49 if bt, ok := b.(Scalar); ok { 50 return scalarEq(at, bt) 51 } 52 return false 53 case tensor.Tensor: 54 if bt, ok := b.(tensor.Tensor); ok { 55 return at.Eq(bt) 56 //log.Printf("at.info %#v, bt.info %#v", a.(*tensor.Dense).Info(), b.(*tensor.Dense).Info()) 57 } 58 return false 59 case ValueEqualer: 60 return at.ValueEq(b) 61 default: 62 panic(fmt.Sprintf("Not implemented yet, %T", a)) 63 } 64 } 65 66 // ValueClose checks whether two values are close to one another. It's predominantly used as an alternative equality test for floats 67 func ValueClose(a, b Value) bool { 68 if a == nil && b == nil { 69 return true 70 } 71 72 switch at := a.(type) { 73 case Scalar: 74 if bt, ok := b.(Scalar); ok { 75 return scalarClose(at, bt) 76 } 77 return false 78 case tensor.Tensor: 79 if bt, ok := b.(tensor.Tensor); ok { 80 return tensorClose(at, bt) 81 } 82 return false 83 case ValueCloser: 84 return at.ValueClose(b) 85 default: 86 panic("Not implemented yet") 87 } 88 } 89 90 // CloneValue clones a value. For scalars, since Go copies scalars, it returns itself 91 func CloneValue(v Value) (Value, error) { 92 switch vt := v.(type) { 93 case *F64: 94 retVal := *vt 95 return &retVal, nil 96 case *F32: 97 retVal := *vt 98 return &retVal, nil 99 case *I: 100 retVal := *vt 101 return &retVal, nil 102 case *I32: 103 retVal := *vt 104 return &retVal, nil 105 case *I64: 106 retVal := *vt 107 return &retVal, nil 108 case *U8: 109 retVal := *vt 110 return &retVal, nil 111 case *B: 112 retVal := *vt 113 return &retVal, nil 114 case tensor.Tensor: 115 return vt.Clone().(*tensor.Dense), nil 116 case CloneErrorer: 117 ret, err := vt.Clone() 118 if err != nil { 119 return nil, err 120 } 121 retVal, ok := ret.(Value) 122 if !ok { 123 return nil, errors.Errorf("Cloner is not a value: %v %T", v, v) 124 } 125 return retVal, nil 126 case Cloner: 127 return vt.Clone().(Value), nil 128 default: 129 return nil, errors.Errorf("Unable to clone value of type %T", v) 130 } 131 } 132 133 // ZeroValue returns the zero value of a type 134 func ZeroValue(v Value) Value { 135 switch vt := v.(type) { 136 case *F64: 137 *vt = 0 138 return vt 139 case *F32: 140 *vt = 0 141 return vt 142 case *I: 143 *vt = 0 144 return vt 145 case *I32: 146 *vt = 0 147 return vt 148 case *I64: 149 *vt = 0 150 return vt 151 case *U8: 152 *vt = 0 153 return vt 154 case *B: 155 *vt = false 156 return vt 157 case tensor.Tensor: 158 vt.Zero() 159 return vt 160 case ZeroValuer: 161 return vt.ZeroValue() 162 default: 163 panic(fmt.Sprintf("Cannot return zero value of %T", v)) 164 } 165 } 166 167 // Copy copies the src values into dest values. For scalars, it just returns itself 168 func Copy(dest, src Value) (Value, error) { 169 var ok bool 170 switch srcT := src.(type) { 171 case *F64: 172 var destS *F64 173 if destS, ok = dest.(*F64); !ok { 174 return nil, errors.Errorf("Expected dest to be *F64. Got %T instead", dest) 175 } 176 *destS = *srcT 177 return destS, nil 178 case *F32: 179 var destS *F32 180 if destS, ok = dest.(*F32); !ok { 181 return nil, errors.Errorf("Expected dest to be *F32. Got %T instead", dest) 182 } 183 *destS = *srcT 184 return destS, nil 185 case *I: 186 var destS *I 187 if destS, ok = dest.(*I); !ok { 188 return nil, errors.Errorf("Expected dest to be *I) . Got %T instead", dest) 189 } 190 *destS = *srcT 191 return destS, nil 192 case *I64: 193 var destS *I64 194 if destS, ok = dest.(*I64); !ok { 195 return nil, errors.Errorf("Expected dest to be *I64. Got %T instead", dest) 196 } 197 *destS = *srcT 198 return destS, nil 199 case *I32: 200 var destS *I32 201 if destS, ok = dest.(*I32); !ok { 202 return nil, errors.Errorf("Expected dest to be *I32. Got %T instead", dest) 203 } 204 *destS = *srcT 205 return destS, nil 206 case *U8: 207 var destS *U8 208 if destS, ok = dest.(*U8); !ok { 209 return nil, errors.Errorf("Expected dest to be *U8). Got %T instead", dest) 210 } 211 *destS = *srcT 212 return destS, nil 213 case *B: 214 var destS *B 215 if destS, ok = dest.(*B); !ok { 216 return nil, errors.Errorf("Expected dest to be *B) . Got %T instead", dest) 217 } 218 *destS = *srcT 219 return destS, nil 220 case tensor.Tensor: 221 var destT tensor.Tensor 222 if destT, ok = dest.(tensor.Tensor); !ok { 223 return nil, errors.Errorf("Expected dest to be a tensor.Tensor. Got %T instead", dest) 224 } 225 err := tensor.Copy(destT, srcT) 226 return dest, err 227 case CopierTo: 228 err := srcT.CopyTo(dest) 229 return dest, err 230 default: 231 var copyFrom CopierFrom 232 if copyFrom, ok = dest.(CopierFrom); ok { 233 err := copyFrom.CopyFrom(src) 234 return dest, err 235 } 236 return nil, errors.Errorf("Unable to copy value of type %T into value of type %T", src, dest) 237 } 238 } 239 240 func setEngine(v Value, e tensor.Engine) { 241 switch vv := v.(type) { 242 case *dualValue: 243 setEngine(vv.Value, e) 244 setEngine(vv.d, e) 245 case tensor.Tensor: 246 tensor.WithEngine(e)(vv) 247 } 248 }