github.com/wzzhu/tensor@v0.9.24/genlib2/native_select.go (about) 1 package main 2 3 import ( 4 "fmt" 5 "io" 6 "text/template" 7 ) 8 9 const checkNativeSelectable = `func checkNativeSelectable(t *Dense, axis int, dt Dtype) error { 10 if !t.IsNativelyAccessible() { 11 return errors.New("Cannot select on non-natively accessible data") 12 } 13 if axis >= t.Shape().Dims() && !(t.IsScalar() && axis == 0) { 14 return errors.Errorf("Cannot select on axis %d. Shape is %v", axis, t.Shape()) 15 } 16 if t.F() || t.RequiresIterator() { 17 return errors.Errorf("Not yet implemented: native select for colmajor or unpacked matrices") 18 } 19 if t.Dtype() != dt { 20 return errors.Errorf("Native selection only works on %v. Got %v", dt, t.Dtype()) 21 } 22 return nil 23 } 24 ` 25 const nativeSelectRaw = `// Select{{short .}} creates a slice of flat data types. See Example of NativeSelectF64. 26 func Select{{short .}}(t *Dense, axis int) (retVal [][]{{asType .}}, err error) { 27 if err := checkNativeSelectable(t, axis, {{reflectKind .}}); err != nil { 28 return nil, err 29 } 30 31 switch t.Shape().Dims() { 32 case 0, 1: 33 retVal = make([][]{{asType .}}, 1) 34 retVal[0] = t.{{sliceOf .}} 35 case 2: 36 if axis == 0 { 37 return Matrix{{short .}}(t) 38 } 39 fallthrough 40 default: 41 // size := t.Shape()[axis] 42 data := t.{{sliceOf .}} 43 stride := t.Strides()[axis] 44 upper := ProdInts(t.Shape()[:axis+1]) 45 retVal = make([][]{{asType .}}, 0, upper) 46 for i, r := 0, 0; r < upper; i += stride { 47 s := make([]{{asType .}}, 0) 48 hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) 49 hdr.Data = uintptr(unsafe.Pointer(&data[i])) 50 hdr.Len = stride 51 hdr.Cap = stride 52 retVal = append(retVal, s) 53 r++ 54 } 55 return retVal, nil 56 57 } 58 return 59 } 60 ` 61 const nativeSelectTestRaw = `func TestSelect{{short .}}(t *testing.T) { 62 assert := assert.New(t) 63 var T *Dense 64 var err error 65 var x [][]{{asType .}} 66 T = New(Of({{reflectKind .}}), WithShape(2, 3, 4, 5), ) 67 if x, err = Select{{short .}}(T, 1); err != nil { 68 t.Fatal(err) 69 } 70 assert.Equal(6, len(x)) 71 assert.Equal(20, len(x[0])) 72 73 T = New(Of({{reflectKind .}}), WithShape(2, 3, 4, 5), ) 74 if x, err = Select{{short .}}(T, 0); err != nil { 75 t.Fatal(err) 76 } 77 assert.Equal(2, len(x)) 78 assert.Equal(60, len(x[0])) 79 80 T = New(Of({{reflectKind .}}), WithShape(2, 3, 4, 5), ) 81 if x, err = Select{{short .}}(T, 3); err != nil { 82 t.Fatal(err) 83 } 84 assert.Equal(120, len(x)) 85 assert.Equal(1, len(x[0])) 86 87 T = New(Of({{reflectKind .}}), WithShape(2, 3), ) 88 if x, err = Select{{short .}}(T, 0); err != nil { 89 t.Fatal(err) 90 } 91 assert.Equal(2, len(x)) 92 assert.Equal(3, len(x[0])) 93 94 T = New(Of({{reflectKind .}}), WithShape(2, 3), ) 95 if x, err = Select{{short .}}(T, 1); err != nil { 96 t.Fatal(err) 97 } 98 assert.Equal(6, len(x)) 99 assert.Equal(1, len(x[0])) 100 101 T = New(FromScalar({{if eq .String "bool" -}}false{{else if eq .String "string" -}}""{{else -}}{{asType .}}(0) {{end -}} )) 102 if x, err = Select{{short .}}(T, 0); err != nil { 103 t.Fatal(err) 104 } 105 assert.Equal(1, len(x)) 106 assert.Equal(1, len(x[0])) 107 108 if _, err = Select{{short .}}(T, 10); err == nil{ 109 t.Fatal("Expected errors") 110 } 111 } 112 ` 113 114 var ( 115 NativeSelect *template.Template 116 NativeSelectTest *template.Template 117 ) 118 119 func init() { 120 NativeSelect = template.Must(template.New("NativeSelect").Funcs(funcs).Parse(nativeSelectRaw)) 121 NativeSelectTest = template.Must(template.New("NativeSelectTest").Funcs(funcs).Parse(nativeSelectTestRaw)) 122 } 123 124 func generateNativeSelect(f io.Writer, ak Kinds) { 125 fmt.Fprintf(f, importUnqualifiedTensor) 126 fmt.Fprintf(f, "%v\n", checkNativeSelectable) 127 ks := filter(ak.Kinds, isSpecialized) 128 for _, k := range ks { 129 fmt.Fprintf(f, "/* Native Select for %v */\n\n", k) 130 NativeSelect.Execute(f, k) 131 fmt.Fprint(f, "\n\n") 132 } 133 } 134 135 func generateNativeSelectTests(f io.Writer, ak Kinds) { 136 fmt.Fprintf(f, importUnqualifiedTensor) 137 ks := filter(ak.Kinds, isSpecialized) 138 for _, k := range ks { 139 NativeSelectTest.Execute(f, k) 140 fmt.Fprint(f, "\n\n") 141 } 142 }