gorgonia.org/tensor@v0.9.24/type_test.go (about)

     1  package tensor
     2  
     3  import (
     4  	"reflect"
     5  	"testing"
     6  )
     7  
     8  type Float16 uint16
     9  
    10  func TestRegisterType(t *testing.T) {
    11  	dt := Dtype{reflect.TypeOf(Float16(0))}
    12  	RegisterFloat(dt)
    13  
    14  	if err := typeclassCheck(dt, floatTypes); err != nil {
    15  		t.Errorf("Expected %v to be in floatTypes: %v", dt, err)
    16  	}
    17  	if err := typeclassCheck(dt, numberTypes); err != nil {
    18  		t.Errorf("Expected %v to be in numberTypes: %v", dt, err)
    19  	}
    20  	if err := typeclassCheck(dt, ordTypes); err != nil {
    21  		t.Errorf("Expected %v to be in ordTypes: %v", dt, err)
    22  	}
    23  	if err := typeclassCheck(dt, eqTypes); err != nil {
    24  		t.Errorf("Expected %v to be in eqTypes: %v", dt, err)
    25  	}
    26  
    27  }
    28  
    29  func TestDtypeConversions(t *testing.T) {
    30  	for k, v := range reverseNumpyDtypes {
    31  		if npdt, err := v.numpyDtype(); npdt != k {
    32  			t.Errorf("Expected %v to return numpy dtype of %q. Got %q instead", v, k, npdt)
    33  		} else if err != nil {
    34  			t.Errorf("Error: %v", err)
    35  		}
    36  	}
    37  	dt := Dtype{reflect.TypeOf(Float16(0))}
    38  	if _, err := dt.numpyDtype(); err == nil {
    39  		t.Errorf("Expected an error when passing in type unknown to np")
    40  	}
    41  
    42  	for k, v := range numpyDtypes {
    43  		if dt, err := fromNumpyDtype(v); dt != k {
    44  			// special cases
    45  			if Int.Size() == 4 && v == "i4" && dt == Int {
    46  				continue
    47  			}
    48  			if Int.Size() == 8 && v == "i8" && dt == Int {
    49  				continue
    50  			}
    51  
    52  			if Uint.Size() == 4 && v == "u4" && dt == Uint {
    53  				continue
    54  			}
    55  			if Uint.Size() == 8 && v == "u8" && dt == Uint {
    56  				continue
    57  			}
    58  			t.Errorf("Expected %q to return %v. Got %v instead", v, k, dt)
    59  		} else if err != nil {
    60  			t.Errorf("Error: %v", err)
    61  		}
    62  	}
    63  	if _, err := fromNumpyDtype("EDIUH"); err == nil {
    64  		t.Error("Expected error when nonsense is passed into fromNumpyDtype")
    65  	}
    66  }