github.com/wzzhu/tensor@v0.9.24/native/generic.go (about) 1 package native 2 3 import ( 4 "reflect" 5 "unsafe" 6 7 . "github.com/wzzhu/tensor" 8 ) 9 10 func Vector(t *Dense) (interface{}, error) { 11 if err := checkNativeIterable(t, 1, t.Dtype()); err != nil { 12 return nil, err 13 } 14 return t.Data(), nil 15 } 16 17 func Matrix(t *Dense) (interface{}, error) { 18 if err := checkNativeIterable(t, 2, t.Dtype()); err != nil { 19 return nil, err 20 } 21 22 shape := t.Shape() 23 strides := t.Strides() 24 typ := t.Dtype().Type 25 rows := shape[0] 26 cols := shape[1] 27 rowStride := strides[0] 28 29 retVal := reflect.MakeSlice(reflect.SliceOf(reflect.SliceOf(typ)), rows, rows) 30 ptr := t.Uintptr() 31 for i := 0; i < rows; i++ { 32 e := retVal.Index(i) 33 sh := (*reflect.SliceHeader)(unsafe.Pointer(e.Addr().Pointer())) 34 sh.Data = uintptr(i*rowStride)*typ.Size() + ptr 35 sh.Len = cols 36 sh.Cap = cols 37 } 38 return retVal.Interface(), nil 39 } 40 41 func Tensor3(t *Dense) (interface{}, error) { 42 if err := checkNativeIterable(t, 3, t.Dtype()); err != nil { 43 return nil, err 44 } 45 shape := t.Shape() 46 strides := t.Strides() 47 typ := t.Dtype().Type 48 49 layers := shape[0] 50 rows := shape[1] 51 cols := shape[2] 52 layerStride := strides[0] 53 rowStride := strides[1] 54 retVal := reflect.MakeSlice(reflect.SliceOf(reflect.SliceOf(reflect.SliceOf(typ))), layers, layers) 55 ptr := t.Uintptr() 56 for i := 0; i < layers; i++ { 57 el := retVal.Index(i) 58 inner := reflect.MakeSlice(reflect.SliceOf(reflect.SliceOf(typ)), rows, rows) 59 for j := 0; j < rows; j++ { 60 e := inner.Index(j) 61 sh := (*reflect.SliceHeader)(unsafe.Pointer(e.Addr().Pointer())) 62 sh.Data = uintptr(i*layerStride+j*rowStride)*typ.Size() + ptr 63 sh.Len = cols 64 sh.Cap = cols 65 } 66 sh := (*reflect.SliceHeader)(unsafe.Pointer(el.Addr().Pointer())) 67 sh.Data = inner.Index(0).Addr().Pointer() 68 sh.Len = rows 69 sh.Cap = rows 70 } 71 return retVal.Interface(), nil 72 }