github.com/wzzhu/tensor@v0.9.24/native/generic.go (about)

     1  package native
     2  
     3  import (
     4  	"reflect"
     5  	"unsafe"
     6  
     7  	. "github.com/wzzhu/tensor"
     8  )
     9  
    10  func Vector(t *Dense) (interface{}, error) {
    11  	if err := checkNativeIterable(t, 1, t.Dtype()); err != nil {
    12  		return nil, err
    13  	}
    14  	return t.Data(), nil
    15  }
    16  
    17  func Matrix(t *Dense) (interface{}, error) {
    18  	if err := checkNativeIterable(t, 2, t.Dtype()); err != nil {
    19  		return nil, err
    20  	}
    21  
    22  	shape := t.Shape()
    23  	strides := t.Strides()
    24  	typ := t.Dtype().Type
    25  	rows := shape[0]
    26  	cols := shape[1]
    27  	rowStride := strides[0]
    28  
    29  	retVal := reflect.MakeSlice(reflect.SliceOf(reflect.SliceOf(typ)), rows, rows)
    30  	ptr := t.Uintptr()
    31  	for i := 0; i < rows; i++ {
    32  		e := retVal.Index(i)
    33  		sh := (*reflect.SliceHeader)(unsafe.Pointer(e.Addr().Pointer()))
    34  		sh.Data = uintptr(i*rowStride)*typ.Size() + ptr
    35  		sh.Len = cols
    36  		sh.Cap = cols
    37  	}
    38  	return retVal.Interface(), nil
    39  }
    40  
    41  func Tensor3(t *Dense) (interface{}, error) {
    42  	if err := checkNativeIterable(t, 3, t.Dtype()); err != nil {
    43  		return nil, err
    44  	}
    45  	shape := t.Shape()
    46  	strides := t.Strides()
    47  	typ := t.Dtype().Type
    48  
    49  	layers := shape[0]
    50  	rows := shape[1]
    51  	cols := shape[2]
    52  	layerStride := strides[0]
    53  	rowStride := strides[1]
    54  	retVal := reflect.MakeSlice(reflect.SliceOf(reflect.SliceOf(reflect.SliceOf(typ))), layers, layers)
    55  	ptr := t.Uintptr()
    56  	for i := 0; i < layers; i++ {
    57  		el := retVal.Index(i)
    58  		inner := reflect.MakeSlice(reflect.SliceOf(reflect.SliceOf(typ)), rows, rows)
    59  		for j := 0; j < rows; j++ {
    60  			e := inner.Index(j)
    61  			sh := (*reflect.SliceHeader)(unsafe.Pointer(e.Addr().Pointer()))
    62  			sh.Data = uintptr(i*layerStride+j*rowStride)*typ.Size() + ptr
    63  			sh.Len = cols
    64  			sh.Cap = cols
    65  		}
    66  		sh := (*reflect.SliceHeader)(unsafe.Pointer(el.Addr().Pointer()))
    67  		sh.Data = inner.Index(0).Addr().Pointer()
    68  		sh.Len = rows
    69  		sh.Cap = rows
    70  	}
    71  	return retVal.Interface(), nil
    72  }