github.com/wzzhu/tensor@v0.9.24/genlib2/native_iterator.go (about) 1 package main 2 3 import ( 4 "fmt" 5 "io" 6 "text/template" 7 ) 8 9 const checkNativeiterable = `func checkNativeIterable(t *Dense, dims int, dt Dtype) error { 10 // checks: 11 if !t.IsNativelyAccessible() { 12 return errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible") 13 } 14 15 if t.Shape().Dims() != dims { 16 return errors.Errorf("Cannot convert *Dense to native iterator. Expected number of dimension: %d, T has got %d dimensions (Shape: %v)", dims, t.Dims(), t.Shape()) 17 } 18 19 if t.F() || t.RequiresIterator() { 20 return errors.Errorf("Not yet implemented: native matrix for colmajor or unpacked matrices") 21 } 22 23 if t.Dtype() != dt { 24 return errors.Errorf("Conversion to native iterable only works on %v. Got %v", dt, t.Dtype()) 25 } 26 27 return nil 28 } 29 ` 30 31 const nativeIterRaw = `// Vector{{short .}} converts a *Dense into a []{{asType .}} 32 // If the *Dense does not represent a vector of the wanted type, it will return 33 // an error. 34 func Vector{{short .}}(t *Dense) (retVal []{{asType .}}, err error) { 35 if err = checkNativeIterable(t, 1, {{reflectKind .}}); err != nil { 36 return nil, err 37 } 38 return t.{{sliceOf .}}, nil 39 } 40 41 // Matrix{{short .}} converts a *Dense into a [][]{{asType .}} 42 // If the *Dense does not represent a matrix of the wanted type, it 43 // will return an error. 44 func Matrix{{short .}}(t *Dense) (retVal [][]{{asType .}}, err error) { 45 if err = checkNativeIterable(t, 2, {{reflectKind .}}); err != nil { 46 return nil, err 47 } 48 49 data := t.{{sliceOf .}} 50 shape := t.Shape() 51 strides := t.Strides() 52 53 rows := shape[0] 54 cols := shape[1] 55 rowStride := strides[0] 56 retVal = make([][]{{asType .}}, rows) 57 for i := range retVal { 58 start := i * rowStride 59 retVal[i] = make([]{{asType .}}, 0) 60 hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i])) 61 hdr.Data = uintptr(unsafe.Pointer(&data[start])) 62 hdr.Cap = cols 63 hdr.Len = cols 64 } 65 return 66 } 67 68 // Tensor3{{short .}} converts a *Dense into a [][][]{{asType .}}. 69 // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error. 70 func Tensor3{{short .}}(t *Dense) (retVal [][][]{{asType .}}, err error) { 71 if err = checkNativeIterable(t, 3, {{reflectKind .}}); err != nil { 72 return nil, err 73 } 74 75 data := t.{{sliceOf .}} 76 shape := t.Shape() 77 strides := t.Strides() 78 79 layers := shape[0] 80 rows := shape[1] 81 cols := shape[2] 82 layerStride := strides[0] 83 rowStride := strides[1] 84 retVal = make([][][]{{asType .}}, layers) 85 for i := range retVal { 86 retVal[i] = make([][]{{asType .}}, rows) 87 for j := range retVal[i] { 88 retVal[i][j] = make([]{{asType .}}, 0) 89 start := i*layerStride + j*rowStride 90 hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j])) 91 hdr.Data = uintptr(unsafe.Pointer(&data[start])) 92 hdr.Cap = cols 93 hdr.Len = cols 94 } 95 } 96 return 97 } 98 ` 99 100 const nativeIterTestRaw = `func Test_Vector{{short .}}(t *testing.T) { 101 assert := assert.New(t) 102 var T *Dense 103 {{if isRangeable . -}} 104 T = New(WithBacking(Range({{reflectKind .}}, 0, 6)), WithShape(6)) 105 {{else -}} 106 T = New(Of({{reflectKind .}}), WithShape(6)) 107 {{end -}} 108 it, err := Vector{{short .}}(T) 109 if err != nil { 110 t.Fatal(err) 111 } 112 113 assert.Equal(6, len(it)) 114 } 115 116 func Test_Matrix{{short .}}(t *testing.T) { 117 assert := assert.New(t) 118 var T *Dense 119 {{if isRangeable . -}} 120 T = New(WithBacking(Range({{reflectKind .}}, 0, 6)), WithShape(2, 3)) 121 {{else -}} 122 T = New(Of({{reflectKind .}}), WithShape(2, 3)) 123 {{end -}} 124 it, err := Matrix{{short .}}(T) 125 if err != nil { 126 t.Fatal(err) 127 } 128 129 assert.Equal(2, len(it)) 130 assert.Equal(3, len(it[0])) 131 } 132 133 func Test_Tensor3{{short .}}(t *testing.T) { 134 assert := assert.New(t) 135 var T *Dense 136 {{if isRangeable . -}} 137 T = New(WithBacking(Range({{reflectKind .}}, 0, 24)), WithShape(2, 3, 4)) 138 {{else -}} 139 T = New(Of({{reflectKind .}}), WithShape(2, 3, 4)) 140 {{end -}} 141 it, err := Tensor3{{short .}}(T) 142 if err != nil { 143 t.Fatal(err) 144 } 145 146 assert.Equal(2, len(it)) 147 assert.Equal(3, len(it[0])) 148 assert.Equal(4, len(it[0][0])) 149 } 150 ` 151 152 var ( 153 NativeIter *template.Template 154 NativeIterTest *template.Template 155 ) 156 157 func init() { 158 NativeIter = template.Must(template.New("NativeIter").Funcs(funcs).Parse(nativeIterRaw)) 159 NativeIterTest = template.Must(template.New("NativeIterTest").Funcs(funcs).Parse(nativeIterTestRaw)) 160 } 161 162 func generateNativeIterators(f io.Writer, ak Kinds) { 163 fmt.Fprintf(f, importUnqualifiedTensor) 164 fmt.Fprintf(f, "%v\n", checkNativeiterable) 165 ks := filter(ak.Kinds, isSpecialized) 166 for _, k := range ks { 167 fmt.Fprintf(f, "/* Native Iterables for %v */\n\n", k) 168 NativeIter.Execute(f, k) 169 fmt.Fprint(f, "\n\n") 170 } 171 } 172 173 func generateNativeIteratorTests(f io.Writer, ak Kinds) { 174 fmt.Fprintf(f, importUnqualifiedTensor) 175 ks := filter(ak.Kinds, isSpecialized) 176 for _, k := range ks { 177 NativeIterTest.Execute(f, k) 178 fmt.Fprint(f, "\n\n") 179 } 180 }