github.com/sonda2208/golearn@v0.0.0-20230401025148-848c5a699337/knn/knn_bench_test.go (about) 1 package knn 2 3 import ( 4 "fmt" 5 "testing" 6 7 "github.com/sjwhitworth/golearn/base" 8 "github.com/sjwhitworth/golearn/evaluation" 9 ) 10 11 func readMnist() (*base.DenseInstances, *base.DenseInstances) { 12 // Create the class Attribute 13 classAttrs := make(map[int]base.Attribute) 14 classAttrs[0] = base.NewCategoricalAttribute() 15 classAttrs[0].SetName("label") 16 // Setup the class Attribute to be in its own group 17 classAttrGroups := make(map[string]string) 18 classAttrGroups["label"] = "ClassGroup" 19 // The rest can go in a default group 20 attrGroups := make(map[string]string) 21 22 inst1, err := base.ParseCSVToInstancesWithAttributeGroups( 23 "../examples/datasets/mnist_train.csv", 24 attrGroups, 25 classAttrGroups, 26 classAttrs, 27 true, 28 ) 29 if err != nil { 30 panic(err) 31 } 32 inst2, err := base.ParseCSVToTemplatedInstances( 33 "../examples/datasets/mnist_test.csv", 34 true, 35 inst1, 36 ) 37 if err != nil { 38 panic(err) 39 } 40 return inst1, inst2 41 } 42 43 func BenchmarkKNNWithOpts(b *testing.B) { 44 // Load 45 train, test := readMnist() 46 cls := NewKnnClassifier("euclidean", "linear", 1) 47 cls.AllowOptimisations = true 48 cls.Fit(train) 49 predictions, err := cls.Predict(test) 50 if err != nil { 51 b.Error(err) 52 } 53 c, err := evaluation.GetConfusionMatrix(test, predictions) 54 if err != nil { 55 panic(err) 56 } 57 fmt.Println(evaluation.GetSummary(c)) 58 fmt.Println(evaluation.GetAccuracy(c)) 59 } 60 61 func BenchmarkKNNWithNoOpts(b *testing.B) { 62 // Load 63 train, test := readMnist() 64 cls := NewKnnClassifier("euclidean", "linear", 1) 65 cls.AllowOptimisations = false 66 cls.Fit(train) 67 predictions, err := cls.Predict(test) 68 if err != nil { 69 b.Error(err) 70 } 71 c, err := evaluation.GetConfusionMatrix(test, predictions) 72 if err != nil { 73 panic(err) 74 } 75 fmt.Println(evaluation.GetSummary(c)) 76 fmt.Println(evaluation.GetAccuracy(c)) 77 }