github.com/Sterrenhemel/go-vector-similarity@v0.0.0-20221003101551-ef67cc113420/all_pairs_containment_benchmark_test.go (about)

     1  package SetSimilaritySearch
     2  
     3  import (
     4  	"bufio"
     5  	"compress/gzip"
     6  	"encoding/csv"
     7  	"fmt"
     8  	"log"
     9  	"os"
    10  	"strconv"
    11  	"strings"
    12  	"testing"
    13  	"time"
    14  )
    15  
    16  var (
    17  	// Download from https://github.com/ekzhu/set-similarity-search-benchmarks
    18  	allPairsContainmentBenchmarkFilename  = "canada_us_uk_opendata.inp.gz"
    19  	allPairsContainmentBenchmarkResult    = "canada_us_uk_opendata_all_pairs_containment.csv"
    20  	allPairsContainmentBenchmarkThreshold = 0.9
    21  	allPairsContainmentBenchmarkMinSize   = 10
    22  )
    23  
    24  // Read set similarity search benchmark files from
    25  // https://github.com/ekzhu/set-similarity-search-benchmarks
    26  func readGzippedTransformedSets(filename string,
    27  	firstLineInfo bool, minSize int) (sets [][]float32) {
    28  	file, err := os.Open(filename)
    29  	if err != nil {
    30  		panic(err)
    31  	}
    32  	defer file.Close()
    33  	gz, err := gzip.NewReader(file)
    34  	if err != nil {
    35  		panic(err)
    36  	}
    37  	defer gz.Close()
    38  	sets = make([][]float32, 0)
    39  	scanner := bufio.NewScanner(gz)
    40  	scanner.Buffer(nil, 1024*1024*1024*4)
    41  	for scanner.Scan() {
    42  		line := strings.Trim(scanner.Text(), "\n")
    43  		if firstLineInfo && len(sets) == 0 {
    44  			// Initialize the sets using the info given by the first line
    45  			count, err := strconv.Atoi(strings.Split(line, " ")[0])
    46  			if err != nil {
    47  				panic(err)
    48  			}
    49  			sets = make([][]float32, 0, count)
    50  			firstLineInfo = false
    51  			continue
    52  		}
    53  		raw := strings.Split(strings.Split(line, "\t")[1], ",")
    54  		if len(raw) < minSize {
    55  			continue
    56  		}
    57  		set := make([]float32, len(raw))
    58  		for i := range set {
    59  			f, err := strconv.ParseFloat(raw[i], 32)
    60  			if err != nil {
    61  				panic(err)
    62  			}
    63  			set[i] = float32(f)
    64  		}
    65  		sets = append(sets, set)
    66  		if len(sets)%100 == 0 {
    67  			fmt.Printf("\rRead %d sets so far", len(sets))
    68  		}
    69  	}
    70  	fmt.Println()
    71  	if err := scanner.Err(); err != nil {
    72  		panic(err)
    73  	}
    74  	return sets
    75  }
    76  
    77  func BenchmarkOpenDataAllPairContainment(b *testing.B) {
    78  	log.Printf("Reading transformed sets from %s",
    79  		allPairsContainmentBenchmarkFilename)
    80  	start := time.Now()
    81  	sets := readGzippedTransformedSets(allPairsContainmentBenchmarkFilename,
    82  		/*firstLineInfo=*/ true,
    83  		allPairsContainmentBenchmarkMinSize)
    84  	log.Printf("Finished reading %d transformed sets in %s", len(sets),
    85  		time.Now().Sub(start).String())
    86  	log.Printf("Building search index")
    87  	start = time.Now()
    88  	searchIndex, err := NewSearchIndex(sets, "containment",
    89  		allPairsContainmentBenchmarkThreshold)
    90  	if err != nil {
    91  		b.Fatal(err)
    92  	}
    93  	log.Printf("Finished building search index in %s",
    94  		time.Now().Sub(start).String())
    95  	out, err := os.Create(allPairsContainmentBenchmarkResult)
    96  	if err != nil {
    97  		b.Fatal(err)
    98  	}
    99  	defer out.Close()
   100  	w := csv.NewWriter(out)
   101  	log.Printf("Begin querying")
   102  	start = time.Now()
   103  	var count int
   104  	for i, set := range sets {
   105  		results := searchIndex.Query(set)
   106  		for _, result := range results {
   107  			if result.X == i {
   108  				continue
   109  			}
   110  			w.Write([]string{
   111  				strconv.Itoa(i),
   112  				strconv.Itoa(result.X),
   113  				strconv.FormatFloat(result.Similarity, 'f', 4, 64),
   114  			})
   115  		}
   116  		count++
   117  		if count%100 == 0 {
   118  			fmt.Printf("\rQueried %d sets so far", count)
   119  		}
   120  	}
   121  	fmt.Println()
   122  	log.Printf("Finished querying in %s", time.Now().Sub(start).String())
   123  	w.Flush()
   124  	if err := w.Error(); err != nil {
   125  		b.Fatal(err)
   126  	}
   127  	log.Printf("Results written to %s", allPairsContainmentBenchmarkResult)
   128  }