gorgonia.org/gorgonia@v0.9.17/op_yolo_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  	"os"
     6  	"testing"
     7  
     8  	"gorgonia.org/tensor"
     9  )
    10  
    11  func TestYolo(t *testing.T) {
    12  
    13  	inputSize := 416
    14  	numClasses := 80
    15  	testAnchors := [][]float32{
    16  		[]float32{10, 13, 16, 30, 33, 23},
    17  		[]float32{30, 61, 62, 45, 59, 119},
    18  		[]float32{116, 90, 156, 198, 373, 326},
    19  	}
    20  
    21  	numpyInputs := []string{
    22  		"./examples/tiny-yolo-v3-coco/data/test_yolo_op/1input.[(10, 13), (16, 30), (33, 23)].npy",
    23  		"./examples/tiny-yolo-v3-coco/data/test_yolo_op/1input.[(30, 61), (62, 45), (59, 119)].npy",
    24  		"./examples/tiny-yolo-v3-coco/data/test_yolo_op/1input.[(116, 90), (156, 198), (373, 326)].npy",
    25  	}
    26  
    27  	numpyExpectedOutputs := []string{
    28  		"./examples/tiny-yolo-v3-coco/data/test_yolo_op/1output.[(10, 13), (16, 30), (33, 23)].npy",
    29  		"./examples/tiny-yolo-v3-coco/data/test_yolo_op/1output.[(30, 61), (62, 45), (59, 119)].npy",
    30  		"./examples/tiny-yolo-v3-coco/data/test_yolo_op/1output.[(116, 90), (156, 198), (373, 326)].npy",
    31  	}
    32  
    33  	for i := range testAnchors {
    34  		// Read input values from numpy format
    35  		input := tensor.New(tensor.Of(tensor.Float32))
    36  		r, err := os.Open(numpyInputs[i])
    37  		if err != nil {
    38  			t.Error(err)
    39  			return
    40  		}
    41  		err = input.ReadNpy(r)
    42  		if err != nil {
    43  			t.Error(err)
    44  			return
    45  		}
    46  
    47  		// Read expected values from numpy format
    48  		expected := tensor.New(tensor.Of(tensor.Float32))
    49  		r, err = os.Open(numpyExpectedOutputs[i])
    50  		if err != nil {
    51  			t.Error(err)
    52  			return
    53  		}
    54  		err = expected.ReadNpy(r)
    55  		if err != nil {
    56  			t.Error(err)
    57  			return
    58  		}
    59  
    60  		// Load graph
    61  		g := NewGraph()
    62  		inputTensor := NewTensor(g, tensor.Float32, 4, WithShape(input.Shape()...), WithName("yolo"))
    63  		// Prepare YOLOv3 node
    64  		outNode, err := YOLOv3(inputTensor, testAnchors[i], []int{0, 1, 2}, inputSize, numClasses, 0.7)
    65  		if err != nil {
    66  			t.Error(err)
    67  			return
    68  		}
    69  		// Run operation
    70  		vm := NewTapeMachine(g)
    71  		if err := Let(inputTensor, input); err != nil {
    72  			t.Error(err)
    73  			return
    74  		}
    75  		vm.RunAll()
    76  		vm.Close()
    77  
    78  		if !floatsEqual32(outNode.Value().Data().([]float32), expected.Data().([]float32)) {
    79  			t.Error(fmt.Sprintf("Test Anchor %d: %v\nGot: \n%v\nExpected: \n%v", i, testAnchors[i], outNode.Value(), expected))
    80  		}
    81  	}
    82  }