github.com/wzzhu/tensor@v0.9.24/dense_argmethods_test.go (about) 1 // Code generated by genlib2. DO NOT EDIT. 2 3 package tensor 4 5 import ( 6 "math" 7 "testing" 8 9 "github.com/chewxy/math32" 10 "github.com/stretchr/testify/assert" 11 ) 12 13 /* Test data */ 14 15 var basicDenseI = New(WithShape(2, 3, 4, 5, 2), WithBacking([]int{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3})) 16 var basicDenseI8 = New(WithShape(2, 3, 4, 5, 2), WithBacking([]int8{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3})) 17 var basicDenseI16 = New(WithShape(2, 3, 4, 5, 2), WithBacking([]int16{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3})) 18 var basicDenseI32 = New(WithShape(2, 3, 4, 5, 2), WithBacking([]int32{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3})) 19 var basicDenseI64 = New(WithShape(2, 3, 4, 5, 2), WithBacking([]int64{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3})) 20 var basicDenseU = New(WithShape(2, 3, 4, 5, 2), WithBacking([]uint{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3})) 21 var basicDenseU8 = New(WithShape(2, 3, 4, 5, 2), WithBacking([]uint8{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3})) 22 var basicDenseU16 = New(WithShape(2, 3, 4, 5, 2), WithBacking([]uint16{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3})) 23 var basicDenseU32 = New(WithShape(2, 3, 4, 5, 2), WithBacking([]uint32{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3})) 24 var basicDenseU64 = New(WithShape(2, 3, 4, 5, 2), WithBacking([]uint64{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3})) 25 var basicDenseF32 = New(WithShape(2, 3, 4, 5, 2), WithBacking([]float32{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3})) 26 var basicDenseF64 = New(WithShape(2, 3, 4, 5, 2), WithBacking([]float64{3, 4, 2, 4, 3, 8, 3, 9, 7, 4, 3, 0, 3, 9, 9, 0, 6, 7, 3, 9, 4, 8, 5, 1, 1, 9, 4, 0, 4, 1, 6, 6, 4, 9, 3, 8, 1, 7, 0, 7, 4, 0, 6, 8, 2, 8, 0, 6, 1, 6, 2, 3, 7, 5, 7, 3, 0, 8, 6, 5, 6, 9, 7, 5, 6, 8, 7, 9, 5, 0, 8, 1, 4, 0, 6, 6, 3, 3, 8, 1, 1, 3, 2, 5, 9, 0, 4, 5, 3, 1, 9, 1, 9, 3, 9, 3, 3, 4, 5, 9, 4, 2, 2, 7, 9, 8, 1, 6, 9, 4, 4, 1, 8, 9, 8, 0, 9, 9, 4, 6, 7, 5, 9, 9, 4, 8, 5, 8, 2, 4, 8, 2, 7, 2, 8, 7, 2, 3, 7, 0, 9, 9, 8, 9, 2, 1, 7, 0, 7, 9, 0, 2, 4, 8, 7, 9, 6, 8, 3, 3, 7, 2, 9, 2, 8, 2, 3, 6, 0, 8, 7, 7, 0, 9, 0, 9, 3, 2, 6, 9, 5, 8, 6, 9, 5, 6, 1, 8, 7, 8, 1, 9, 9, 3, 7, 7, 6, 8, 2, 1, 1, 5, 1, 4, 0, 5, 1, 7, 9, 5, 6, 6, 8, 7, 5, 1, 3, 4, 0, 1, 8, 0, 2, 6, 9, 1, 4, 8, 0, 5, 6, 2, 9, 4, 4, 2, 4, 4, 4, 3})) 27 28 var argmaxCorrect = []struct { 29 shape Shape 30 data []int 31 }{ 32 {Shape{3, 4, 5, 2}, []int{ 33 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 34 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 35 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 36 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 37 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 38 1, 0, 0, 0, 0, 39 }}, 40 {Shape{2, 4, 5, 2}, []int{ 41 1, 0, 1, 1, 2, 0, 2, 0, 0, 1, 2, 1, 2, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 42 2, 2, 0, 1, 1, 2, 2, 1, 0, 2, 0, 2, 0, 2, 2, 1, 0, 0, 0, 0, 0, 1, 0, 43 0, 0, 2, 1, 0, 1, 2, 1, 0, 1, 1, 2, 0, 1, 0, 0, 0, 0, 2, 1, 0, 1, 0, 44 0, 2, 1, 1, 0, 0, 0, 0, 0, 2, 0, 45 }}, 46 {Shape{2, 3, 5, 2}, []int{ 47 3, 2, 2, 1, 1, 2, 1, 0, 0, 1, 3, 2, 1, 0, 1, 0, 2, 2, 3, 0, 1, 0, 1, 48 3, 0, 2, 3, 3, 2, 1, 2, 2, 0, 0, 1, 3, 2, 0, 1, 2, 0, 3, 0, 1, 0, 1, 49 3, 2, 2, 1, 2, 1, 3, 1, 2, 0, 2, 2, 0, 0, 50 }}, 51 {Shape{2, 3, 4, 2}, []int{ 52 4, 3, 2, 1, 1, 2, 0, 1, 1, 1, 1, 3, 1, 0, 0, 2, 2, 1, 0, 4, 2, 2, 3, 53 1, 1, 1, 0, 2, 0, 0, 2, 2, 1, 4, 0, 1, 4, 1, 1, 0, 4, 3, 1, 1, 2, 3, 54 1, 1, 55 }}, 56 {Shape{2, 3, 4, 5}, []int{ 57 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 58 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 59 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 60 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 61 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 62 0, 0, 0, 0, 0, 63 }}, 64 } 65 66 var argminCorrect = []struct { 67 shape Shape 68 data []int 69 }{ 70 {Shape{3, 4, 5, 2}, []int{ 71 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 72 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 73 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 74 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 75 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 76 0, 1, 1, 0, 1, 77 }}, 78 {Shape{2, 4, 5, 2}, []int{ 79 2, 1, 0, 0, 1, 2, 1, 2, 1, 2, 1, 0, 0, 2, 1, 0, 1, 2, 0, 1, 0, 2, 2, 80 0, 0, 1, 2, 0, 0, 1, 2, 1, 0, 1, 0, 2, 0, 1, 0, 1, 2, 1, 2, 1, 2, 1, 81 2, 1, 1, 0, 2, 0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2, 2, 2, 0, 0, 1, 0, 2, 82 2, 0, 0, 0, 1, 2, 2, 2, 2, 1, 1, 83 }}, 84 {Shape{2, 3, 5, 2}, []int{ 85 0, 1, 0, 2, 2, 1, 3, 2, 3, 2, 1, 0, 3, 3, 0, 1, 0, 3, 0, 2, 0, 1, 0, 86 1, 3, 0, 2, 1, 0, 0, 3, 1, 3, 1, 2, 2, 1, 2, 0, 1, 3, 0, 1, 0, 1, 0, 87 2, 1, 0, 3, 0, 2, 0, 0, 0, 1, 0, 1, 1, 1, 88 }}, 89 {Shape{2, 3, 4, 2}, []int{ 90 1, 0, 0, 0, 2, 3, 4, 0, 3, 0, 3, 0, 4, 4, 3, 1, 0, 2, 3, 0, 3, 0, 0, 91 2, 4, 4, 3, 4, 2, 3, 0, 0, 4, 0, 1, 3, 3, 2, 0, 4, 2, 1, 4, 2, 4, 0, 92 2, 0, 93 }}, 94 {Shape{2, 3, 4, 5}, []int{ 95 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 96 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 97 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 98 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 99 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 100 1, 1, 1, 0, 1, 101 }}, 102 } 103 104 func TestDense_Argmax_I(t *testing.T) { 105 assert := assert.New(t) 106 var T, argmax *Dense 107 var err error 108 T = basicDenseI.Clone().(*Dense) 109 for i := 0; i < T.Dims(); i++ { 110 if argmax, err = T.Argmax(i); err != nil { 111 t.Error(err) 112 continue 113 } 114 115 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 116 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 117 } 118 // test all axes 119 if argmax, err = T.Argmax(AllAxes); err != nil { 120 t.Error(err) 121 return 122 } 123 assert.True(argmax.IsScalar()) 124 assert.Equal(7, argmax.ScalarValue()) 125 126 // with different engine 127 T = basicDenseI.Clone().(*Dense) 128 WithEngine(dummyEngine2{})(T) 129 for i := 0; i < T.Dims(); i++ { 130 if argmax, err = T.Argmax(i); err != nil { 131 t.Error(err) 132 continue 133 } 134 135 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 136 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 137 } 138 139 // idiotsville 140 _, err = T.Argmax(10000) 141 assert.NotNil(err) 142 143 } 144 func TestDense_Argmin_I(t *testing.T) { 145 assert := assert.New(t) 146 var T, argmin *Dense 147 var err error 148 T = basicDenseI.Clone().(*Dense) 149 for i := 0; i < T.Dims(); i++ { 150 if argmin, err = T.Argmin(i); err != nil { 151 t.Error(err) 152 continue 153 } 154 155 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 156 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 157 } 158 // test all axes 159 if argmin, err = T.Argmin(AllAxes); err != nil { 160 t.Error(err) 161 return 162 } 163 assert.True(argmin.IsScalar()) 164 assert.Equal(11, argmin.ScalarValue()) 165 166 // with different engine 167 T = basicDenseI.Clone().(*Dense) 168 WithEngine(dummyEngine2{})(T) 169 for i := 0; i < T.Dims(); i++ { 170 if argmin, err = T.Argmin(i); err != nil { 171 t.Error(err) 172 continue 173 } 174 175 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 176 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 177 } 178 179 // idiotsville 180 _, err = T.Argmin(10000) 181 assert.NotNil(err) 182 183 } 184 func TestDense_Argmax_I8(t *testing.T) { 185 assert := assert.New(t) 186 var T, argmax *Dense 187 var err error 188 T = basicDenseI8.Clone().(*Dense) 189 for i := 0; i < T.Dims(); i++ { 190 if argmax, err = T.Argmax(i); err != nil { 191 t.Error(err) 192 continue 193 } 194 195 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 196 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 197 } 198 // test all axes 199 if argmax, err = T.Argmax(AllAxes); err != nil { 200 t.Error(err) 201 return 202 } 203 assert.True(argmax.IsScalar()) 204 assert.Equal(7, argmax.ScalarValue()) 205 206 // with different engine 207 T = basicDenseI8.Clone().(*Dense) 208 WithEngine(dummyEngine2{})(T) 209 for i := 0; i < T.Dims(); i++ { 210 if argmax, err = T.Argmax(i); err != nil { 211 t.Error(err) 212 continue 213 } 214 215 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 216 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 217 } 218 219 // idiotsville 220 _, err = T.Argmax(10000) 221 assert.NotNil(err) 222 223 } 224 func TestDense_Argmin_I8(t *testing.T) { 225 assert := assert.New(t) 226 var T, argmin *Dense 227 var err error 228 T = basicDenseI8.Clone().(*Dense) 229 for i := 0; i < T.Dims(); i++ { 230 if argmin, err = T.Argmin(i); err != nil { 231 t.Error(err) 232 continue 233 } 234 235 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 236 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 237 } 238 // test all axes 239 if argmin, err = T.Argmin(AllAxes); err != nil { 240 t.Error(err) 241 return 242 } 243 assert.True(argmin.IsScalar()) 244 assert.Equal(11, argmin.ScalarValue()) 245 246 // with different engine 247 T = basicDenseI8.Clone().(*Dense) 248 WithEngine(dummyEngine2{})(T) 249 for i := 0; i < T.Dims(); i++ { 250 if argmin, err = T.Argmin(i); err != nil { 251 t.Error(err) 252 continue 253 } 254 255 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 256 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 257 } 258 259 // idiotsville 260 _, err = T.Argmin(10000) 261 assert.NotNil(err) 262 263 } 264 func TestDense_Argmax_I16(t *testing.T) { 265 assert := assert.New(t) 266 var T, argmax *Dense 267 var err error 268 T = basicDenseI16.Clone().(*Dense) 269 for i := 0; i < T.Dims(); i++ { 270 if argmax, err = T.Argmax(i); err != nil { 271 t.Error(err) 272 continue 273 } 274 275 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 276 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 277 } 278 // test all axes 279 if argmax, err = T.Argmax(AllAxes); err != nil { 280 t.Error(err) 281 return 282 } 283 assert.True(argmax.IsScalar()) 284 assert.Equal(7, argmax.ScalarValue()) 285 286 // with different engine 287 T = basicDenseI16.Clone().(*Dense) 288 WithEngine(dummyEngine2{})(T) 289 for i := 0; i < T.Dims(); i++ { 290 if argmax, err = T.Argmax(i); err != nil { 291 t.Error(err) 292 continue 293 } 294 295 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 296 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 297 } 298 299 // idiotsville 300 _, err = T.Argmax(10000) 301 assert.NotNil(err) 302 303 } 304 func TestDense_Argmin_I16(t *testing.T) { 305 assert := assert.New(t) 306 var T, argmin *Dense 307 var err error 308 T = basicDenseI16.Clone().(*Dense) 309 for i := 0; i < T.Dims(); i++ { 310 if argmin, err = T.Argmin(i); err != nil { 311 t.Error(err) 312 continue 313 } 314 315 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 316 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 317 } 318 // test all axes 319 if argmin, err = T.Argmin(AllAxes); err != nil { 320 t.Error(err) 321 return 322 } 323 assert.True(argmin.IsScalar()) 324 assert.Equal(11, argmin.ScalarValue()) 325 326 // with different engine 327 T = basicDenseI16.Clone().(*Dense) 328 WithEngine(dummyEngine2{})(T) 329 for i := 0; i < T.Dims(); i++ { 330 if argmin, err = T.Argmin(i); err != nil { 331 t.Error(err) 332 continue 333 } 334 335 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 336 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 337 } 338 339 // idiotsville 340 _, err = T.Argmin(10000) 341 assert.NotNil(err) 342 343 } 344 func TestDense_Argmax_I32(t *testing.T) { 345 assert := assert.New(t) 346 var T, argmax *Dense 347 var err error 348 T = basicDenseI32.Clone().(*Dense) 349 for i := 0; i < T.Dims(); i++ { 350 if argmax, err = T.Argmax(i); err != nil { 351 t.Error(err) 352 continue 353 } 354 355 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 356 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 357 } 358 // test all axes 359 if argmax, err = T.Argmax(AllAxes); err != nil { 360 t.Error(err) 361 return 362 } 363 assert.True(argmax.IsScalar()) 364 assert.Equal(7, argmax.ScalarValue()) 365 366 // with different engine 367 T = basicDenseI32.Clone().(*Dense) 368 WithEngine(dummyEngine2{})(T) 369 for i := 0; i < T.Dims(); i++ { 370 if argmax, err = T.Argmax(i); err != nil { 371 t.Error(err) 372 continue 373 } 374 375 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 376 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 377 } 378 379 // idiotsville 380 _, err = T.Argmax(10000) 381 assert.NotNil(err) 382 383 } 384 func TestDense_Argmin_I32(t *testing.T) { 385 assert := assert.New(t) 386 var T, argmin *Dense 387 var err error 388 T = basicDenseI32.Clone().(*Dense) 389 for i := 0; i < T.Dims(); i++ { 390 if argmin, err = T.Argmin(i); err != nil { 391 t.Error(err) 392 continue 393 } 394 395 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 396 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 397 } 398 // test all axes 399 if argmin, err = T.Argmin(AllAxes); err != nil { 400 t.Error(err) 401 return 402 } 403 assert.True(argmin.IsScalar()) 404 assert.Equal(11, argmin.ScalarValue()) 405 406 // with different engine 407 T = basicDenseI32.Clone().(*Dense) 408 WithEngine(dummyEngine2{})(T) 409 for i := 0; i < T.Dims(); i++ { 410 if argmin, err = T.Argmin(i); err != nil { 411 t.Error(err) 412 continue 413 } 414 415 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 416 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 417 } 418 419 // idiotsville 420 _, err = T.Argmin(10000) 421 assert.NotNil(err) 422 423 } 424 func TestDense_Argmax_I64(t *testing.T) { 425 assert := assert.New(t) 426 var T, argmax *Dense 427 var err error 428 T = basicDenseI64.Clone().(*Dense) 429 for i := 0; i < T.Dims(); i++ { 430 if argmax, err = T.Argmax(i); err != nil { 431 t.Error(err) 432 continue 433 } 434 435 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 436 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 437 } 438 // test all axes 439 if argmax, err = T.Argmax(AllAxes); err != nil { 440 t.Error(err) 441 return 442 } 443 assert.True(argmax.IsScalar()) 444 assert.Equal(7, argmax.ScalarValue()) 445 446 // with different engine 447 T = basicDenseI64.Clone().(*Dense) 448 WithEngine(dummyEngine2{})(T) 449 for i := 0; i < T.Dims(); i++ { 450 if argmax, err = T.Argmax(i); err != nil { 451 t.Error(err) 452 continue 453 } 454 455 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 456 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 457 } 458 459 // idiotsville 460 _, err = T.Argmax(10000) 461 assert.NotNil(err) 462 463 } 464 func TestDense_Argmin_I64(t *testing.T) { 465 assert := assert.New(t) 466 var T, argmin *Dense 467 var err error 468 T = basicDenseI64.Clone().(*Dense) 469 for i := 0; i < T.Dims(); i++ { 470 if argmin, err = T.Argmin(i); err != nil { 471 t.Error(err) 472 continue 473 } 474 475 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 476 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 477 } 478 // test all axes 479 if argmin, err = T.Argmin(AllAxes); err != nil { 480 t.Error(err) 481 return 482 } 483 assert.True(argmin.IsScalar()) 484 assert.Equal(11, argmin.ScalarValue()) 485 486 // with different engine 487 T = basicDenseI64.Clone().(*Dense) 488 WithEngine(dummyEngine2{})(T) 489 for i := 0; i < T.Dims(); i++ { 490 if argmin, err = T.Argmin(i); err != nil { 491 t.Error(err) 492 continue 493 } 494 495 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 496 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 497 } 498 499 // idiotsville 500 _, err = T.Argmin(10000) 501 assert.NotNil(err) 502 503 } 504 func TestDense_Argmax_U(t *testing.T) { 505 assert := assert.New(t) 506 var T, argmax *Dense 507 var err error 508 T = basicDenseU.Clone().(*Dense) 509 for i := 0; i < T.Dims(); i++ { 510 if argmax, err = T.Argmax(i); err != nil { 511 t.Error(err) 512 continue 513 } 514 515 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 516 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 517 } 518 // test all axes 519 if argmax, err = T.Argmax(AllAxes); err != nil { 520 t.Error(err) 521 return 522 } 523 assert.True(argmax.IsScalar()) 524 assert.Equal(7, argmax.ScalarValue()) 525 526 // with different engine 527 T = basicDenseU.Clone().(*Dense) 528 WithEngine(dummyEngine2{})(T) 529 for i := 0; i < T.Dims(); i++ { 530 if argmax, err = T.Argmax(i); err != nil { 531 t.Error(err) 532 continue 533 } 534 535 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 536 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 537 } 538 539 // idiotsville 540 _, err = T.Argmax(10000) 541 assert.NotNil(err) 542 543 } 544 func TestDense_Argmin_U(t *testing.T) { 545 assert := assert.New(t) 546 var T, argmin *Dense 547 var err error 548 T = basicDenseU.Clone().(*Dense) 549 for i := 0; i < T.Dims(); i++ { 550 if argmin, err = T.Argmin(i); err != nil { 551 t.Error(err) 552 continue 553 } 554 555 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 556 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 557 } 558 // test all axes 559 if argmin, err = T.Argmin(AllAxes); err != nil { 560 t.Error(err) 561 return 562 } 563 assert.True(argmin.IsScalar()) 564 assert.Equal(11, argmin.ScalarValue()) 565 566 // with different engine 567 T = basicDenseU.Clone().(*Dense) 568 WithEngine(dummyEngine2{})(T) 569 for i := 0; i < T.Dims(); i++ { 570 if argmin, err = T.Argmin(i); err != nil { 571 t.Error(err) 572 continue 573 } 574 575 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 576 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 577 } 578 579 // idiotsville 580 _, err = T.Argmin(10000) 581 assert.NotNil(err) 582 583 } 584 func TestDense_Argmax_U8(t *testing.T) { 585 assert := assert.New(t) 586 var T, argmax *Dense 587 var err error 588 T = basicDenseU8.Clone().(*Dense) 589 for i := 0; i < T.Dims(); i++ { 590 if argmax, err = T.Argmax(i); err != nil { 591 t.Error(err) 592 continue 593 } 594 595 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 596 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 597 } 598 // test all axes 599 if argmax, err = T.Argmax(AllAxes); err != nil { 600 t.Error(err) 601 return 602 } 603 assert.True(argmax.IsScalar()) 604 assert.Equal(7, argmax.ScalarValue()) 605 606 // with different engine 607 T = basicDenseU8.Clone().(*Dense) 608 WithEngine(dummyEngine2{})(T) 609 for i := 0; i < T.Dims(); i++ { 610 if argmax, err = T.Argmax(i); err != nil { 611 t.Error(err) 612 continue 613 } 614 615 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 616 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 617 } 618 619 // idiotsville 620 _, err = T.Argmax(10000) 621 assert.NotNil(err) 622 623 } 624 func TestDense_Argmin_U8(t *testing.T) { 625 assert := assert.New(t) 626 var T, argmin *Dense 627 var err error 628 T = basicDenseU8.Clone().(*Dense) 629 for i := 0; i < T.Dims(); i++ { 630 if argmin, err = T.Argmin(i); err != nil { 631 t.Error(err) 632 continue 633 } 634 635 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 636 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 637 } 638 // test all axes 639 if argmin, err = T.Argmin(AllAxes); err != nil { 640 t.Error(err) 641 return 642 } 643 assert.True(argmin.IsScalar()) 644 assert.Equal(11, argmin.ScalarValue()) 645 646 // with different engine 647 T = basicDenseU8.Clone().(*Dense) 648 WithEngine(dummyEngine2{})(T) 649 for i := 0; i < T.Dims(); i++ { 650 if argmin, err = T.Argmin(i); err != nil { 651 t.Error(err) 652 continue 653 } 654 655 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 656 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 657 } 658 659 // idiotsville 660 _, err = T.Argmin(10000) 661 assert.NotNil(err) 662 663 } 664 func TestDense_Argmax_U16(t *testing.T) { 665 assert := assert.New(t) 666 var T, argmax *Dense 667 var err error 668 T = basicDenseU16.Clone().(*Dense) 669 for i := 0; i < T.Dims(); i++ { 670 if argmax, err = T.Argmax(i); err != nil { 671 t.Error(err) 672 continue 673 } 674 675 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 676 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 677 } 678 // test all axes 679 if argmax, err = T.Argmax(AllAxes); err != nil { 680 t.Error(err) 681 return 682 } 683 assert.True(argmax.IsScalar()) 684 assert.Equal(7, argmax.ScalarValue()) 685 686 // with different engine 687 T = basicDenseU16.Clone().(*Dense) 688 WithEngine(dummyEngine2{})(T) 689 for i := 0; i < T.Dims(); i++ { 690 if argmax, err = T.Argmax(i); err != nil { 691 t.Error(err) 692 continue 693 } 694 695 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 696 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 697 } 698 699 // idiotsville 700 _, err = T.Argmax(10000) 701 assert.NotNil(err) 702 703 } 704 func TestDense_Argmin_U16(t *testing.T) { 705 assert := assert.New(t) 706 var T, argmin *Dense 707 var err error 708 T = basicDenseU16.Clone().(*Dense) 709 for i := 0; i < T.Dims(); i++ { 710 if argmin, err = T.Argmin(i); err != nil { 711 t.Error(err) 712 continue 713 } 714 715 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 716 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 717 } 718 // test all axes 719 if argmin, err = T.Argmin(AllAxes); err != nil { 720 t.Error(err) 721 return 722 } 723 assert.True(argmin.IsScalar()) 724 assert.Equal(11, argmin.ScalarValue()) 725 726 // with different engine 727 T = basicDenseU16.Clone().(*Dense) 728 WithEngine(dummyEngine2{})(T) 729 for i := 0; i < T.Dims(); i++ { 730 if argmin, err = T.Argmin(i); err != nil { 731 t.Error(err) 732 continue 733 } 734 735 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 736 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 737 } 738 739 // idiotsville 740 _, err = T.Argmin(10000) 741 assert.NotNil(err) 742 743 } 744 func TestDense_Argmax_U32(t *testing.T) { 745 assert := assert.New(t) 746 var T, argmax *Dense 747 var err error 748 T = basicDenseU32.Clone().(*Dense) 749 for i := 0; i < T.Dims(); i++ { 750 if argmax, err = T.Argmax(i); err != nil { 751 t.Error(err) 752 continue 753 } 754 755 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 756 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 757 } 758 // test all axes 759 if argmax, err = T.Argmax(AllAxes); err != nil { 760 t.Error(err) 761 return 762 } 763 assert.True(argmax.IsScalar()) 764 assert.Equal(7, argmax.ScalarValue()) 765 766 // with different engine 767 T = basicDenseU32.Clone().(*Dense) 768 WithEngine(dummyEngine2{})(T) 769 for i := 0; i < T.Dims(); i++ { 770 if argmax, err = T.Argmax(i); err != nil { 771 t.Error(err) 772 continue 773 } 774 775 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 776 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 777 } 778 779 // idiotsville 780 _, err = T.Argmax(10000) 781 assert.NotNil(err) 782 783 } 784 func TestDense_Argmin_U32(t *testing.T) { 785 assert := assert.New(t) 786 var T, argmin *Dense 787 var err error 788 T = basicDenseU32.Clone().(*Dense) 789 for i := 0; i < T.Dims(); i++ { 790 if argmin, err = T.Argmin(i); err != nil { 791 t.Error(err) 792 continue 793 } 794 795 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 796 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 797 } 798 // test all axes 799 if argmin, err = T.Argmin(AllAxes); err != nil { 800 t.Error(err) 801 return 802 } 803 assert.True(argmin.IsScalar()) 804 assert.Equal(11, argmin.ScalarValue()) 805 806 // with different engine 807 T = basicDenseU32.Clone().(*Dense) 808 WithEngine(dummyEngine2{})(T) 809 for i := 0; i < T.Dims(); i++ { 810 if argmin, err = T.Argmin(i); err != nil { 811 t.Error(err) 812 continue 813 } 814 815 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 816 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 817 } 818 819 // idiotsville 820 _, err = T.Argmin(10000) 821 assert.NotNil(err) 822 823 } 824 func TestDense_Argmax_U64(t *testing.T) { 825 assert := assert.New(t) 826 var T, argmax *Dense 827 var err error 828 T = basicDenseU64.Clone().(*Dense) 829 for i := 0; i < T.Dims(); i++ { 830 if argmax, err = T.Argmax(i); err != nil { 831 t.Error(err) 832 continue 833 } 834 835 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 836 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 837 } 838 // test all axes 839 if argmax, err = T.Argmax(AllAxes); err != nil { 840 t.Error(err) 841 return 842 } 843 assert.True(argmax.IsScalar()) 844 assert.Equal(7, argmax.ScalarValue()) 845 846 // with different engine 847 T = basicDenseU64.Clone().(*Dense) 848 WithEngine(dummyEngine2{})(T) 849 for i := 0; i < T.Dims(); i++ { 850 if argmax, err = T.Argmax(i); err != nil { 851 t.Error(err) 852 continue 853 } 854 855 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 856 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 857 } 858 859 // idiotsville 860 _, err = T.Argmax(10000) 861 assert.NotNil(err) 862 863 } 864 func TestDense_Argmin_U64(t *testing.T) { 865 assert := assert.New(t) 866 var T, argmin *Dense 867 var err error 868 T = basicDenseU64.Clone().(*Dense) 869 for i := 0; i < T.Dims(); i++ { 870 if argmin, err = T.Argmin(i); err != nil { 871 t.Error(err) 872 continue 873 } 874 875 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 876 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 877 } 878 // test all axes 879 if argmin, err = T.Argmin(AllAxes); err != nil { 880 t.Error(err) 881 return 882 } 883 assert.True(argmin.IsScalar()) 884 assert.Equal(11, argmin.ScalarValue()) 885 886 // with different engine 887 T = basicDenseU64.Clone().(*Dense) 888 WithEngine(dummyEngine2{})(T) 889 for i := 0; i < T.Dims(); i++ { 890 if argmin, err = T.Argmin(i); err != nil { 891 t.Error(err) 892 continue 893 } 894 895 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 896 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 897 } 898 899 // idiotsville 900 _, err = T.Argmin(10000) 901 assert.NotNil(err) 902 903 } 904 func TestDense_Argmax_F32(t *testing.T) { 905 assert := assert.New(t) 906 var T, argmax *Dense 907 var err error 908 T = basicDenseF32.Clone().(*Dense) 909 for i := 0; i < T.Dims(); i++ { 910 if argmax, err = T.Argmax(i); err != nil { 911 t.Error(err) 912 continue 913 } 914 915 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 916 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 917 } 918 // test all axes 919 if argmax, err = T.Argmax(AllAxes); err != nil { 920 t.Error(err) 921 return 922 } 923 assert.True(argmax.IsScalar()) 924 assert.Equal(7, argmax.ScalarValue()) 925 926 // test with NaN 927 T = New(WithShape(4), WithBacking([]float32{1, 2, math32.NaN(), 4})) 928 if argmax, err = T.Argmax(AllAxes); err != nil { 929 t.Errorf("Failed test with NaN: %v", err) 930 } 931 assert.True(argmax.IsScalar()) 932 assert.Equal(2, argmax.ScalarValue(), "NaN test") 933 934 // test with Mask and Nan 935 T = New(WithShape(4), WithBacking([]float32{1, 9, math32.NaN(), 4}, []bool{false, true, true, false})) 936 if argmax, err = T.Argmax(AllAxes); err != nil { 937 t.Errorf("Failed test with NaN: %v", err) 938 } 939 assert.True(argmax.IsScalar()) 940 assert.Equal(3, argmax.ScalarValue(), "Masked NaN test") 941 942 // test with +Inf 943 T = New(WithShape(4), WithBacking([]float32{1, 2, math32.Inf(1), 4})) 944 if argmax, err = T.Argmax(AllAxes); err != nil { 945 t.Errorf("Failed test with +Inf: %v", err) 946 } 947 assert.True(argmax.IsScalar()) 948 assert.Equal(2, argmax.ScalarValue(), "+Inf test") 949 950 // test with Mask and +Inf 951 T = New(WithShape(4), WithBacking([]float32{1, 9, math32.Inf(1), 4}, []bool{false, true, true, false})) 952 if argmax, err = T.Argmax(AllAxes); err != nil { 953 t.Errorf("Failed test with NaN: %v", err) 954 } 955 assert.True(argmax.IsScalar()) 956 assert.Equal(3, argmax.ScalarValue(), "Masked NaN test") 957 958 // test with -Inf 959 T = New(WithShape(4), WithBacking([]float32{1, 2, math32.Inf(-1), 4})) 960 if argmax, err = T.Argmax(AllAxes); err != nil { 961 t.Errorf("Failed test with -Inf: %v", err) 962 } 963 assert.True(argmax.IsScalar()) 964 assert.Equal(3, argmax.ScalarValue(), "+Inf test") 965 966 // test with Mask and -Inf 967 T = New(WithShape(4), WithBacking([]float32{1, 9, math32.Inf(-1), 4}, []bool{false, true, true, false})) 968 if argmax, err = T.Argmax(AllAxes); err != nil { 969 t.Errorf("Failed test with NaN: %v", err) 970 } 971 assert.True(argmax.IsScalar()) 972 assert.Equal(3, argmax.ScalarValue(), "Masked -Inf test") 973 974 // with different engine 975 T = basicDenseF32.Clone().(*Dense) 976 WithEngine(dummyEngine2{})(T) 977 for i := 0; i < T.Dims(); i++ { 978 if argmax, err = T.Argmax(i); err != nil { 979 t.Error(err) 980 continue 981 } 982 983 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 984 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 985 } 986 987 // idiotsville 988 _, err = T.Argmax(10000) 989 assert.NotNil(err) 990 991 } 992 func TestDense_Argmin_F32(t *testing.T) { 993 assert := assert.New(t) 994 var T, argmin *Dense 995 var err error 996 T = basicDenseF32.Clone().(*Dense) 997 for i := 0; i < T.Dims(); i++ { 998 if argmin, err = T.Argmin(i); err != nil { 999 t.Error(err) 1000 continue 1001 } 1002 1003 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 1004 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 1005 } 1006 // test all axes 1007 if argmin, err = T.Argmin(AllAxes); err != nil { 1008 t.Error(err) 1009 return 1010 } 1011 assert.True(argmin.IsScalar()) 1012 assert.Equal(11, argmin.ScalarValue()) 1013 1014 // test with NaN 1015 T = New(WithShape(4), WithBacking([]float32{1, 2, math32.NaN(), 4})) 1016 if argmin, err = T.Argmin(AllAxes); err != nil { 1017 t.Errorf("Failed test with NaN: %v", err) 1018 } 1019 assert.True(argmin.IsScalar()) 1020 assert.Equal(2, argmin.ScalarValue(), "NaN test") 1021 1022 // test with Mask and Nan 1023 T = New(WithShape(4), WithBacking([]float32{1, -9, math32.NaN(), 4}, []bool{false, true, true, false})) 1024 if argmin, err = T.Argmin(AllAxes); err != nil { 1025 t.Errorf("Failed test with NaN: %v", err) 1026 } 1027 assert.True(argmin.IsScalar()) 1028 assert.Equal(0, argmin.ScalarValue(), "Masked NaN test") 1029 1030 // test with +Inf 1031 T = New(WithShape(4), WithBacking([]float32{1, 2, math32.Inf(1), 4})) 1032 if argmin, err = T.Argmin(AllAxes); err != nil { 1033 t.Errorf("Failed test with +Inf: %v", err) 1034 } 1035 assert.True(argmin.IsScalar()) 1036 assert.Equal(0, argmin.ScalarValue(), "+Inf test") 1037 1038 // test with Mask and +Inf 1039 T = New(WithShape(4), WithBacking([]float32{1, -9, math32.Inf(1), 4}, []bool{false, true, true, false})) 1040 if argmin, err = T.Argmin(AllAxes); err != nil { 1041 t.Errorf("Failed test with NaN: %v", err) 1042 } 1043 assert.True(argmin.IsScalar()) 1044 assert.Equal(0, argmin.ScalarValue(), "Masked NaN test") 1045 1046 // test with -Inf 1047 T = New(WithShape(4), WithBacking([]float32{1, 2, math32.Inf(-1), 4})) 1048 if argmin, err = T.Argmin(AllAxes); err != nil { 1049 t.Errorf("Failed test with -Inf: %v", err) 1050 } 1051 assert.True(argmin.IsScalar()) 1052 assert.Equal(2, argmin.ScalarValue(), "+Inf test") 1053 1054 // test with Mask and -Inf 1055 T = New(WithShape(4), WithBacking([]float32{1, -9, math32.Inf(-1), 4}, []bool{false, true, true, false})) 1056 if argmin, err = T.Argmin(AllAxes); err != nil { 1057 t.Errorf("Failed test with NaN: %v", err) 1058 } 1059 assert.True(argmin.IsScalar()) 1060 assert.Equal(0, argmin.ScalarValue(), "Masked -Inf test") 1061 1062 // with different engine 1063 T = basicDenseF32.Clone().(*Dense) 1064 WithEngine(dummyEngine2{})(T) 1065 for i := 0; i < T.Dims(); i++ { 1066 if argmin, err = T.Argmin(i); err != nil { 1067 t.Error(err) 1068 continue 1069 } 1070 1071 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 1072 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 1073 } 1074 1075 // idiotsville 1076 _, err = T.Argmin(10000) 1077 assert.NotNil(err) 1078 1079 } 1080 func TestDense_Argmax_F64(t *testing.T) { 1081 assert := assert.New(t) 1082 var T, argmax *Dense 1083 var err error 1084 T = basicDenseF64.Clone().(*Dense) 1085 for i := 0; i < T.Dims(); i++ { 1086 if argmax, err = T.Argmax(i); err != nil { 1087 t.Error(err) 1088 continue 1089 } 1090 1091 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 1092 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 1093 } 1094 // test all axes 1095 if argmax, err = T.Argmax(AllAxes); err != nil { 1096 t.Error(err) 1097 return 1098 } 1099 assert.True(argmax.IsScalar()) 1100 assert.Equal(7, argmax.ScalarValue()) 1101 1102 // test with NaN 1103 T = New(WithShape(4), WithBacking([]float64{1, 2, math.NaN(), 4})) 1104 if argmax, err = T.Argmax(AllAxes); err != nil { 1105 t.Errorf("Failed test with NaN: %v", err) 1106 } 1107 assert.True(argmax.IsScalar()) 1108 assert.Equal(2, argmax.ScalarValue(), "NaN test") 1109 1110 // test with Mask and Nan 1111 T = New(WithShape(4), WithBacking([]float64{1, 9, math.NaN(), 4}, []bool{false, true, true, false})) 1112 if argmax, err = T.Argmax(AllAxes); err != nil { 1113 t.Errorf("Failed test with NaN: %v", err) 1114 } 1115 assert.True(argmax.IsScalar()) 1116 assert.Equal(3, argmax.ScalarValue(), "Masked NaN test") 1117 1118 // test with +Inf 1119 T = New(WithShape(4), WithBacking([]float64{1, 2, math.Inf(1), 4})) 1120 if argmax, err = T.Argmax(AllAxes); err != nil { 1121 t.Errorf("Failed test with +Inf: %v", err) 1122 } 1123 assert.True(argmax.IsScalar()) 1124 assert.Equal(2, argmax.ScalarValue(), "+Inf test") 1125 1126 // test with Mask and +Inf 1127 T = New(WithShape(4), WithBacking([]float64{1, 9, math.Inf(1), 4}, []bool{false, true, true, false})) 1128 if argmax, err = T.Argmax(AllAxes); err != nil { 1129 t.Errorf("Failed test with NaN: %v", err) 1130 } 1131 assert.True(argmax.IsScalar()) 1132 assert.Equal(3, argmax.ScalarValue(), "Masked NaN test") 1133 1134 // test with -Inf 1135 T = New(WithShape(4), WithBacking([]float64{1, 2, math.Inf(-1), 4})) 1136 if argmax, err = T.Argmax(AllAxes); err != nil { 1137 t.Errorf("Failed test with -Inf: %v", err) 1138 } 1139 assert.True(argmax.IsScalar()) 1140 assert.Equal(3, argmax.ScalarValue(), "+Inf test") 1141 1142 // test with Mask and -Inf 1143 T = New(WithShape(4), WithBacking([]float64{1, 9, math.Inf(-1), 4}, []bool{false, true, true, false})) 1144 if argmax, err = T.Argmax(AllAxes); err != nil { 1145 t.Errorf("Failed test with NaN: %v", err) 1146 } 1147 assert.True(argmax.IsScalar()) 1148 assert.Equal(3, argmax.ScalarValue(), "Masked -Inf test") 1149 1150 // with different engine 1151 T = basicDenseF64.Clone().(*Dense) 1152 WithEngine(dummyEngine2{})(T) 1153 for i := 0; i < T.Dims(); i++ { 1154 if argmax, err = T.Argmax(i); err != nil { 1155 t.Error(err) 1156 continue 1157 } 1158 1159 assert.True(argmaxCorrect[i].shape.Eq(argmax.Shape()), "Argmax(%d) error. Want shape %v. Got %v", i, argmaxCorrect[i].shape) 1160 assert.Equal(argmaxCorrect[i].data, argmax.Data(), "Argmax(%d) error. ", i) 1161 } 1162 1163 // idiotsville 1164 _, err = T.Argmax(10000) 1165 assert.NotNil(err) 1166 1167 } 1168 func TestDense_Argmin_F64(t *testing.T) { 1169 assert := assert.New(t) 1170 var T, argmin *Dense 1171 var err error 1172 T = basicDenseF64.Clone().(*Dense) 1173 for i := 0; i < T.Dims(); i++ { 1174 if argmin, err = T.Argmin(i); err != nil { 1175 t.Error(err) 1176 continue 1177 } 1178 1179 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 1180 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 1181 } 1182 // test all axes 1183 if argmin, err = T.Argmin(AllAxes); err != nil { 1184 t.Error(err) 1185 return 1186 } 1187 assert.True(argmin.IsScalar()) 1188 assert.Equal(11, argmin.ScalarValue()) 1189 1190 // test with NaN 1191 T = New(WithShape(4), WithBacking([]float64{1, 2, math.NaN(), 4})) 1192 if argmin, err = T.Argmin(AllAxes); err != nil { 1193 t.Errorf("Failed test with NaN: %v", err) 1194 } 1195 assert.True(argmin.IsScalar()) 1196 assert.Equal(2, argmin.ScalarValue(), "NaN test") 1197 1198 // test with Mask and Nan 1199 T = New(WithShape(4), WithBacking([]float64{1, -9, math.NaN(), 4}, []bool{false, true, true, false})) 1200 if argmin, err = T.Argmin(AllAxes); err != nil { 1201 t.Errorf("Failed test with NaN: %v", err) 1202 } 1203 assert.True(argmin.IsScalar()) 1204 assert.Equal(0, argmin.ScalarValue(), "Masked NaN test") 1205 1206 // test with +Inf 1207 T = New(WithShape(4), WithBacking([]float64{1, 2, math.Inf(1), 4})) 1208 if argmin, err = T.Argmin(AllAxes); err != nil { 1209 t.Errorf("Failed test with +Inf: %v", err) 1210 } 1211 assert.True(argmin.IsScalar()) 1212 assert.Equal(0, argmin.ScalarValue(), "+Inf test") 1213 1214 // test with Mask and +Inf 1215 T = New(WithShape(4), WithBacking([]float64{1, -9, math.Inf(1), 4}, []bool{false, true, true, false})) 1216 if argmin, err = T.Argmin(AllAxes); err != nil { 1217 t.Errorf("Failed test with NaN: %v", err) 1218 } 1219 assert.True(argmin.IsScalar()) 1220 assert.Equal(0, argmin.ScalarValue(), "Masked NaN test") 1221 1222 // test with -Inf 1223 T = New(WithShape(4), WithBacking([]float64{1, 2, math.Inf(-1), 4})) 1224 if argmin, err = T.Argmin(AllAxes); err != nil { 1225 t.Errorf("Failed test with -Inf: %v", err) 1226 } 1227 assert.True(argmin.IsScalar()) 1228 assert.Equal(2, argmin.ScalarValue(), "+Inf test") 1229 1230 // test with Mask and -Inf 1231 T = New(WithShape(4), WithBacking([]float64{1, -9, math.Inf(-1), 4}, []bool{false, true, true, false})) 1232 if argmin, err = T.Argmin(AllAxes); err != nil { 1233 t.Errorf("Failed test with NaN: %v", err) 1234 } 1235 assert.True(argmin.IsScalar()) 1236 assert.Equal(0, argmin.ScalarValue(), "Masked -Inf test") 1237 1238 // with different engine 1239 T = basicDenseF64.Clone().(*Dense) 1240 WithEngine(dummyEngine2{})(T) 1241 for i := 0; i < T.Dims(); i++ { 1242 if argmin, err = T.Argmin(i); err != nil { 1243 t.Error(err) 1244 continue 1245 } 1246 1247 assert.True(argminCorrect[i].shape.Eq(argmin.Shape()), "Argmin(%d) error. Want shape %v. Got %v", i, argminCorrect[i].shape) 1248 assert.Equal(argminCorrect[i].data, argmin.Data(), "Argmin(%d) error. ", i) 1249 } 1250 1251 // idiotsville 1252 _, err = T.Argmin(10000) 1253 assert.NotNil(err) 1254 1255 }