github.com/wzzhu/tensor@v0.9.24/dense_apply_test.go (about) 1 package tensor 2 3 import ( 4 "math/rand" 5 "testing" 6 "testing/quick" 7 "time" 8 "unsafe" 9 ) 10 11 func getMutateVal(dt Dtype) interface{} { 12 switch dt { 13 case Int: 14 return int(1) 15 case Int8: 16 return int8(1) 17 case Int16: 18 return int16(1) 19 case Int32: 20 return int32(1) 21 case Int64: 22 return int64(1) 23 case Uint: 24 return uint(1) 25 case Uint8: 26 return uint8(1) 27 case Uint16: 28 return uint16(1) 29 case Uint32: 30 return uint32(1) 31 case Uint64: 32 return uint64(1) 33 case Float32: 34 return float32(1) 35 case Float64: 36 return float64(1) 37 case Complex64: 38 var c complex64 = 1 39 return c 40 case Complex128: 41 var c complex128 = 1 42 return c 43 case Bool: 44 return true 45 case String: 46 return "Hello World" 47 case Uintptr: 48 return uintptr(0xdeadbeef) 49 case UnsafePointer: 50 return unsafe.Pointer(uintptr(0xdeadbeef)) 51 } 52 return nil 53 } 54 55 func getMutateFn(dt Dtype) interface{} { 56 switch dt { 57 case Int: 58 return mutateI 59 case Int8: 60 return mutateI8 61 case Int16: 62 return mutateI16 63 case Int32: 64 return mutateI32 65 case Int64: 66 return mutateI64 67 case Uint: 68 return mutateU 69 case Uint8: 70 return mutateU8 71 case Uint16: 72 return mutateU16 73 case Uint32: 74 return mutateU32 75 case Uint64: 76 return mutateU64 77 case Float32: 78 return mutateF32 79 case Float64: 80 return mutateF64 81 case Complex64: 82 return mutateC64 83 case Complex128: 84 return mutateC128 85 case Bool: 86 return mutateB 87 case String: 88 return mutateStr 89 case Uintptr: 90 return mutateUintptr 91 case UnsafePointer: 92 return mutateUnsafePointer 93 } 94 return nil 95 } 96 97 func TestDense_Apply(t *testing.T) { 98 var r *rand.Rand 99 mut := func(q *Dense) bool { 100 var mutVal interface{} 101 if mutVal = getMutateVal(q.Dtype()); mutVal == nil { 102 return true // we'll temporarily skip those we cannot mutate/get a mutation value 103 } 104 var fn interface{} 105 if fn = getMutateFn(q.Dtype()); fn == nil { 106 return true // we'll skip those that we cannot mutate 107 } 108 109 we, eqFail := willerr(q, nil, nil) 110 _, ok := q.Engine().(Mapper) 111 we = we || !ok 112 113 a := q.Clone().(*Dense) 114 correct := q.Clone().(*Dense) 115 correct.Memset(mutVal) 116 ret, err := a.Apply(fn) 117 if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly { 118 if err != nil { 119 return false 120 } 121 return true 122 } 123 if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) { 124 return false 125 } 126 127 // wrong fn type/illogical values 128 if _, err = a.Apply(getMutateFn); err == nil { 129 t.Error("Expected an error") 130 return false 131 } 132 return true 133 } 134 r = rand.New(rand.NewSource(time.Now().UnixNano())) 135 if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil { 136 t.Errorf("Applying mutation function failed %v", err) 137 } 138 } 139 140 func TestDense_Apply_unsafe(t *testing.T) { 141 var r *rand.Rand 142 mut := func(q *Dense) bool { 143 var mutVal interface{} 144 if mutVal = getMutateVal(q.Dtype()); mutVal == nil { 145 return true // we'll temporarily skip those we cannot mutate/get a mutation value 146 } 147 var fn interface{} 148 if fn = getMutateFn(q.Dtype()); fn == nil { 149 return true // we'll skip those that we cannot mutate 150 } 151 152 we, eqFail := willerr(q, nil, nil) 153 _, ok := q.Engine().(Mapper) 154 we = we || !ok 155 156 a := q.Clone().(*Dense) 157 correct := q.Clone().(*Dense) 158 correct.Memset(mutVal) 159 ret, err := a.Apply(fn, UseUnsafe()) 160 if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly { 161 if err != nil { 162 return false 163 } 164 return true 165 } 166 if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) { 167 return false 168 } 169 if ret != a { 170 t.Error("Expected ret == correct (Unsafe option was used)") 171 return false 172 } 173 return true 174 } 175 r = rand.New(rand.NewSource(time.Now().UnixNano())) 176 if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil { 177 t.Errorf("Applying mutation function failed %v", err) 178 } 179 } 180 181 func TestDense_Apply_reuse(t *testing.T) { 182 var r *rand.Rand 183 mut := func(q *Dense) bool { 184 var mutVal interface{} 185 if mutVal = getMutateVal(q.Dtype()); mutVal == nil { 186 return true // we'll temporarily skip those we cannot mutate/get a mutation value 187 } 188 var fn interface{} 189 if fn = getMutateFn(q.Dtype()); fn == nil { 190 return true // we'll skip those that we cannot mutate 191 } 192 193 we, eqFail := willerr(q, nil, nil) 194 _, ok := q.Engine().(Mapper) 195 we = we || !ok 196 197 a := q.Clone().(*Dense) 198 reuse := q.Clone().(*Dense) 199 reuse.Zero() 200 correct := q.Clone().(*Dense) 201 correct.Memset(mutVal) 202 ret, err := a.Apply(fn, WithReuse(reuse)) 203 if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly { 204 if err != nil { 205 return false 206 } 207 return true 208 } 209 if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) { 210 return false 211 } 212 if ret != reuse { 213 t.Error("Expected ret == correct (Unsafe option was used)") 214 return false 215 } 216 return true 217 } 218 r = rand.New(rand.NewSource(time.Now().UnixNano())) 219 if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil { 220 t.Errorf("Applying mutation function failed %v", err) 221 } 222 }