github.com/wzzhu/tensor@v0.9.24/genlib2/testutils.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"text/template"
     7  )
     8  
     9  const anyToF64sRaw = `func anyToFloat64s(x interface{}) (retVal []float64) {
    10  	switch xt := x.(type) {
    11  	{{range  .Kinds -}}
    12  	{{if isNumber . -}}
    13  	case []{{asType .}}:
    14  		{{if eq .String "float64" -}}
    15  		{{else if eq .String "float32" -}}
    16  			retVal = make([]float64, len(xt))
    17  			for i, v := range xt {
    18  				switch {
    19  				case math32.IsNaN(v):
    20  					retVal[i] = math.NaN()
    21  				case math32.IsInf(v, 1):
    22  					retVal[i] = math.Inf(1)
    23  				case math32.IsInf(v, -1):
    24  					retVal[i] = math.Inf(-1)
    25  				default:
    26  					retVal[i] = float64(v)
    27  				}
    28  			}
    29  		{{else if eq .String "complex64" -}}
    30  			retVal = make([]float64, len(xt))
    31  			for i, v := range xt {
    32  				switch {
    33  				case cmplx.IsNaN(complex128(v)):
    34  					retVal[i] = math.NaN()
    35  				case cmplx.IsInf(complex128(v)):
    36  					retVal[i] = math.Inf(1)
    37  				default:
    38  					retVal[i] = float64(real(v))
    39  				}
    40  			}
    41  		{{else if eq .String "complex128" -}}
    42  			retVal = make([]float64, len(xt))
    43  			for i, v := range xt {
    44  				switch {
    45  				case cmplx.IsNaN(v):
    46  					retVal[i] = math.NaN()
    47  				case cmplx.IsInf(v):
    48  					retVal[i] = math.Inf(1)
    49  				default:
    50  					retVal[i] = real(v)
    51  				}
    52  			}
    53  		{{else -}}
    54  			retVal = make([]float64, len(xt))
    55  			for i, v := range xt {
    56  				retVal[i]=  float64(v)
    57  			}
    58  		{{end -}}
    59  		return {{if eq .String "float64"}}xt{{end}}
    60  	{{end -}}
    61  	{{end -}}
    62  	}
    63  	panic("Unreachable")
    64  }
    65  `
    66  
    67  const qcGenraw = `func randomQC(a Tensor, r *rand.Rand) {
    68  	switch a.Dtype() {
    69  	{{range .Kinds -}}
    70  	{{if isParameterized . -}}
    71  	{{else -}}
    72  	case {{reflectKind . -}}:
    73  		s := a.Data().([]{{asType .}})
    74  		for i := range s {
    75  			{{if hasPrefix .String "uint" -}}
    76  				s[i] = {{asType .}}(r.Uint32())
    77  			{{else if hasPrefix .String "int" -}}
    78  				s[i] = {{asType .}}(r.Int())
    79  			{{else if eq .String "float64" -}}
    80  				s[i] = r.Float64()
    81  			{{else if eq .String "float32" -}}
    82  				s[i] = r.Float32()
    83  			{{else if eq .String "complex64" -}}
    84  				s[i] = complex(r.Float32(), r.Float32())
    85  			{{else if eq .String "complex128" -}}
    86  				s[i] = complex(r.Float64(), r.Float64())
    87  			{{else if eq .String "bool" -}}
    88  				s[i] = randomBool()
    89  			{{else if eq .String "string" -}}
    90  				s[i] = randomString()
    91  			{{else if eq .String "unsafe.Pointer" -}}
    92  				s[i] = nil
    93  			{{end -}}	
    94  		}
    95  	{{end -}}
    96  	{{end -}}
    97  	}
    98  }
    99  `
   100  
   101  const testQCRaw = `type QCDense{{short .}} struct {
   102  	*Dense 
   103  }
   104  func (*QCDense{{short .}}) Generate(r *rand.Rand, size int) reflect.Value {
   105  	s := make([]{{asType .}}, size)
   106  	for i := range s {
   107  		{{if hasPrefix .String "uint" -}}
   108  			s[i] = {{asType .}}(r.Uint32())
   109  		{{else if hasPrefix .String "int" -}}
   110  			s[i] = {{asType .}}(r.Int())
   111  		{{else if eq .String "float64" -}}
   112  			s[i] = r.Float64()
   113  		{{else if eq .String "float32" -}}
   114  			s[i] = r.Float32()
   115  		{{else if eq .String "complex64" -}}
   116  			s[i] = complex(r.Float32(), r.Float32())
   117  		{{else if eq .String "complex128" -}}
   118  			s[i] = complex(r.Float64(), r.Float64())
   119  		{{else if eq .String "bool" -}}
   120  			s[i] = randomBool()
   121  		{{else if eq .String "string" -}}
   122  			s[i] = randomString()
   123  		{{else if eq .String "unsafe.Pointer" -}}
   124  			s[i] = nil
   125  		{{end -}}
   126  	}
   127  	d := recycledDense({{asType . | title | strip}}, Shape{size}, WithBacking(s))
   128  	q := new(QCDense{{short .}})
   129  	q.Dense = d
   130  	return reflect.ValueOf(q)
   131  }
   132  `
   133  
   134  const identityFnsRaw = `func identity{{short .}}(a {{asType .}}) {{asType .}}{return a}
   135  `
   136  const mutateFnsRaw = `func mutate{{short .}}(a {{asType . }}){{asType .}} { {{if isNumber . -}}return 1}
   137  {{else if eq .String "bool" -}}return true }
   138  {{else if eq .String "string" -}}return "Hello World"}
   139  {{else if eq .String "uintptr" -}}return 0xdeadbeef}
   140  {{else if eq .String "unsafe.Pointer" -}}return unsafe.Pointer(uintptr(0xdeadbeef))} 
   141  {{end -}} 
   142  `
   143  
   144  const identityValsRaw = `func identityVal(x int, dt Dtype) interface{} {
   145  	switch dt {
   146  		{{range .Kinds -}}
   147  	case {{reflectKind .}}:
   148  		return {{asType .}}(x)
   149  		{{end -}}
   150  	case Complex64:
   151  		var c complex64
   152  		if x == 0 {
   153  			return c
   154  		}
   155  		c = 1
   156  		return c
   157  	case Complex128:
   158  		var c complex128
   159  		if x == 0 {
   160  			return c
   161  		}
   162  		c = 1
   163  		return c
   164  	case Bool:
   165  		if x == 0 {
   166  			return false
   167  		}
   168  		return true
   169  	case String:
   170  		if x == 0 {
   171  			return ""
   172  		}
   173  		return fmt.Sprintf("%v", x)
   174  	default:
   175  		return x
   176  	}
   177  }`
   178  
   179  const threewayEqualityRaw = `func threewayEq(a, b, c interface{}) bool {
   180  	switch at := a.(type){
   181  		{{range .Kinds -}}
   182  	case []{{asType .}}:
   183  		bt := b.([]{{asType .}})
   184  		ct := c.([]{{asType .}})
   185  
   186  		for i, va := range at {
   187  			if va == 1 && bt[i] == 1 {
   188  				if ct[i] != 1 {
   189  					return false
   190  				}
   191  			}
   192  		}
   193  		return true
   194  		{{end -}}
   195  		{{range .Kinds -}}
   196  	case {{asType .}}:
   197  		bt := b.({{asType .}})
   198  		ct := c.({{asType .}})
   199  		if (at == 1 && bt == 1) && ct != 1 {
   200  			return false
   201  		}
   202  		return true
   203  		{{end -}}
   204  	}
   205  
   206  	return false
   207  }
   208  `
   209  
   210  var (
   211  	anyToF64s        *template.Template
   212  	qcGen            *template.Template
   213  	testQC           *template.Template
   214  	identityFns      *template.Template
   215  	mutateFns        *template.Template
   216  	identityVals     *template.Template
   217  	threewayEquality *template.Template
   218  )
   219  
   220  func init() {
   221  	qcGen = template.Must(template.New("QCGen").Funcs(funcs).Parse(qcGenraw))
   222  	testQC = template.Must(template.New("testQCs").Funcs(funcs).Parse(testQCRaw))
   223  	anyToF64s = template.Must(template.New("anyToF64s").Funcs(funcs).Parse(anyToF64sRaw))
   224  	identityFns = template.Must(template.New("identityFn").Funcs(funcs).Parse(identityFnsRaw))
   225  	mutateFns = template.Must(template.New("mutateFns").Funcs(funcs).Parse(mutateFnsRaw))
   226  	identityVals = template.Must(template.New("identityVal").Funcs(funcs).Parse(identityValsRaw))
   227  	threewayEquality = template.Must(template.New("threeway").Funcs(funcs).Parse(threewayEqualityRaw))
   228  }
   229  
   230  func generateTestUtils(f io.Writer, ak Kinds) {
   231  	anyToF64s.Execute(f, ak)
   232  	fmt.Fprintf(f, "\n")
   233  	ak2 := Kinds{Kinds: filter(ak.Kinds, isNonComplexNumber)}
   234  	identityVals.Execute(f, ak2)
   235  	fmt.Fprintf(f, "\n")
   236  	ak3 := Kinds{Kinds: filter(ak.Kinds, isNumber)}
   237  	threewayEquality.Execute(f, ak3)
   238  	fmt.Fprintf(f, "\n")
   239  	for _, k := range ak.Kinds {
   240  		if !isParameterized(k) {
   241  			identityFns.Execute(f, k)
   242  		}
   243  	}
   244  	for _, k := range ak.Kinds {
   245  		if !isParameterized(k) {
   246  			mutateFns.Execute(f, k)
   247  		}
   248  	}
   249  	fmt.Fprintf(f, "\n")
   250  	// for _, k := range ak.Kinds {
   251  	// 	if !isParameterized(k) {
   252  	// 		testQC.Execute(f, k)
   253  	// 		fmt.Fprint(f, "\n")
   254  	// 	}
   255  	// }
   256  
   257  }