gorgonia.org/tensor@v0.9.24/internal/storage/header_test.go (about)

     1  package storage
     2  
     3  import (
     4  	"github.com/stretchr/testify/assert"
     5  	"reflect"
     6  	"testing"
     7  )
     8  
     9  func TestCopy(t *testing.T) {
    10  	// A longer than B
    11  	a := headerFromSlice([]int{0, 1, 2, 3, 4})
    12  	b := headerFromSlice([]int{10, 11})
    13  	copied := Copy(reflect.TypeOf(1), &a, &b)
    14  
    15  	assert.Equal(t, 2, copied)
    16  	assert.Equal(t, []int{10, 11, 2, 3, 4}, a.Ints())
    17  
    18  	// B longer than A
    19  	a = headerFromSlice([]int{10, 11})
    20  	b = headerFromSlice([]int{0, 1, 2, 3, 4})
    21  	copied = Copy(reflect.TypeOf(1), &a, &b)
    22  
    23  	assert.Equal(t, 2, copied)
    24  	assert.Equal(t, []int{0, 1}, a.Ints())
    25  
    26  	// A is empty
    27  	a = headerFromSlice([]int{})
    28  	b = headerFromSlice([]int{0, 1, 2, 3, 4})
    29  	copied = Copy(reflect.TypeOf(1), &a, &b)
    30  
    31  	assert.Equal(t, 0, copied)
    32  
    33  	// B is empty
    34  	a = headerFromSlice([]int{0, 1, 2, 3, 4})
    35  	b = headerFromSlice([]int{})
    36  	copied = Copy(reflect.TypeOf(1), &a, &b)
    37  
    38  	assert.Equal(t, 0, copied)
    39  	assert.Equal(t, []int{0, 1, 2, 3, 4}, a.Ints())
    40  }
    41  
    42  func TestFill(t *testing.T) {
    43  	// A longer than B
    44  	a := headerFromSlice([]int{0, 1, 2, 3, 4})
    45  	b := headerFromSlice([]int{10, 11})
    46  	copied := Fill(reflect.TypeOf(1), &a, &b)
    47  
    48  	assert.Equal(t, 5, copied)
    49  	assert.Equal(t, []int{10, 11, 10, 11, 10}, a.Ints())
    50  
    51  	// B longer than A
    52  	a = headerFromSlice([]int{10, 11})
    53  	b = headerFromSlice([]int{0, 1, 2, 3, 4})
    54  	copied = Fill(reflect.TypeOf(1), &a, &b)
    55  
    56  	assert.Equal(t, 2, copied)
    57  	assert.Equal(t, []int{0, 1}, a.Ints())
    58  }
    59  
    60  func headerFromSlice(x interface{}) Header {
    61  	xT := reflect.TypeOf(x)
    62  	if xT.Kind() != reflect.Slice {
    63  		panic("Expected a slice")
    64  	}
    65  	xV := reflect.ValueOf(x)
    66  	size := uintptr(xV.Len()) * xT.Elem().Size()
    67  	return Header{
    68  		Raw: FromMemory(xV.Pointer(), size),
    69  	}
    70  }