gorgonia.org/tensor@v0.9.24/dense_argmethods_test.go (about)

     1  // Code generated by genlib2. DO NOT EDIT.
     2  
     3  package tensor
     4  
     5  import (
     6  	"math"
     7  	"testing"
     8  
     9  	"github.com/chewxy/math32"
    10  	"github.com/stretchr/testify/assert"
    11  )
    12  
    13  /* Test data */
    14  
    15  var basicDenseI = New(WithShape(2, 3, 4, 5, 2), WithBacking([]int{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3}))
    16  var basicDenseI8 = New(WithShape(2, 3, 4, 5, 2), WithBacking([]int8{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3}))
    17  var basicDenseI16 = New(WithShape(2, 3, 4, 5, 2), WithBacking([]int16{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3}))
    18  var basicDenseI32 = New(WithShape(2, 3, 4, 5, 2), WithBacking([]int32{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3}))
    19  var basicDenseI64 = New(WithShape(2, 3, 4, 5, 2), WithBacking([]int64{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3}))
    20  var basicDenseU = New(WithShape(2, 3, 4, 5, 2), WithBacking([]uint{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3}))
    21  var basicDenseU8 = New(WithShape(2, 3, 4, 5, 2), WithBacking([]uint8{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3}))
    22  var basicDenseU16 = New(WithShape(2, 3, 4, 5, 2), WithBacking([]uint16{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3}))
    23  var basicDenseU32 = New(WithShape(2, 3, 4, 5, 2), WithBacking([]uint32{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3}))
    24  var basicDenseU64 = New(WithShape(2, 3, 4, 5, 2), WithBacking([]uint64{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3}))
    25  var basicDenseF32 = New(WithShape(2, 3, 4, 5, 2), WithBacking([]float32{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3}))
    26  var basicDenseF64 = New(WithShape(2, 3, 4, 5, 2), WithBacking([]float64{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3}))
    27  
    28  var argmaxCorrect = []struct {
    29  	shape Shape
    30  	data  []int
    31  }{
    32  	{Shape{3, 4, 5, 2}, []int{
    33  		1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1,
    34  		1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0,
    35  		1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1,
    36  		1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1,
    37  		0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0,
    38  		1, 0, 0, 0, 0,
    39  	}},
    40  	{Shape{2, 4, 5, 2}, []int{
    41  		1, 0, 1, 1, 2, 0, 2, 0, 0, 1, 2, 1, 2, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1,
    42  		2, 2, 0, 1, 1, 2, 2, 1, 0, 2, 0, 2, 0, 2, 2, 1, 0, 0, 0, 0, 0, 1, 0,
    43  		0, 0, 2, 1, 0, 1, 2, 1, 0, 1, 1, 2, 0, 1, 0, 0, 0, 0, 2, 1, 0, 1, 0,
    44  		0, 2, 1, 1, 0, 0, 0, 0, 0, 2, 0,
    45  	}},
    46  	{Shape{2, 3, 5, 2}, []int{
    47  		3, 2, 2, 1, 1, 2, 1, 0, 0, 1, 3, 2, 1, 0, 1, 0, 2, 2, 3, 0, 1, 0, 1,
    48  		3, 0, 2, 3, 3, 2, 1, 2, 2, 0, 0, 1, 3, 2, 0, 1, 2, 0, 3, 0, 1, 0, 1,
    49  		3, 2, 2, 1, 2, 1, 3, 1, 2, 0, 2, 2, 0, 0,
    50  	}},
    51  	{Shape{2, 3, 4, 2}, []int{
    52  		4, 3, 2, 1, 1, 2, 0, 1, 1, 1, 1, 3, 1, 0, 0, 2, 2, 1, 0, 4, 2, 2, 3,
    53  		1, 1, 1, 0, 2, 0, 0, 2, 2, 1, 4, 0, 1, 4, 1, 1, 0, 4, 3, 1, 1, 2, 3,
    54  		1, 1,
    55  	}},
    56  	{Shape{2, 3, 4, 5}, []int{
    57  		1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1,
    58  		1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0,
    59  		0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1,
    60  		0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1,
    61  		1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1,
    62  		0, 0, 0, 0, 0,
    63  	}},
    64  }
    65  
    66  var argminCorrect = []struct {
    67  	shape Shape
    68  	data  []int
    69  }{
    70  	{Shape{3, 4, 5, 2}, []int{
    71  		0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0,
    72  		0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1,
    73  		0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0,
    74  		0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0,
    75  		1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1,
    76  		0, 1, 1, 0, 1,
    77  	}},
    78  	{Shape{2, 4, 5, 2}, []int{
    79  		2, 1, 0, 0, 1, 2, 1, 2, 1, 2, 1, 0, 0, 2, 1, 0, 1, 2, 0, 1, 0, 2, 2,
    80  		0, 0, 1, 2, 0, 0, 1, 2, 1, 0, 1, 0, 2, 0, 1, 0, 1, 2, 1, 2, 1, 2, 1,
    81  		2, 1, 1, 0, 2, 0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 2, 2, 0, 0, 1, 0, 2,
    82  		2, 0, 0, 0, 1, 2, 2, 2, 2, 1, 1,
    83  	}},
    84  	{Shape{2, 3, 5, 2}, []int{
    85  		0, 1, 0, 2, 2, 1, 3, 2, 3, 2, 1, 0, 3, 3, 0, 1, 0, 3, 0, 2, 0, 1, 0,
    86  		1, 3, 0, 2, 1, 0, 0, 3, 1, 3, 1, 2, 2, 1, 2, 0, 1, 3, 0, 1, 0, 1, 0,
    87  		2, 1, 0, 3, 0, 2, 0, 0, 0, 1, 0, 1, 1, 1,
    88  	}},
    89  	{Shape{2, 3, 4, 2}, []int{
    90  		1, 0, 0, 0, 2, 3, 4, 0, 3, 0, 3, 0, 4, 4, 3, 1, 0, 2, 3, 0, 3, 0, 0,
    91  		2, 4, 4, 3, 4, 2, 3, 0, 0, 4, 0, 1, 3, 3, 2, 0, 4, 2, 1, 4, 2, 4, 0,
    92  		2, 0,
    93  	}},
    94  	{Shape{2, 3, 4, 5}, []int{
    95  		0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0,
    96  		0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1,
    97  		1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0,
    98  		1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0,
    99  		0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0,
   100  		1, 1, 1, 0, 1,
   101  	}},
   102  }
   103  
   104  func TestDense_Argmax_I(t *testing.T) {
   105  	assert := assert.New(t)
   106  	var T, argmax *Dense
   107  	var err error
   108  	T = basicDenseI.Clone().(*Dense)
   109  	for i := 0; i < T.Dims(); i++ {
   110  		if argmax, err = T.Argmax(i); err != nil {
   111  			t.Error(err)
   112  			continue
   113  		}
   114  
   115  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
   116  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
   117  	}
   118  	// test all axes
   119  	if argmax, err = T.Argmax(AllAxes); err != nil {
   120  		t.Error(err)
   121  		return
   122  	}
   123  	assert.True(argmax.IsScalar())
   124  	assert.Equal(7, argmax.ScalarValue())
   125  
   126  	// with different engine
   127  	T = basicDenseI.Clone().(*Dense)
   128  	WithEngine(dummyEngine2{})(T)
   129  	for i := 0; i < T.Dims(); i++ {
   130  		if argmax, err = T.Argmax(i); err != nil {
   131  			t.Error(err)
   132  			continue
   133  		}
   134  
   135  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
   136  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
   137  	}
   138  
   139  	// idiotsville
   140  	_, err = T.Argmax(10000)
   141  	assert.NotNil(err)
   142  
   143  }
   144  func TestDense_Argmin_I(t *testing.T) {
   145  	assert := assert.New(t)
   146  	var T, argmin *Dense
   147  	var err error
   148  	T = basicDenseI.Clone().(*Dense)
   149  	for i := 0; i < T.Dims(); i++ {
   150  		if argmin, err = T.Argmin(i); err != nil {
   151  			t.Error(err)
   152  			continue
   153  		}
   154  
   155  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
   156  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
   157  	}
   158  	// test all axes
   159  	if argmin, err = T.Argmin(AllAxes); err != nil {
   160  		t.Error(err)
   161  		return
   162  	}
   163  	assert.True(argmin.IsScalar())
   164  	assert.Equal(11, argmin.ScalarValue())
   165  
   166  	// with different engine
   167  	T = basicDenseI.Clone().(*Dense)
   168  	WithEngine(dummyEngine2{})(T)
   169  	for i := 0; i < T.Dims(); i++ {
   170  		if argmin, err = T.Argmin(i); err != nil {
   171  			t.Error(err)
   172  			continue
   173  		}
   174  
   175  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
   176  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
   177  	}
   178  
   179  	// idiotsville
   180  	_, err = T.Argmin(10000)
   181  	assert.NotNil(err)
   182  
   183  }
   184  func TestDense_Argmax_I8(t *testing.T) {
   185  	assert := assert.New(t)
   186  	var T, argmax *Dense
   187  	var err error
   188  	T = basicDenseI8.Clone().(*Dense)
   189  	for i := 0; i < T.Dims(); i++ {
   190  		if argmax, err = T.Argmax(i); err != nil {
   191  			t.Error(err)
   192  			continue
   193  		}
   194  
   195  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
   196  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
   197  	}
   198  	// test all axes
   199  	if argmax, err = T.Argmax(AllAxes); err != nil {
   200  		t.Error(err)
   201  		return
   202  	}
   203  	assert.True(argmax.IsScalar())
   204  	assert.Equal(7, argmax.ScalarValue())
   205  
   206  	// with different engine
   207  	T = basicDenseI8.Clone().(*Dense)
   208  	WithEngine(dummyEngine2{})(T)
   209  	for i := 0; i < T.Dims(); i++ {
   210  		if argmax, err = T.Argmax(i); err != nil {
   211  			t.Error(err)
   212  			continue
   213  		}
   214  
   215  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
   216  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
   217  	}
   218  
   219  	// idiotsville
   220  	_, err = T.Argmax(10000)
   221  	assert.NotNil(err)
   222  
   223  }
   224  func TestDense_Argmin_I8(t *testing.T) {
   225  	assert := assert.New(t)
   226  	var T, argmin *Dense
   227  	var err error
   228  	T = basicDenseI8.Clone().(*Dense)
   229  	for i := 0; i < T.Dims(); i++ {
   230  		if argmin, err = T.Argmin(i); err != nil {
   231  			t.Error(err)
   232  			continue
   233  		}
   234  
   235  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
   236  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
   237  	}
   238  	// test all axes
   239  	if argmin, err = T.Argmin(AllAxes); err != nil {
   240  		t.Error(err)
   241  		return
   242  	}
   243  	assert.True(argmin.IsScalar())
   244  	assert.Equal(11, argmin.ScalarValue())
   245  
   246  	// with different engine
   247  	T = basicDenseI8.Clone().(*Dense)
   248  	WithEngine(dummyEngine2{})(T)
   249  	for i := 0; i < T.Dims(); i++ {
   250  		if argmin, err = T.Argmin(i); err != nil {
   251  			t.Error(err)
   252  			continue
   253  		}
   254  
   255  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
   256  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
   257  	}
   258  
   259  	// idiotsville
   260  	_, err = T.Argmin(10000)
   261  	assert.NotNil(err)
   262  
   263  }
   264  func TestDense_Argmax_I16(t *testing.T) {
   265  	assert := assert.New(t)
   266  	var T, argmax *Dense
   267  	var err error
   268  	T = basicDenseI16.Clone().(*Dense)
   269  	for i := 0; i < T.Dims(); i++ {
   270  		if argmax, err = T.Argmax(i); err != nil {
   271  			t.Error(err)
   272  			continue
   273  		}
   274  
   275  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
   276  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
   277  	}
   278  	// test all axes
   279  	if argmax, err = T.Argmax(AllAxes); err != nil {
   280  		t.Error(err)
   281  		return
   282  	}
   283  	assert.True(argmax.IsScalar())
   284  	assert.Equal(7, argmax.ScalarValue())
   285  
   286  	// with different engine
   287  	T = basicDenseI16.Clone().(*Dense)
   288  	WithEngine(dummyEngine2{})(T)
   289  	for i := 0; i < T.Dims(); i++ {
   290  		if argmax, err = T.Argmax(i); err != nil {
   291  			t.Error(err)
   292  			continue
   293  		}
   294  
   295  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
   296  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
   297  	}
   298  
   299  	// idiotsville
   300  	_, err = T.Argmax(10000)
   301  	assert.NotNil(err)
   302  
   303  }
   304  func TestDense_Argmin_I16(t *testing.T) {
   305  	assert := assert.New(t)
   306  	var T, argmin *Dense
   307  	var err error
   308  	T = basicDenseI16.Clone().(*Dense)
   309  	for i := 0; i < T.Dims(); i++ {
   310  		if argmin, err = T.Argmin(i); err != nil {
   311  			t.Error(err)
   312  			continue
   313  		}
   314  
   315  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
   316  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
   317  	}
   318  	// test all axes
   319  	if argmin, err = T.Argmin(AllAxes); err != nil {
   320  		t.Error(err)
   321  		return
   322  	}
   323  	assert.True(argmin.IsScalar())
   324  	assert.Equal(11, argmin.ScalarValue())
   325  
   326  	// with different engine
   327  	T = basicDenseI16.Clone().(*Dense)
   328  	WithEngine(dummyEngine2{})(T)
   329  	for i := 0; i < T.Dims(); i++ {
   330  		if argmin, err = T.Argmin(i); err != nil {
   331  			t.Error(err)
   332  			continue
   333  		}
   334  
   335  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
   336  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
   337  	}
   338  
   339  	// idiotsville
   340  	_, err = T.Argmin(10000)
   341  	assert.NotNil(err)
   342  
   343  }
   344  func TestDense_Argmax_I32(t *testing.T) {
   345  	assert := assert.New(t)
   346  	var T, argmax *Dense
   347  	var err error
   348  	T = basicDenseI32.Clone().(*Dense)
   349  	for i := 0; i < T.Dims(); i++ {
   350  		if argmax, err = T.Argmax(i); err != nil {
   351  			t.Error(err)
   352  			continue
   353  		}
   354  
   355  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
   356  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
   357  	}
   358  	// test all axes
   359  	if argmax, err = T.Argmax(AllAxes); err != nil {
   360  		t.Error(err)
   361  		return
   362  	}
   363  	assert.True(argmax.IsScalar())
   364  	assert.Equal(7, argmax.ScalarValue())
   365  
   366  	// with different engine
   367  	T = basicDenseI32.Clone().(*Dense)
   368  	WithEngine(dummyEngine2{})(T)
   369  	for i := 0; i < T.Dims(); i++ {
   370  		if argmax, err = T.Argmax(i); err != nil {
   371  			t.Error(err)
   372  			continue
   373  		}
   374  
   375  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
   376  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
   377  	}
   378  
   379  	// idiotsville
   380  	_, err = T.Argmax(10000)
   381  	assert.NotNil(err)
   382  
   383  }
   384  func TestDense_Argmin_I32(t *testing.T) {
   385  	assert := assert.New(t)
   386  	var T, argmin *Dense
   387  	var err error
   388  	T = basicDenseI32.Clone().(*Dense)
   389  	for i := 0; i < T.Dims(); i++ {
   390  		if argmin, err = T.Argmin(i); err != nil {
   391  			t.Error(err)
   392  			continue
   393  		}
   394  
   395  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
   396  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
   397  	}
   398  	// test all axes
   399  	if argmin, err = T.Argmin(AllAxes); err != nil {
   400  		t.Error(err)
   401  		return
   402  	}
   403  	assert.True(argmin.IsScalar())
   404  	assert.Equal(11, argmin.ScalarValue())
   405  
   406  	// with different engine
   407  	T = basicDenseI32.Clone().(*Dense)
   408  	WithEngine(dummyEngine2{})(T)
   409  	for i := 0; i < T.Dims(); i++ {
   410  		if argmin, err = T.Argmin(i); err != nil {
   411  			t.Error(err)
   412  			continue
   413  		}
   414  
   415  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
   416  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
   417  	}
   418  
   419  	// idiotsville
   420  	_, err = T.Argmin(10000)
   421  	assert.NotNil(err)
   422  
   423  }
   424  func TestDense_Argmax_I64(t *testing.T) {
   425  	assert := assert.New(t)
   426  	var T, argmax *Dense
   427  	var err error
   428  	T = basicDenseI64.Clone().(*Dense)
   429  	for i := 0; i < T.Dims(); i++ {
   430  		if argmax, err = T.Argmax(i); err != nil {
   431  			t.Error(err)
   432  			continue
   433  		}
   434  
   435  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
   436  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
   437  	}
   438  	// test all axes
   439  	if argmax, err = T.Argmax(AllAxes); err != nil {
   440  		t.Error(err)
   441  		return
   442  	}
   443  	assert.True(argmax.IsScalar())
   444  	assert.Equal(7, argmax.ScalarValue())
   445  
   446  	// with different engine
   447  	T = basicDenseI64.Clone().(*Dense)
   448  	WithEngine(dummyEngine2{})(T)
   449  	for i := 0; i < T.Dims(); i++ {
   450  		if argmax, err = T.Argmax(i); err != nil {
   451  			t.Error(err)
   452  			continue
   453  		}
   454  
   455  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
   456  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
   457  	}
   458  
   459  	// idiotsville
   460  	_, err = T.Argmax(10000)
   461  	assert.NotNil(err)
   462  
   463  }
   464  func TestDense_Argmin_I64(t *testing.T) {
   465  	assert := assert.New(t)
   466  	var T, argmin *Dense
   467  	var err error
   468  	T = basicDenseI64.Clone().(*Dense)
   469  	for i := 0; i < T.Dims(); i++ {
   470  		if argmin, err = T.Argmin(i); err != nil {
   471  			t.Error(err)
   472  			continue
   473  		}
   474  
   475  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
   476  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
   477  	}
   478  	// test all axes
   479  	if argmin, err = T.Argmin(AllAxes); err != nil {
   480  		t.Error(err)
   481  		return
   482  	}
   483  	assert.True(argmin.IsScalar())
   484  	assert.Equal(11, argmin.ScalarValue())
   485  
   486  	// with different engine
   487  	T = basicDenseI64.Clone().(*Dense)
   488  	WithEngine(dummyEngine2{})(T)
   489  	for i := 0; i < T.Dims(); i++ {
   490  		if argmin, err = T.Argmin(i); err != nil {
   491  			t.Error(err)
   492  			continue
   493  		}
   494  
   495  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
   496  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
   497  	}
   498  
   499  	// idiotsville
   500  	_, err = T.Argmin(10000)
   501  	assert.NotNil(err)
   502  
   503  }
   504  func TestDense_Argmax_U(t *testing.T) {
   505  	assert := assert.New(t)
   506  	var T, argmax *Dense
   507  	var err error
   508  	T = basicDenseU.Clone().(*Dense)
   509  	for i := 0; i < T.Dims(); i++ {
   510  		if argmax, err = T.Argmax(i); err != nil {
   511  			t.Error(err)
   512  			continue
   513  		}
   514  
   515  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
   516  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
   517  	}
   518  	// test all axes
   519  	if argmax, err = T.Argmax(AllAxes); err != nil {
   520  		t.Error(err)
   521  		return
   522  	}
   523  	assert.True(argmax.IsScalar())
   524  	assert.Equal(7, argmax.ScalarValue())
   525  
   526  	// with different engine
   527  	T = basicDenseU.Clone().(*Dense)
   528  	WithEngine(dummyEngine2{})(T)
   529  	for i := 0; i < T.Dims(); i++ {
   530  		if argmax, err = T.Argmax(i); err != nil {
   531  			t.Error(err)
   532  			continue
   533  		}
   534  
   535  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
   536  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
   537  	}
   538  
   539  	// idiotsville
   540  	_, err = T.Argmax(10000)
   541  	assert.NotNil(err)
   542  
   543  }
   544  func TestDense_Argmin_U(t *testing.T) {
   545  	assert := assert.New(t)
   546  	var T, argmin *Dense
   547  	var err error
   548  	T = basicDenseU.Clone().(*Dense)
   549  	for i := 0; i < T.Dims(); i++ {
   550  		if argmin, err = T.Argmin(i); err != nil {
   551  			t.Error(err)
   552  			continue
   553  		}
   554  
   555  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
   556  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
   557  	}
   558  	// test all axes
   559  	if argmin, err = T.Argmin(AllAxes); err != nil {
   560  		t.Error(err)
   561  		return
   562  	}
   563  	assert.True(argmin.IsScalar())
   564  	assert.Equal(11, argmin.ScalarValue())
   565  
   566  	// with different engine
   567  	T = basicDenseU.Clone().(*Dense)
   568  	WithEngine(dummyEngine2{})(T)
   569  	for i := 0; i < T.Dims(); i++ {
   570  		if argmin, err = T.Argmin(i); err != nil {
   571  			t.Error(err)
   572  			continue
   573  		}
   574  
   575  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
   576  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
   577  	}
   578  
   579  	// idiotsville
   580  	_, err = T.Argmin(10000)
   581  	assert.NotNil(err)
   582  
   583  }
   584  func TestDense_Argmax_U8(t *testing.T) {
   585  	assert := assert.New(t)
   586  	var T, argmax *Dense
   587  	var err error
   588  	T = basicDenseU8.Clone().(*Dense)
   589  	for i := 0; i < T.Dims(); i++ {
   590  		if argmax, err = T.Argmax(i); err != nil {
   591  			t.Error(err)
   592  			continue
   593  		}
   594  
   595  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
   596  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
   597  	}
   598  	// test all axes
   599  	if argmax, err = T.Argmax(AllAxes); err != nil {
   600  		t.Error(err)
   601  		return
   602  	}
   603  	assert.True(argmax.IsScalar())
   604  	assert.Equal(7, argmax.ScalarValue())
   605  
   606  	// with different engine
   607  	T = basicDenseU8.Clone().(*Dense)
   608  	WithEngine(dummyEngine2{})(T)
   609  	for i := 0; i < T.Dims(); i++ {
   610  		if argmax, err = T.Argmax(i); err != nil {
   611  			t.Error(err)
   612  			continue
   613  		}
   614  
   615  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
   616  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
   617  	}
   618  
   619  	// idiotsville
   620  	_, err = T.Argmax(10000)
   621  	assert.NotNil(err)
   622  
   623  }
   624  func TestDense_Argmin_U8(t *testing.T) {
   625  	assert := assert.New(t)
   626  	var T, argmin *Dense
   627  	var err error
   628  	T = basicDenseU8.Clone().(*Dense)
   629  	for i := 0; i < T.Dims(); i++ {
   630  		if argmin, err = T.Argmin(i); err != nil {
   631  			t.Error(err)
   632  			continue
   633  		}
   634  
   635  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
   636  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
   637  	}
   638  	// test all axes
   639  	if argmin, err = T.Argmin(AllAxes); err != nil {
   640  		t.Error(err)
   641  		return
   642  	}
   643  	assert.True(argmin.IsScalar())
   644  	assert.Equal(11, argmin.ScalarValue())
   645  
   646  	// with different engine
   647  	T = basicDenseU8.Clone().(*Dense)
   648  	WithEngine(dummyEngine2{})(T)
   649  	for i := 0; i < T.Dims(); i++ {
   650  		if argmin, err = T.Argmin(i); err != nil {
   651  			t.Error(err)
   652  			continue
   653  		}
   654  
   655  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
   656  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
   657  	}
   658  
   659  	// idiotsville
   660  	_, err = T.Argmin(10000)
   661  	assert.NotNil(err)
   662  
   663  }
   664  func TestDense_Argmax_U16(t *testing.T) {
   665  	assert := assert.New(t)
   666  	var T, argmax *Dense
   667  	var err error
   668  	T = basicDenseU16.Clone().(*Dense)
   669  	for i := 0; i < T.Dims(); i++ {
   670  		if argmax, err = T.Argmax(i); err != nil {
   671  			t.Error(err)
   672  			continue
   673  		}
   674  
   675  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
   676  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
   677  	}
   678  	// test all axes
   679  	if argmax, err = T.Argmax(AllAxes); err != nil {
   680  		t.Error(err)
   681  		return
   682  	}
   683  	assert.True(argmax.IsScalar())
   684  	assert.Equal(7, argmax.ScalarValue())
   685  
   686  	// with different engine
   687  	T = basicDenseU16.Clone().(*Dense)
   688  	WithEngine(dummyEngine2{})(T)
   689  	for i := 0; i < T.Dims(); i++ {
   690  		if argmax, err = T.Argmax(i); err != nil {
   691  			t.Error(err)
   692  			continue
   693  		}
   694  
   695  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
   696  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
   697  	}
   698  
   699  	// idiotsville
   700  	_, err = T.Argmax(10000)
   701  	assert.NotNil(err)
   702  
   703  }
   704  func TestDense_Argmin_U16(t *testing.T) {
   705  	assert := assert.New(t)
   706  	var T, argmin *Dense
   707  	var err error
   708  	T = basicDenseU16.Clone().(*Dense)
   709  	for i := 0; i < T.Dims(); i++ {
   710  		if argmin, err = T.Argmin(i); err != nil {
   711  			t.Error(err)
   712  			continue
   713  		}
   714  
   715  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
   716  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
   717  	}
   718  	// test all axes
   719  	if argmin, err = T.Argmin(AllAxes); err != nil {
   720  		t.Error(err)
   721  		return
   722  	}
   723  	assert.True(argmin.IsScalar())
   724  	assert.Equal(11, argmin.ScalarValue())
   725  
   726  	// with different engine
   727  	T = basicDenseU16.Clone().(*Dense)
   728  	WithEngine(dummyEngine2{})(T)
   729  	for i := 0; i < T.Dims(); i++ {
   730  		if argmin, err = T.Argmin(i); err != nil {
   731  			t.Error(err)
   732  			continue
   733  		}
   734  
   735  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
   736  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
   737  	}
   738  
   739  	// idiotsville
   740  	_, err = T.Argmin(10000)
   741  	assert.NotNil(err)
   742  
   743  }
   744  func TestDense_Argmax_U32(t *testing.T) {
   745  	assert := assert.New(t)
   746  	var T, argmax *Dense
   747  	var err error
   748  	T = basicDenseU32.Clone().(*Dense)
   749  	for i := 0; i < T.Dims(); i++ {
   750  		if argmax, err = T.Argmax(i); err != nil {
   751  			t.Error(err)
   752  			continue
   753  		}
   754  
   755  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
   756  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
   757  	}
   758  	// test all axes
   759  	if argmax, err = T.Argmax(AllAxes); err != nil {
   760  		t.Error(err)
   761  		return
   762  	}
   763  	assert.True(argmax.IsScalar())
   764  	assert.Equal(7, argmax.ScalarValue())
   765  
   766  	// with different engine
   767  	T = basicDenseU32.Clone().(*Dense)
   768  	WithEngine(dummyEngine2{})(T)
   769  	for i := 0; i < T.Dims(); i++ {
   770  		if argmax, err = T.Argmax(i); err != nil {
   771  			t.Error(err)
   772  			continue
   773  		}
   774  
   775  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
   776  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
   777  	}
   778  
   779  	// idiotsville
   780  	_, err = T.Argmax(10000)
   781  	assert.NotNil(err)
   782  
   783  }
   784  func TestDense_Argmin_U32(t *testing.T) {
   785  	assert := assert.New(t)
   786  	var T, argmin *Dense
   787  	var err error
   788  	T = basicDenseU32.Clone().(*Dense)
   789  	for i := 0; i < T.Dims(); i++ {
   790  		if argmin, err = T.Argmin(i); err != nil {
   791  			t.Error(err)
   792  			continue
   793  		}
   794  
   795  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
   796  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
   797  	}
   798  	// test all axes
   799  	if argmin, err = T.Argmin(AllAxes); err != nil {
   800  		t.Error(err)
   801  		return
   802  	}
   803  	assert.True(argmin.IsScalar())
   804  	assert.Equal(11, argmin.ScalarValue())
   805  
   806  	// with different engine
   807  	T = basicDenseU32.Clone().(*Dense)
   808  	WithEngine(dummyEngine2{})(T)
   809  	for i := 0; i < T.Dims(); i++ {
   810  		if argmin, err = T.Argmin(i); err != nil {
   811  			t.Error(err)
   812  			continue
   813  		}
   814  
   815  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
   816  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
   817  	}
   818  
   819  	// idiotsville
   820  	_, err = T.Argmin(10000)
   821  	assert.NotNil(err)
   822  
   823  }
   824  func TestDense_Argmax_U64(t *testing.T) {
   825  	assert := assert.New(t)
   826  	var T, argmax *Dense
   827  	var err error
   828  	T = basicDenseU64.Clone().(*Dense)
   829  	for i := 0; i < T.Dims(); i++ {
   830  		if argmax, err = T.Argmax(i); err != nil {
   831  			t.Error(err)
   832  			continue
   833  		}
   834  
   835  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
   836  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
   837  	}
   838  	// test all axes
   839  	if argmax, err = T.Argmax(AllAxes); err != nil {
   840  		t.Error(err)
   841  		return
   842  	}
   843  	assert.True(argmax.IsScalar())
   844  	assert.Equal(7, argmax.ScalarValue())
   845  
   846  	// with different engine
   847  	T = basicDenseU64.Clone().(*Dense)
   848  	WithEngine(dummyEngine2{})(T)
   849  	for i := 0; i < T.Dims(); i++ {
   850  		if argmax, err = T.Argmax(i); err != nil {
   851  			t.Error(err)
   852  			continue
   853  		}
   854  
   855  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
   856  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
   857  	}
   858  
   859  	// idiotsville
   860  	_, err = T.Argmax(10000)
   861  	assert.NotNil(err)
   862  
   863  }
   864  func TestDense_Argmin_U64(t *testing.T) {
   865  	assert := assert.New(t)
   866  	var T, argmin *Dense
   867  	var err error
   868  	T = basicDenseU64.Clone().(*Dense)
   869  	for i := 0; i < T.Dims(); i++ {
   870  		if argmin, err = T.Argmin(i); err != nil {
   871  			t.Error(err)
   872  			continue
   873  		}
   874  
   875  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
   876  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
   877  	}
   878  	// test all axes
   879  	if argmin, err = T.Argmin(AllAxes); err != nil {
   880  		t.Error(err)
   881  		return
   882  	}
   883  	assert.True(argmin.IsScalar())
   884  	assert.Equal(11, argmin.ScalarValue())
   885  
   886  	// with different engine
   887  	T = basicDenseU64.Clone().(*Dense)
   888  	WithEngine(dummyEngine2{})(T)
   889  	for i := 0; i < T.Dims(); i++ {
   890  		if argmin, err = T.Argmin(i); err != nil {
   891  			t.Error(err)
   892  			continue
   893  		}
   894  
   895  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
   896  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
   897  	}
   898  
   899  	// idiotsville
   900  	_, err = T.Argmin(10000)
   901  	assert.NotNil(err)
   902  
   903  }
   904  func TestDense_Argmax_F32(t *testing.T) {
   905  	assert := assert.New(t)
   906  	var T, argmax *Dense
   907  	var err error
   908  	T = basicDenseF32.Clone().(*Dense)
   909  	for i := 0; i < T.Dims(); i++ {
   910  		if argmax, err = T.Argmax(i); err != nil {
   911  			t.Error(err)
   912  			continue
   913  		}
   914  
   915  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
   916  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
   917  	}
   918  	// test all axes
   919  	if argmax, err = T.Argmax(AllAxes); err != nil {
   920  		t.Error(err)
   921  		return
   922  	}
   923  	assert.True(argmax.IsScalar())
   924  	assert.Equal(7, argmax.ScalarValue())
   925  
   926  	// test with NaN
   927  	T = New(WithShape(4), WithBacking([]float32{1, 2, math32.NaN(), 4}))
   928  	if argmax, err = T.Argmax(AllAxes); err != nil {
   929  		t.Errorf("Failed test with NaN: %v", err)
   930  	}
   931  	assert.True(argmax.IsScalar())
   932  	assert.Equal(2, argmax.ScalarValue(), "NaN test")
   933  
   934  	// test with Mask and Nan
   935  	T = New(WithShape(4), WithBacking([]float32{1, 9, math32.NaN(), 4}, []bool{false, true, true, false}))
   936  	if argmax, err = T.Argmax(AllAxes); err != nil {
   937  		t.Errorf("Failed test with NaN: %v", err)
   938  	}
   939  	assert.True(argmax.IsScalar())
   940  	assert.Equal(3, argmax.ScalarValue(), "Masked NaN test")
   941  
   942  	// test with +Inf
   943  	T = New(WithShape(4), WithBacking([]float32{1, 2, math32.Inf(1), 4}))
   944  	if argmax, err = T.Argmax(AllAxes); err != nil {
   945  		t.Errorf("Failed test with +Inf: %v", err)
   946  	}
   947  	assert.True(argmax.IsScalar())
   948  	assert.Equal(2, argmax.ScalarValue(), "+Inf test")
   949  
   950  	// test with Mask and +Inf
   951  	T = New(WithShape(4), WithBacking([]float32{1, 9, math32.Inf(1), 4}, []bool{false, true, true, false}))
   952  	if argmax, err = T.Argmax(AllAxes); err != nil {
   953  		t.Errorf("Failed test with NaN: %v", err)
   954  	}
   955  	assert.True(argmax.IsScalar())
   956  	assert.Equal(3, argmax.ScalarValue(), "Masked NaN test")
   957  
   958  	// test with -Inf
   959  	T = New(WithShape(4), WithBacking([]float32{1, 2, math32.Inf(-1), 4}))
   960  	if argmax, err = T.Argmax(AllAxes); err != nil {
   961  		t.Errorf("Failed test with -Inf: %v", err)
   962  	}
   963  	assert.True(argmax.IsScalar())
   964  	assert.Equal(3, argmax.ScalarValue(), "+Inf test")
   965  
   966  	// test with Mask and -Inf
   967  	T = New(WithShape(4), WithBacking([]float32{1, 9, math32.Inf(-1), 4}, []bool{false, true, true, false}))
   968  	if argmax, err = T.Argmax(AllAxes); err != nil {
   969  		t.Errorf("Failed test with NaN: %v", err)
   970  	}
   971  	assert.True(argmax.IsScalar())
   972  	assert.Equal(3, argmax.ScalarValue(), "Masked -Inf test")
   973  
   974  	// with different engine
   975  	T = basicDenseF32.Clone().(*Dense)
   976  	WithEngine(dummyEngine2{})(T)
   977  	for i := 0; i < T.Dims(); i++ {
   978  		if argmax, err = T.Argmax(i); err != nil {
   979  			t.Error(err)
   980  			continue
   981  		}
   982  
   983  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
   984  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
   985  	}
   986  
   987  	// idiotsville
   988  	_, err = T.Argmax(10000)
   989  	assert.NotNil(err)
   990  
   991  }
   992  func TestDense_Argmin_F32(t *testing.T) {
   993  	assert := assert.New(t)
   994  	var T, argmin *Dense
   995  	var err error
   996  	T = basicDenseF32.Clone().(*Dense)
   997  	for i := 0; i < T.Dims(); i++ {
   998  		if argmin, err = T.Argmin(i); err != nil {
   999  			t.Error(err)
  1000  			continue
  1001  		}
  1002  
  1003  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
  1004  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
  1005  	}
  1006  	// test all axes
  1007  	if argmin, err = T.Argmin(AllAxes); err != nil {
  1008  		t.Error(err)
  1009  		return
  1010  	}
  1011  	assert.True(argmin.IsScalar())
  1012  	assert.Equal(11, argmin.ScalarValue())
  1013  
  1014  	// test with NaN
  1015  	T = New(WithShape(4), WithBacking([]float32{1, 2, math32.NaN(), 4}))
  1016  	if argmin, err = T.Argmin(AllAxes); err != nil {
  1017  		t.Errorf("Failed test with NaN: %v", err)
  1018  	}
  1019  	assert.True(argmin.IsScalar())
  1020  	assert.Equal(2, argmin.ScalarValue(), "NaN test")
  1021  
  1022  	// test with Mask and Nan
  1023  	T = New(WithShape(4), WithBacking([]float32{1, -9, math32.NaN(), 4}, []bool{false, true, true, false}))
  1024  	if argmin, err = T.Argmin(AllAxes); err != nil {
  1025  		t.Errorf("Failed test with NaN: %v", err)
  1026  	}
  1027  	assert.True(argmin.IsScalar())
  1028  	assert.Equal(0, argmin.ScalarValue(), "Masked NaN test")
  1029  
  1030  	// test with +Inf
  1031  	T = New(WithShape(4), WithBacking([]float32{1, 2, math32.Inf(1), 4}))
  1032  	if argmin, err = T.Argmin(AllAxes); err != nil {
  1033  		t.Errorf("Failed test with +Inf: %v", err)
  1034  	}
  1035  	assert.True(argmin.IsScalar())
  1036  	assert.Equal(0, argmin.ScalarValue(), "+Inf test")
  1037  
  1038  	// test with Mask and +Inf
  1039  	T = New(WithShape(4), WithBacking([]float32{1, -9, math32.Inf(1), 4}, []bool{false, true, true, false}))
  1040  	if argmin, err = T.Argmin(AllAxes); err != nil {
  1041  		t.Errorf("Failed test with NaN: %v", err)
  1042  	}
  1043  	assert.True(argmin.IsScalar())
  1044  	assert.Equal(0, argmin.ScalarValue(), "Masked NaN test")
  1045  
  1046  	// test with -Inf
  1047  	T = New(WithShape(4), WithBacking([]float32{1, 2, math32.Inf(-1), 4}))
  1048  	if argmin, err = T.Argmin(AllAxes); err != nil {
  1049  		t.Errorf("Failed test with -Inf: %v", err)
  1050  	}
  1051  	assert.True(argmin.IsScalar())
  1052  	assert.Equal(2, argmin.ScalarValue(), "+Inf test")
  1053  
  1054  	// test with Mask and -Inf
  1055  	T = New(WithShape(4), WithBacking([]float32{1, -9, math32.Inf(-1), 4}, []bool{false, true, true, false}))
  1056  	if argmin, err = T.Argmin(AllAxes); err != nil {
  1057  		t.Errorf("Failed test with NaN: %v", err)
  1058  	}
  1059  	assert.True(argmin.IsScalar())
  1060  	assert.Equal(0, argmin.ScalarValue(), "Masked -Inf test")
  1061  
  1062  	// with different engine
  1063  	T = basicDenseF32.Clone().(*Dense)
  1064  	WithEngine(dummyEngine2{})(T)
  1065  	for i := 0; i < T.Dims(); i++ {
  1066  		if argmin, err = T.Argmin(i); err != nil {
  1067  			t.Error(err)
  1068  			continue
  1069  		}
  1070  
  1071  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
  1072  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
  1073  	}
  1074  
  1075  	// idiotsville
  1076  	_, err = T.Argmin(10000)
  1077  	assert.NotNil(err)
  1078  
  1079  }
  1080  func TestDense_Argmax_F64(t *testing.T) {
  1081  	assert := assert.New(t)
  1082  	var T, argmax *Dense
  1083  	var err error
  1084  	T = basicDenseF64.Clone().(*Dense)
  1085  	for i := 0; i < T.Dims(); i++ {
  1086  		if argmax, err = T.Argmax(i); err != nil {
  1087  			t.Error(err)
  1088  			continue
  1089  		}
  1090  
  1091  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
  1092  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
  1093  	}
  1094  	// test all axes
  1095  	if argmax, err = T.Argmax(AllAxes); err != nil {
  1096  		t.Error(err)
  1097  		return
  1098  	}
  1099  	assert.True(argmax.IsScalar())
  1100  	assert.Equal(7, argmax.ScalarValue())
  1101  
  1102  	// test with NaN
  1103  	T = New(WithShape(4), WithBacking([]float64{1, 2, math.NaN(), 4}))
  1104  	if argmax, err = T.Argmax(AllAxes); err != nil {
  1105  		t.Errorf("Failed test with NaN: %v", err)
  1106  	}
  1107  	assert.True(argmax.IsScalar())
  1108  	assert.Equal(2, argmax.ScalarValue(), "NaN test")
  1109  
  1110  	// test with Mask and Nan
  1111  	T = New(WithShape(4), WithBacking([]float64{1, 9, math.NaN(), 4}, []bool{false, true, true, false}))
  1112  	if argmax, err = T.Argmax(AllAxes); err != nil {
  1113  		t.Errorf("Failed test with NaN: %v", err)
  1114  	}
  1115  	assert.True(argmax.IsScalar())
  1116  	assert.Equal(3, argmax.ScalarValue(), "Masked NaN test")
  1117  
  1118  	// test with +Inf
  1119  	T = New(WithShape(4), WithBacking([]float64{1, 2, math.Inf(1), 4}))
  1120  	if argmax, err = T.Argmax(AllAxes); err != nil {
  1121  		t.Errorf("Failed test with +Inf: %v", err)
  1122  	}
  1123  	assert.True(argmax.IsScalar())
  1124  	assert.Equal(2, argmax.ScalarValue(), "+Inf test")
  1125  
  1126  	// test with Mask and +Inf
  1127  	T = New(WithShape(4), WithBacking([]float64{1, 9, math.Inf(1), 4}, []bool{false, true, true, false}))
  1128  	if argmax, err = T.Argmax(AllAxes); err != nil {
  1129  		t.Errorf("Failed test with NaN: %v", err)
  1130  	}
  1131  	assert.True(argmax.IsScalar())
  1132  	assert.Equal(3, argmax.ScalarValue(), "Masked NaN test")
  1133  
  1134  	// test with -Inf
  1135  	T = New(WithShape(4), WithBacking([]float64{1, 2, math.Inf(-1), 4}))
  1136  	if argmax, err = T.Argmax(AllAxes); err != nil {
  1137  		t.Errorf("Failed test with -Inf: %v", err)
  1138  	}
  1139  	assert.True(argmax.IsScalar())
  1140  	assert.Equal(3, argmax.ScalarValue(), "+Inf test")
  1141  
  1142  	// test with Mask and -Inf
  1143  	T = New(WithShape(4), WithBacking([]float64{1, 9, math.Inf(-1), 4}, []bool{false, true, true, false}))
  1144  	if argmax, err = T.Argmax(AllAxes); err != nil {
  1145  		t.Errorf("Failed test with NaN: %v", err)
  1146  	}
  1147  	assert.True(argmax.IsScalar())
  1148  	assert.Equal(3, argmax.ScalarValue(), "Masked -Inf test")
  1149  
  1150  	// with different engine
  1151  	T = basicDenseF64.Clone().(*Dense)
  1152  	WithEngine(dummyEngine2{})(T)
  1153  	for i := 0; i < T.Dims(); i++ {
  1154  		if argmax, err = T.Argmax(i); err != nil {
  1155  			t.Error(err)
  1156  			continue
  1157  		}
  1158  
  1159  		assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape)
  1160  		assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i)
  1161  	}
  1162  
  1163  	// idiotsville
  1164  	_, err = T.Argmax(10000)
  1165  	assert.NotNil(err)
  1166  
  1167  }
  1168  func TestDense_Argmin_F64(t *testing.T) {
  1169  	assert := assert.New(t)
  1170  	var T, argmin *Dense
  1171  	var err error
  1172  	T = basicDenseF64.Clone().(*Dense)
  1173  	for i := 0; i < T.Dims(); i++ {
  1174  		if argmin, err = T.Argmin(i); err != nil {
  1175  			t.Error(err)
  1176  			continue
  1177  		}
  1178  
  1179  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
  1180  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
  1181  	}
  1182  	// test all axes
  1183  	if argmin, err = T.Argmin(AllAxes); err != nil {
  1184  		t.Error(err)
  1185  		return
  1186  	}
  1187  	assert.True(argmin.IsScalar())
  1188  	assert.Equal(11, argmin.ScalarValue())
  1189  
  1190  	// test with NaN
  1191  	T = New(WithShape(4), WithBacking([]float64{1, 2, math.NaN(), 4}))
  1192  	if argmin, err = T.Argmin(AllAxes); err != nil {
  1193  		t.Errorf("Failed test with NaN: %v", err)
  1194  	}
  1195  	assert.True(argmin.IsScalar())
  1196  	assert.Equal(2, argmin.ScalarValue(), "NaN test")
  1197  
  1198  	// test with Mask and Nan
  1199  	T = New(WithShape(4), WithBacking([]float64{1, -9, math.NaN(), 4}, []bool{false, true, true, false}))
  1200  	if argmin, err = T.Argmin(AllAxes); err != nil {
  1201  		t.Errorf("Failed test with NaN: %v", err)
  1202  	}
  1203  	assert.True(argmin.IsScalar())
  1204  	assert.Equal(0, argmin.ScalarValue(), "Masked NaN test")
  1205  
  1206  	// test with +Inf
  1207  	T = New(WithShape(4), WithBacking([]float64{1, 2, math.Inf(1), 4}))
  1208  	if argmin, err = T.Argmin(AllAxes); err != nil {
  1209  		t.Errorf("Failed test with +Inf: %v", err)
  1210  	}
  1211  	assert.True(argmin.IsScalar())
  1212  	assert.Equal(0, argmin.ScalarValue(), "+Inf test")
  1213  
  1214  	// test with Mask and +Inf
  1215  	T = New(WithShape(4), WithBacking([]float64{1, -9, math.Inf(1), 4}, []bool{false, true, true, false}))
  1216  	if argmin, err = T.Argmin(AllAxes); err != nil {
  1217  		t.Errorf("Failed test with NaN: %v", err)
  1218  	}
  1219  	assert.True(argmin.IsScalar())
  1220  	assert.Equal(0, argmin.ScalarValue(), "Masked NaN test")
  1221  
  1222  	// test with -Inf
  1223  	T = New(WithShape(4), WithBacking([]float64{1, 2, math.Inf(-1), 4}))
  1224  	if argmin, err = T.Argmin(AllAxes); err != nil {
  1225  		t.Errorf("Failed test with -Inf: %v", err)
  1226  	}
  1227  	assert.True(argmin.IsScalar())
  1228  	assert.Equal(2, argmin.ScalarValue(), "+Inf test")
  1229  
  1230  	// test with Mask and -Inf
  1231  	T = New(WithShape(4), WithBacking([]float64{1, -9, math.Inf(-1), 4}, []bool{false, true, true, false}))
  1232  	if argmin, err = T.Argmin(AllAxes); err != nil {
  1233  		t.Errorf("Failed test with NaN: %v", err)
  1234  	}
  1235  	assert.True(argmin.IsScalar())
  1236  	assert.Equal(0, argmin.ScalarValue(), "Masked -Inf test")
  1237  
  1238  	// with different engine
  1239  	T = basicDenseF64.Clone().(*Dense)
  1240  	WithEngine(dummyEngine2{})(T)
  1241  	for i := 0; i < T.Dims(); i++ {
  1242  		if argmin, err = T.Argmin(i); err != nil {
  1243  			t.Error(err)
  1244  			continue
  1245  		}
  1246  
  1247  		assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape)
  1248  		assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i)
  1249  	}
  1250  
  1251  	// idiotsville
  1252  	_, err = T.Argmin(10000)
  1253  	assert.NotNil(err)
  1254  
  1255  }