gorgonia.org/gorgonia@v0.9.17/equalities.go (about) 1 package gorgonia 2 3 import ( 4 "gorgonia.org/dawson" 5 "gorgonia.org/tensor" 6 ) 7 8 func scalarEq(a, b Scalar) bool { 9 switch at := a.(type) { 10 case *F64: 11 if bt, ok := b.(*F64); ok { 12 if at == bt { 13 return true 14 } 15 return *at == *bt 16 } 17 return false 18 case *F32: 19 if bt, ok := b.(*F32); ok { 20 if at == bt { 21 return true 22 } 23 return *at == *bt 24 } 25 return false 26 case *I: 27 if bt, ok := b.(*I); ok { 28 if at == bt { 29 return true 30 } 31 return *at == *bt 32 } 33 return false 34 case *I32: 35 if bt, ok := b.(*I32); ok { 36 if at == bt { 37 return true 38 } 39 return *at == *bt 40 } 41 return false 42 case *I64: 43 if bt, ok := b.(*I64); ok { 44 if at == bt { 45 return true 46 } 47 return *at == *bt 48 } 49 return false 50 case *U8: 51 if bt, ok := b.(*U8); ok { 52 if at == bt { 53 return true 54 } 55 return *at == *bt 56 } 57 return false 58 case *B: 59 if bt, ok := b.(*B); ok { 60 if at == bt { 61 return true 62 } 63 return *at == *bt 64 } 65 return false 66 } 67 return false 68 } 69 70 func scalarClose(a, b Scalar) bool { 71 switch at := a.(type) { 72 case *F64: 73 if bt, ok := b.(*F64); ok { 74 return dawson.CloseF64(float64(*at), float64(*bt)) 75 } 76 return false 77 case *F32: 78 if bt, ok := b.(*F32); ok { 79 return dawson.CloseF32(float32(*at), float32(*bt)) 80 } 81 return false 82 default: 83 return scalarEq(a, b) 84 } 85 } 86 87 func tensorClose(a, b tensor.Tensor) bool { 88 aDt := a.Dtype() 89 bDt := b.Dtype() 90 if aDt != bDt { 91 return false 92 } 93 94 switch aDt { 95 case tensor.Float64: 96 aFs := a.Data().([]float64) 97 bFs := b.Data().([]float64) 98 if len(aFs) != len(bFs) { 99 return false 100 } 101 aFs = aFs[:] 102 bFs = bFs[:len(aFs)] 103 for i, v := range aFs { 104 if !dawson.CloseF64(v, bFs[i]) { 105 return false 106 } 107 } 108 return true 109 case tensor.Float32: 110 aFs := a.Data().([]float32) 111 bFs := b.Data().([]float32) 112 if len(aFs) != len(bFs) { 113 return false 114 } 115 aFs = aFs[:] 116 bFs = bFs[:len(aFs)] 117 for i, v := range aFs { 118 if !dawson.CloseF32(v, bFs[i]) { 119 return false 120 } 121 } 122 return true 123 default: 124 return a.Eq(b) 125 } 126 127 } 128 129 /* 130 func axesEq(a, b axes) bool { 131 if len(a) != len(b) { 132 return false 133 } 134 135 for i, s := range a { 136 if b[i] != s { 137 return false 138 } 139 } 140 return true 141 } 142 143 // yes it's exactly the same as axesEq 144 func coordEq(a, b coordinates) bool { 145 if len(a) != len(b) { 146 return false 147 } 148 149 for i, s := range a { 150 if b[i] != s { 151 return false 152 } 153 } 154 return true 155 } 156 */ 157 158 func constEq(a, b constant) (ok bool) { 159 switch at := a.(type) { 160 case constantScalar: 161 var bt constantScalar 162 if bt, ok = b.(constantScalar); !ok { 163 return 164 } 165 166 return bt == at 167 case constantTensor: 168 var bt constantTensor 169 if bt, ok = b.(constantTensor); !ok { 170 return 171 } 172 return at.v.Eq(bt.v) 173 default: 174 panic("Not yet implemented") 175 } 176 } 177 178 // fastest comparisons to least fastest 179 func nodeEq(a, b *Node) bool { 180 if a == b { 181 return true 182 } 183 184 if a.isInput() { 185 if !b.isInput() { 186 return false 187 } 188 return a.name == b.name 189 } 190 191 if b.isInput() { 192 return false 193 } 194 195 // hashcode is good for comparing Op (TODO: benchmark this vs reflect.DeepEq) 196 if a.op.Hashcode() != b.op.Hashcode() { 197 return false 198 } 199 200 if len(a.children) != len(b.children) { 201 return false 202 } 203 204 if a.t != b.t { 205 return false 206 } 207 208 if !a.shape.Eq(b.shape) { 209 return false 210 } 211 212 return true 213 }