github.com/wzzhu/tensor@v0.9.24/genlib2/native_iterator.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"text/template"
     7  )
     8  
     9  const checkNativeiterable = `func checkNativeIterable(t *Dense, dims int, dt Dtype) error {
    10  	// checks:
    11  	if !t.IsNativelyAccessible() {
    12  		return errors.Errorf("Cannot convert *Dense to *mat.Dense. Data is inaccessible")
    13  	}
    14  
    15  	if t.Shape().Dims() != dims {
    16  		return errors.Errorf("Cannot convert *Dense to native iterator. Expected number of dimension: %d, T has got %d dimensions (Shape: %v)", dims, t.Dims(), t.Shape())
    17  	}
    18  
    19  	if t.F() || t.RequiresIterator() {
    20  		return errors.Errorf("Not yet implemented: native matrix for colmajor or unpacked matrices")
    21  	}
    22  
    23  	if t.Dtype() != dt {
    24  		return errors.Errorf("Conversion to native iterable only works on %v. Got %v", dt, t.Dtype())
    25  	}
    26  
    27  	return nil
    28  }
    29  `
    30  
    31  const nativeIterRaw = `// Vector{{short .}} converts a *Dense into a []{{asType .}}
    32  // If the *Dense does not represent a vector of the wanted type, it will return
    33  // an error.
    34  func Vector{{short .}}(t *Dense) (retVal []{{asType .}}, err error) {
    35  	if err = checkNativeIterable(t, 1, {{reflectKind .}}); err != nil {
    36  		return nil, err
    37  	}
    38  	return t.{{sliceOf .}}, nil
    39  }
    40  
    41  // Matrix{{short .}} converts a  *Dense into a [][]{{asType .}}
    42  // If the *Dense does not represent a matrix of the wanted type, it
    43  // will return an error.
    44  func Matrix{{short .}}(t *Dense) (retVal [][]{{asType .}}, err error) {
    45  	if err = checkNativeIterable(t, 2, {{reflectKind .}}); err != nil {
    46  		return nil, err
    47  	}
    48  
    49  	data := t.{{sliceOf .}}
    50  	shape := t.Shape()
    51  	strides := t.Strides()
    52  
    53  	rows := shape[0]
    54  	cols := shape[1]
    55  	rowStride := strides[0]
    56  	retVal = make([][]{{asType .}}, rows)
    57  	for i := range retVal {
    58  		start := i * rowStride
    59  		retVal[i] = make([]{{asType .}}, 0)
    60  		hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i]))
    61  		hdr.Data = uintptr(unsafe.Pointer(&data[start]))
    62  		hdr.Cap = cols
    63  		hdr.Len = cols
    64  	}
    65  	return
    66  }
    67  
    68  // Tensor3{{short .}} converts a *Dense into a  [][][]{{asType .}}.
    69  // If the *Dense does not represent a 3-tensor of the wanted type, it will return an error.
    70  func Tensor3{{short .}}(t *Dense) (retVal [][][]{{asType .}}, err error) {
    71  	if err = checkNativeIterable(t, 3, {{reflectKind .}}); err != nil {
    72  		return nil, err
    73  	}
    74  
    75  	data := t.{{sliceOf .}}
    76  	shape := t.Shape()
    77  	strides := t.Strides()
    78  
    79  	layers := shape[0]
    80  	rows := shape[1]
    81  	cols := shape[2]
    82  	layerStride := strides[0]
    83  	rowStride := strides[1]
    84  	retVal = make([][][]{{asType .}}, layers)
    85  	for i := range retVal {
    86  		retVal[i] = make([][]{{asType .}}, rows)
    87  		for j := range retVal[i] {
    88  			retVal[i][j] = make([]{{asType .}}, 0)
    89  			start := i*layerStride + j*rowStride
    90  			hdr := (*reflect.SliceHeader)(unsafe.Pointer(&retVal[i][j]))
    91  			hdr.Data = uintptr(unsafe.Pointer(&data[start]))
    92  			hdr.Cap = cols
    93  			hdr.Len = cols
    94  		}
    95  	}
    96  	return
    97  }
    98  `
    99  
   100  const nativeIterTestRaw = `func Test_Vector{{short .}}(t *testing.T) {
   101  	assert := assert.New(t)
   102  	var T *Dense
   103  	{{if isRangeable . -}}
   104  	T = New(WithBacking(Range({{reflectKind .}}, 0, 6)), WithShape(6))
   105  	{{else -}}
   106  	T = New(Of({{reflectKind .}}), WithShape(6))
   107  	{{end -}}
   108  	it, err := Vector{{short .}}(T)
   109  	if err != nil {
   110  		t.Fatal(err)
   111  	}
   112  
   113  	assert.Equal(6, len(it))
   114  }
   115  
   116  func Test_Matrix{{short .}}(t *testing.T) {
   117  	assert := assert.New(t)
   118  	var T *Dense
   119  	{{if isRangeable . -}}
   120  	T = New(WithBacking(Range({{reflectKind .}}, 0, 6)), WithShape(2, 3))
   121  	{{else -}}
   122  	T = New(Of({{reflectKind .}}), WithShape(2, 3))
   123  	{{end -}}
   124  	it, err := Matrix{{short .}}(T)
   125  	if err != nil {
   126  		t.Fatal(err)
   127  	}
   128  
   129  	assert.Equal(2, len(it))
   130  	assert.Equal(3, len(it[0]))
   131  }
   132  
   133  func Test_Tensor3{{short .}}(t *testing.T) {
   134  	assert := assert.New(t)
   135  	var T *Dense
   136  	{{if isRangeable . -}}
   137  	T = New(WithBacking(Range({{reflectKind .}}, 0, 24)), WithShape(2, 3, 4))
   138  	{{else -}}
   139  	T = New(Of({{reflectKind .}}), WithShape(2, 3, 4))
   140  	{{end -}}
   141  	it, err := Tensor3{{short .}}(T)
   142  	if err != nil {
   143  		t.Fatal(err)
   144  	}
   145  
   146  	assert.Equal(2, len(it))
   147  	assert.Equal(3, len(it[0]))
   148  	assert.Equal(4, len(it[0][0]))
   149  }
   150  `
   151  
   152  var (
   153  	NativeIter     *template.Template
   154  	NativeIterTest *template.Template
   155  )
   156  
   157  func init() {
   158  	NativeIter = template.Must(template.New("NativeIter").Funcs(funcs).Parse(nativeIterRaw))
   159  	NativeIterTest = template.Must(template.New("NativeIterTest").Funcs(funcs).Parse(nativeIterTestRaw))
   160  }
   161  
   162  func generateNativeIterators(f io.Writer, ak Kinds) {
   163  	fmt.Fprintf(f, importUnqualifiedTensor)
   164  	fmt.Fprintf(f, "%v\n", checkNativeiterable)
   165  	ks := filter(ak.Kinds, isSpecialized)
   166  	for _, k := range ks {
   167  		fmt.Fprintf(f, "/* Native Iterables for %v */\n\n", k)
   168  		NativeIter.Execute(f, k)
   169  		fmt.Fprint(f, "\n\n")
   170  	}
   171  }
   172  
   173  func generateNativeIteratorTests(f io.Writer, ak Kinds) {
   174  	fmt.Fprintf(f, importUnqualifiedTensor)
   175  	ks := filter(ak.Kinds, isSpecialized)
   176  	for _, k := range ks {
   177  		NativeIterTest.Execute(f, k)
   178  		fmt.Fprint(f, "\n\n")
   179  	}
   180  }