github.com/wzzhu/tensor@v0.9.24/genlib2/testutils.go (about) 1 package main 2 3 import ( 4 "fmt" 5 "io" 6 "text/template" 7 ) 8 9 const anyToF64sRaw = `func anyToFloat64s(x interface{}) (retVal []float64) { 10 switch xt := x.(type) { 11 {{range .Kinds -}} 12 {{if isNumber . -}} 13 case []{{asType .}}: 14 {{if eq .String "float64" -}} 15 {{else if eq .String "float32" -}} 16 retVal = make([]float64, len(xt)) 17 for i, v := range xt { 18 switch { 19 case math32.IsNaN(v): 20 retVal[i] = math.NaN() 21 case math32.IsInf(v, 1): 22 retVal[i] = math.Inf(1) 23 case math32.IsInf(v, -1): 24 retVal[i] = math.Inf(-1) 25 default: 26 retVal[i] = float64(v) 27 } 28 } 29 {{else if eq .String "complex64" -}} 30 retVal = make([]float64, len(xt)) 31 for i, v := range xt { 32 switch { 33 case cmplx.IsNaN(complex128(v)): 34 retVal[i] = math.NaN() 35 case cmplx.IsInf(complex128(v)): 36 retVal[i] = math.Inf(1) 37 default: 38 retVal[i] = float64(real(v)) 39 } 40 } 41 {{else if eq .String "complex128" -}} 42 retVal = make([]float64, len(xt)) 43 for i, v := range xt { 44 switch { 45 case cmplx.IsNaN(v): 46 retVal[i] = math.NaN() 47 case cmplx.IsInf(v): 48 retVal[i] = math.Inf(1) 49 default: 50 retVal[i] = real(v) 51 } 52 } 53 {{else -}} 54 retVal = make([]float64, len(xt)) 55 for i, v := range xt { 56 retVal[i]= float64(v) 57 } 58 {{end -}} 59 return {{if eq .String "float64"}}xt{{end}} 60 {{end -}} 61 {{end -}} 62 } 63 panic("Unreachable") 64 } 65 ` 66 67 const qcGenraw = `func randomQC(a Tensor, r *rand.Rand) { 68 switch a.Dtype() { 69 {{range .Kinds -}} 70 {{if isParameterized . -}} 71 {{else -}} 72 case {{reflectKind . -}}: 73 s := a.Data().([]{{asType .}}) 74 for i := range s { 75 {{if hasPrefix .String "uint" -}} 76 s[i] = {{asType .}}(r.Uint32()) 77 {{else if hasPrefix .String "int" -}} 78 s[i] = {{asType .}}(r.Int()) 79 {{else if eq .String "float64" -}} 80 s[i] = r.Float64() 81 {{else if eq .String "float32" -}} 82 s[i] = r.Float32() 83 {{else if eq .String "complex64" -}} 84 s[i] = complex(r.Float32(), r.Float32()) 85 {{else if eq .String "complex128" -}} 86 s[i] = complex(r.Float64(), r.Float64()) 87 {{else if eq .String "bool" -}} 88 s[i] = randomBool() 89 {{else if eq .String "string" -}} 90 s[i] = randomString() 91 {{else if eq .String "unsafe.Pointer" -}} 92 s[i] = nil 93 {{end -}} 94 } 95 {{end -}} 96 {{end -}} 97 } 98 } 99 ` 100 101 const testQCRaw = `type QCDense{{short .}} struct { 102 *Dense 103 } 104 func (*QCDense{{short .}}) Generate(r *rand.Rand, size int) reflect.Value { 105 s := make([]{{asType .}}, size) 106 for i := range s { 107 {{if hasPrefix .String "uint" -}} 108 s[i] = {{asType .}}(r.Uint32()) 109 {{else if hasPrefix .String "int" -}} 110 s[i] = {{asType .}}(r.Int()) 111 {{else if eq .String "float64" -}} 112 s[i] = r.Float64() 113 {{else if eq .String "float32" -}} 114 s[i] = r.Float32() 115 {{else if eq .String "complex64" -}} 116 s[i] = complex(r.Float32(), r.Float32()) 117 {{else if eq .String "complex128" -}} 118 s[i] = complex(r.Float64(), r.Float64()) 119 {{else if eq .String "bool" -}} 120 s[i] = randomBool() 121 {{else if eq .String "string" -}} 122 s[i] = randomString() 123 {{else if eq .String "unsafe.Pointer" -}} 124 s[i] = nil 125 {{end -}} 126 } 127 d := recycledDense({{asType . | title | strip}}, Shape{size}, WithBacking(s)) 128 q := new(QCDense{{short .}}) 129 q.Dense = d 130 return reflect.ValueOf(q) 131 } 132 ` 133 134 const identityFnsRaw = `func identity{{short .}}(a {{asType .}}) {{asType .}}{return a} 135 ` 136 const mutateFnsRaw = `func mutate{{short .}}(a {{asType . }}){{asType .}} { {{if isNumber . -}}return 1} 137 {{else if eq .String "bool" -}}return true } 138 {{else if eq .String "string" -}}return "Hello World"} 139 {{else if eq .String "uintptr" -}}return 0xdeadbeef} 140 {{else if eq .String "unsafe.Pointer" -}}return unsafe.Pointer(uintptr(0xdeadbeef))} 141 {{end -}} 142 ` 143 144 const identityValsRaw = `func identityVal(x int, dt Dtype) interface{} { 145 switch dt { 146 {{range .Kinds -}} 147 case {{reflectKind .}}: 148 return {{asType .}}(x) 149 {{end -}} 150 case Complex64: 151 var c complex64 152 if x == 0 { 153 return c 154 } 155 c = 1 156 return c 157 case Complex128: 158 var c complex128 159 if x == 0 { 160 return c 161 } 162 c = 1 163 return c 164 case Bool: 165 if x == 0 { 166 return false 167 } 168 return true 169 case String: 170 if x == 0 { 171 return "" 172 } 173 return fmt.Sprintf("%v", x) 174 default: 175 return x 176 } 177 }` 178 179 const threewayEqualityRaw = `func threewayEq(a, b, c interface{}) bool { 180 switch at := a.(type){ 181 {{range .Kinds -}} 182 case []{{asType .}}: 183 bt := b.([]{{asType .}}) 184 ct := c.([]{{asType .}}) 185 186 for i, va := range at { 187 if va == 1 && bt[i] == 1 { 188 if ct[i] != 1 { 189 return false 190 } 191 } 192 } 193 return true 194 {{end -}} 195 {{range .Kinds -}} 196 case {{asType .}}: 197 bt := b.({{asType .}}) 198 ct := c.({{asType .}}) 199 if (at == 1 && bt == 1) && ct != 1 { 200 return false 201 } 202 return true 203 {{end -}} 204 } 205 206 return false 207 } 208 ` 209 210 var ( 211 anyToF64s *template.Template 212 qcGen *template.Template 213 testQC *template.Template 214 identityFns *template.Template 215 mutateFns *template.Template 216 identityVals *template.Template 217 threewayEquality *template.Template 218 ) 219 220 func init() { 221 qcGen = template.Must(template.New("QCGen").Funcs(funcs).Parse(qcGenraw)) 222 testQC = template.Must(template.New("testQCs").Funcs(funcs).Parse(testQCRaw)) 223 anyToF64s = template.Must(template.New("anyToF64s").Funcs(funcs).Parse(anyToF64sRaw)) 224 identityFns = template.Must(template.New("identityFn").Funcs(funcs).Parse(identityFnsRaw)) 225 mutateFns = template.Must(template.New("mutateFns").Funcs(funcs).Parse(mutateFnsRaw)) 226 identityVals = template.Must(template.New("identityVal").Funcs(funcs).Parse(identityValsRaw)) 227 threewayEquality = template.Must(template.New("threeway").Funcs(funcs).Parse(threewayEqualityRaw)) 228 } 229 230 func generateTestUtils(f io.Writer, ak Kinds) { 231 anyToF64s.Execute(f, ak) 232 fmt.Fprintf(f, "\n") 233 ak2 := Kinds{Kinds: filter(ak.Kinds, isNonComplexNumber)} 234 identityVals.Execute(f, ak2) 235 fmt.Fprintf(f, "\n") 236 ak3 := Kinds{Kinds: filter(ak.Kinds, isNumber)} 237 threewayEquality.Execute(f, ak3) 238 fmt.Fprintf(f, "\n") 239 for _, k := range ak.Kinds { 240 if !isParameterized(k) { 241 identityFns.Execute(f, k) 242 } 243 } 244 for _, k := range ak.Kinds { 245 if !isParameterized(k) { 246 mutateFns.Execute(f, k) 247 } 248 } 249 fmt.Fprintf(f, "\n") 250 // for _, k := range ak.Kinds { 251 // if !isParameterized(k) { 252 // testQC.Execute(f, k) 253 // fmt.Fprint(f, "\n") 254 // } 255 // } 256 257 }