github.com/sonda2208/golearn@v0.0.0-20230401025148-848c5a699337/knn/knn_bench_test.go (about)

     1  package knn
     2  
     3  import (
     4  	"fmt"
     5  	"testing"
     6  
     7  	"github.com/sjwhitworth/golearn/base"
     8  	"github.com/sjwhitworth/golearn/evaluation"
     9  )
    10  
    11  func readMnist() (*base.DenseInstances, *base.DenseInstances) {
    12  	// Create the class Attribute
    13  	classAttrs := make(map[int]base.Attribute)
    14  	classAttrs[0] = base.NewCategoricalAttribute()
    15  	classAttrs[0].SetName("label")
    16  	// Setup the class Attribute to be in its own group
    17  	classAttrGroups := make(map[string]string)
    18  	classAttrGroups["label"] = "ClassGroup"
    19  	// The rest can go in a default group
    20  	attrGroups := make(map[string]string)
    21  
    22  	inst1, err := base.ParseCSVToInstancesWithAttributeGroups(
    23  		"../examples/datasets/mnist_train.csv",
    24  		attrGroups,
    25  		classAttrGroups,
    26  		classAttrs,
    27  		true,
    28  	)
    29  	if err != nil {
    30  		panic(err)
    31  	}
    32  	inst2, err := base.ParseCSVToTemplatedInstances(
    33  		"../examples/datasets/mnist_test.csv",
    34  		true,
    35  		inst1,
    36  	)
    37  	if err != nil {
    38  		panic(err)
    39  	}
    40  	return inst1, inst2
    41  }
    42  
    43  func BenchmarkKNNWithOpts(b *testing.B) {
    44  	// Load
    45  	train, test := readMnist()
    46  	cls := NewKnnClassifier("euclidean", "linear", 1)
    47  	cls.AllowOptimisations = true
    48  	cls.Fit(train)
    49  	predictions, err := cls.Predict(test)
    50  	if err != nil {
    51  		b.Error(err)
    52  	}
    53  	c, err := evaluation.GetConfusionMatrix(test, predictions)
    54  	if err != nil {
    55  		panic(err)
    56  	}
    57  	fmt.Println(evaluation.GetSummary(c))
    58  	fmt.Println(evaluation.GetAccuracy(c))
    59  }
    60  
    61  func BenchmarkKNNWithNoOpts(b *testing.B) {
    62  	// Load
    63  	train, test := readMnist()
    64  	cls := NewKnnClassifier("euclidean", "linear", 1)
    65  	cls.AllowOptimisations = false
    66  	cls.Fit(train)
    67  	predictions, err := cls.Predict(test)
    68  	if err != nil {
    69  		b.Error(err)
    70  	}
    71  	c, err := evaluation.GetConfusionMatrix(test, predictions)
    72  	if err != nil {
    73  		panic(err)
    74  	}
    75  	fmt.Println(evaluation.GetSummary(c))
    76  	fmt.Println(evaluation.GetAccuracy(c))
    77  }