github.com/wzzhu/tensor@v0.9.24/genlib2/native_select.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"text/template"
     7  )
     8  
     9  const checkNativeSelectable = `func checkNativeSelectable(t *Dense, axis int, dt Dtype) error {
    10  	if !t.IsNativelyAccessible() {
    11  		return errors.New("Cannot select on non-natively accessible data")
    12  	}
    13  	if axis >= t.Shape().Dims() && !(t.IsScalar() && axis == 0) {
    14  		return errors.Errorf("Cannot select on axis %d. Shape is %v", axis, t.Shape())
    15  	}
    16  	if t.F() || t.RequiresIterator() {
    17  		return errors.Errorf("Not yet implemented: native select for colmajor or unpacked matrices")
    18  	}
    19  	if t.Dtype() != dt {
    20  		return errors.Errorf("Native selection only works on %v. Got %v", dt, t.Dtype())
    21  	}
    22  	return nil
    23  }
    24  `
    25  const nativeSelectRaw = `// Select{{short .}} creates a slice of flat data types. See Example of NativeSelectF64.
    26  func Select{{short .}}(t *Dense, axis int) (retVal [][]{{asType .}}, err error) {
    27  	if err := checkNativeSelectable(t, axis, {{reflectKind .}}); err != nil {
    28  		return nil, err
    29  	}
    30  
    31  	switch t.Shape().Dims() {
    32  	case 0, 1:
    33  		retVal = make([][]{{asType .}}, 1)
    34  		retVal[0] = t.{{sliceOf .}}
    35  	case 2:
    36  		if axis == 0 {
    37  			return Matrix{{short .}}(t)
    38  		}
    39  		fallthrough
    40  	default:
    41  		// size := t.Shape()[axis]
    42  		data := t.{{sliceOf .}}
    43  		stride := t.Strides()[axis]
    44  		upper := ProdInts(t.Shape()[:axis+1])
    45  		retVal = make([][]{{asType .}}, 0, upper)
    46  		for i, r := 0, 0; r < upper; i += stride {
    47  			s := make([]{{asType .}}, 0)
    48  			hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s))
    49  			hdr.Data = uintptr(unsafe.Pointer(&data[i]))
    50  			hdr.Len = stride
    51  			hdr.Cap = stride
    52  			retVal = append(retVal, s)
    53  			r++
    54  		}
    55  		return retVal, nil
    56  
    57  	}
    58  	return
    59  }
    60  `
    61  const nativeSelectTestRaw = `func TestSelect{{short .}}(t *testing.T) {
    62  	assert := assert.New(t)
    63  	var T *Dense
    64  	var err error
    65  	var x [][]{{asType .}}
    66  	T = New(Of({{reflectKind .}}), WithShape(2, 3, 4, 5), )
    67  	if x, err = Select{{short .}}(T, 1); err != nil {
    68  		t.Fatal(err)
    69  	}
    70  	assert.Equal(6, len(x))
    71  	assert.Equal(20, len(x[0]))
    72  
    73  	T = New(Of({{reflectKind .}}), WithShape(2, 3, 4, 5), )
    74  	if x, err = Select{{short .}}(T, 0); err != nil {
    75  		t.Fatal(err)
    76  	}
    77  	assert.Equal(2, len(x))
    78  	assert.Equal(60, len(x[0]))
    79  
    80  	T = New(Of({{reflectKind .}}), WithShape(2, 3, 4, 5), )
    81  	if x, err = Select{{short .}}(T, 3); err != nil {
    82  		t.Fatal(err)
    83  	}
    84  	assert.Equal(120, len(x))
    85  	assert.Equal(1, len(x[0]))
    86  
    87  	T = New(Of({{reflectKind .}}), WithShape(2, 3), )
    88  	if x, err = Select{{short .}}(T, 0); err != nil {
    89  		t.Fatal(err)
    90  	}
    91  	assert.Equal(2, len(x))
    92  	assert.Equal(3, len(x[0]))
    93  
    94  	T = New(Of({{reflectKind .}}), WithShape(2, 3), )
    95  	if x, err = Select{{short .}}(T, 1); err != nil {
    96  		t.Fatal(err)
    97  	}
    98  	assert.Equal(6, len(x))
    99  	assert.Equal(1, len(x[0]))
   100  
   101  	T = New(FromScalar({{if eq .String "bool" -}}false{{else if eq .String "string" -}}""{{else -}}{{asType .}}(0) {{end -}} ))
   102  	if x, err = Select{{short .}}(T, 0); err != nil {
   103  		t.Fatal(err)
   104  	}
   105  	assert.Equal(1, len(x))
   106  	assert.Equal(1, len(x[0]))
   107  
   108  	if _, err = Select{{short .}}(T, 10); err == nil{
   109  		t.Fatal("Expected errors")
   110  	}
   111  }
   112  `
   113  
   114  var (
   115  	NativeSelect     *template.Template
   116  	NativeSelectTest *template.Template
   117  )
   118  
   119  func init() {
   120  	NativeSelect = template.Must(template.New("NativeSelect").Funcs(funcs).Parse(nativeSelectRaw))
   121  	NativeSelectTest = template.Must(template.New("NativeSelectTest").Funcs(funcs).Parse(nativeSelectTestRaw))
   122  }
   123  
   124  func generateNativeSelect(f io.Writer, ak Kinds) {
   125  	fmt.Fprintf(f, importUnqualifiedTensor)
   126  	fmt.Fprintf(f, "%v\n", checkNativeSelectable)
   127  	ks := filter(ak.Kinds, isSpecialized)
   128  	for _, k := range ks {
   129  		fmt.Fprintf(f, "/* Native Select for %v */\n\n", k)
   130  		NativeSelect.Execute(f, k)
   131  		fmt.Fprint(f, "\n\n")
   132  	}
   133  }
   134  
   135  func generateNativeSelectTests(f io.Writer, ak Kinds) {
   136  	fmt.Fprintf(f, importUnqualifiedTensor)
   137  	ks := filter(ak.Kinds, isSpecialized)
   138  	for _, k := range ks {
   139  		NativeSelectTest.Execute(f, k)
   140  		fmt.Fprint(f, "\n\n")
   141  	}
   142  }