gorgonia.org/tensor@v0.9.24/type_test.go (about) 1 package tensor 2 3 import ( 4 "reflect" 5 "testing" 6 ) 7 8 type Float16 uint16 9 10 func TestRegisterType(t *testing.T) { 11 dt := Dtype{reflect.TypeOf(Float16(0))} 12 RegisterFloat(dt) 13 14 if err := typeclassCheck(dt, floatTypes); err != nil { 15 t.Errorf("Expected %v to be in floatTypes: %v", dt, err) 16 } 17 if err := typeclassCheck(dt, numberTypes); err != nil { 18 t.Errorf("Expected %v to be in numberTypes: %v", dt, err) 19 } 20 if err := typeclassCheck(dt, ordTypes); err != nil { 21 t.Errorf("Expected %v to be in ordTypes: %v", dt, err) 22 } 23 if err := typeclassCheck(dt, eqTypes); err != nil { 24 t.Errorf("Expected %v to be in eqTypes: %v", dt, err) 25 } 26 27 } 28 29 func TestDtypeConversions(t *testing.T) { 30 for k, v := range reverseNumpyDtypes { 31 if npdt, err := v.numpyDtype(); npdt != k { 32 t.Errorf("Expected %v to return numpy dtype of %q. Got %q instead", v, k, npdt) 33 } else if err != nil { 34 t.Errorf("Error: %v", err) 35 } 36 } 37 dt := Dtype{reflect.TypeOf(Float16(0))} 38 if _, err := dt.numpyDtype(); err == nil { 39 t.Errorf("Expected an error when passing in type unknown to np") 40 } 41 42 for k, v := range numpyDtypes { 43 if dt, err := fromNumpyDtype(v); dt != k { 44 // special cases 45 if Int.Size() == 4 && v == "i4" && dt == Int { 46 continue 47 } 48 if Int.Size() == 8 && v == "i8" && dt == Int { 49 continue 50 } 51 52 if Uint.Size() == 4 && v == "u4" && dt == Uint { 53 continue 54 } 55 if Uint.Size() == 8 && v == "u8" && dt == Uint { 56 continue 57 } 58 t.Errorf("Expected %q to return %v. Got %v instead", v, k, dt) 59 } else if err != nil { 60 t.Errorf("Error: %v", err) 61 } 62 } 63 if _, err := fromNumpyDtype("EDIUH"); err == nil { 64 t.Error("Expected error when nonsense is passed into fromNumpyDtype") 65 } 66 }