gorgonia.org/gorgonia@v0.9.17/op_yolo_test.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 "os" 6 "testing" 7 8 "gorgonia.org/tensor" 9 ) 10 11 func TestYolo(t *testing.T) { 12 13 inputSize := 416 14 numClasses := 80 15 testAnchors := [][]float32{ 16 []float32{10, 13, 16, 30, 33, 23}, 17 []float32{30, 61, 62, 45, 59, 119}, 18 []float32{116, 90, 156, 198, 373, 326}, 19 } 20 21 numpyInputs := []string{ 22 "./examples/tiny-yolo-v3-coco/data/test_yolo_op/1input.[(10, 13), (16, 30), (33, 23)].npy", 23 "./examples/tiny-yolo-v3-coco/data/test_yolo_op/1input.[(30, 61), (62, 45), (59, 119)].npy", 24 "./examples/tiny-yolo-v3-coco/data/test_yolo_op/1input.[(116, 90), (156, 198), (373, 326)].npy", 25 } 26 27 numpyExpectedOutputs := []string{ 28 "./examples/tiny-yolo-v3-coco/data/test_yolo_op/1output.[(10, 13), (16, 30), (33, 23)].npy", 29 "./examples/tiny-yolo-v3-coco/data/test_yolo_op/1output.[(30, 61), (62, 45), (59, 119)].npy", 30 "./examples/tiny-yolo-v3-coco/data/test_yolo_op/1output.[(116, 90), (156, 198), (373, 326)].npy", 31 } 32 33 for i := range testAnchors { 34 // Read input values from numpy format 35 input := tensor.New(tensor.Of(tensor.Float32)) 36 r, err := os.Open(numpyInputs[i]) 37 if err != nil { 38 t.Error(err) 39 return 40 } 41 err = input.ReadNpy(r) 42 if err != nil { 43 t.Error(err) 44 return 45 } 46 47 // Read expected values from numpy format 48 expected := tensor.New(tensor.Of(tensor.Float32)) 49 r, err = os.Open(numpyExpectedOutputs[i]) 50 if err != nil { 51 t.Error(err) 52 return 53 } 54 err = expected.ReadNpy(r) 55 if err != nil { 56 t.Error(err) 57 return 58 } 59 60 // Load graph 61 g := NewGraph() 62 inputTensor := NewTensor(g, tensor.Float32, 4, WithShape(input.Shape()...), WithName("yolo")) 63 // Prepare YOLOv3 node 64 outNode, err := YOLOv3(inputTensor, testAnchors[i], []int{0, 1, 2}, inputSize, numClasses, 0.7) 65 if err != nil { 66 t.Error(err) 67 return 68 } 69 // Run operation 70 vm := NewTapeMachine(g) 71 if err := Let(inputTensor, input); err != nil { 72 t.Error(err) 73 return 74 } 75 vm.RunAll() 76 vm.Close() 77 78 if !floatsEqual32(outNode.Value().Data().([]float32), expected.Data().([]float32)) { 79 t.Error(fmt.Sprintf("Test Anchor %d: %v\nGot: \n%v\nExpected: \n%v", i, testAnchors[i], outNode.Value(), expected)) 80 } 81 } 82 }