github.com/lomereiter/go-set-similarity-search@v0.0.0-20220827150533-7db29b22ebbc/all_pairs_containment_benchmark_test.go (about)

     1  package SetSimilaritySearch
     2  
     3  import (
     4  	"bufio"
     5  	"compress/gzip"
     6  	"encoding/csv"
     7  	"fmt"
     8  	"log"
     9  	"os"
    10  	"strconv"
    11  	"strings"
    12  	"testing"
    13  	"time"
    14  )
    15  
    16  var (
    17  	// Download from https://github.com/ekzhu/set-similarity-search-benchmarks
    18  	allPairsContainmentBenchmarkFilename  = "canada_us_uk_opendata.inp.gz"
    19  	allPairsContainmentBenchmarkResult    = "canada_us_uk_opendata_all_pairs_containment.csv"
    20  	allPairsContainmentBenchmarkThreshold = 0.9
    21  	allPairsContainmentBenchmarkMinSize   = 10
    22  )
    23  
    24  // Read set similarity search benchmark files from
    25  // https://github.com/ekzhu/set-similarity-search-benchmarks
    26  func readGzippedTransformedSets(filename string,
    27  	firstLineInfo bool, minSize int) (sets [][]int) {
    28  	file, err := os.Open(filename)
    29  	if err != nil {
    30  		panic(err)
    31  	}
    32  	defer file.Close()
    33  	gz, err := gzip.NewReader(file)
    34  	if err != nil {
    35  		panic(err)
    36  	}
    37  	defer gz.Close()
    38  	sets = make([][]int, 0)
    39  	scanner := bufio.NewScanner(gz)
    40  	scanner.Buffer(nil, 1024*1024*1024*4)
    41  	for scanner.Scan() {
    42  		line := strings.Trim(scanner.Text(), "\n")
    43  		if firstLineInfo && len(sets) == 0 {
    44  			// Initialize the sets using the info given by the first line
    45  			count, err := strconv.Atoi(strings.Split(line, " ")[0])
    46  			if err != nil {
    47  				panic(err)
    48  			}
    49  			sets = make([][]int, 0, count)
    50  			firstLineInfo = false
    51  			continue
    52  		}
    53  		raw := strings.Split(strings.Split(line, "\t")[1], ",")
    54  		if len(raw) < minSize {
    55  			continue
    56  		}
    57  		set := make([]int, len(raw))
    58  		for i := range set {
    59  			set[i], err = strconv.Atoi(raw[i])
    60  			if err != nil {
    61  				panic(err)
    62  			}
    63  		}
    64  		sets = append(sets, set)
    65  		if len(sets)%100 == 0 {
    66  			fmt.Printf("\rRead %d sets so far", len(sets))
    67  		}
    68  	}
    69  	fmt.Println()
    70  	if err := scanner.Err(); err != nil {
    71  		panic(err)
    72  	}
    73  	return sets
    74  }
    75  
    76  func BenchmarkOpenDataAllPairContainment(b *testing.B) {
    77  	log.Printf("Reading transformed sets from %s",
    78  		allPairsContainmentBenchmarkFilename)
    79  	start := time.Now()
    80  	sets := readGzippedTransformedSets(allPairsContainmentBenchmarkFilename,
    81  		/*firstLineInfo=*/ true,
    82  		allPairsContainmentBenchmarkMinSize)
    83  	log.Printf("Finished reading %d transformed sets in %s", len(sets),
    84  		time.Now().Sub(start).String())
    85  	log.Printf("Building search index")
    86  	start = time.Now()
    87  	searchIndex, err := NewSearchIndex(sets, "containment",
    88  		allPairsContainmentBenchmarkThreshold)
    89  	if err != nil {
    90  		b.Fatal(err)
    91  	}
    92  	log.Printf("Finished building search index in %s",
    93  		time.Now().Sub(start).String())
    94  	out, err := os.Create(allPairsContainmentBenchmarkResult)
    95  	if err != nil {
    96  		b.Fatal(err)
    97  	}
    98  	defer out.Close()
    99  	w := csv.NewWriter(out)
   100  	log.Printf("Begin querying")
   101  	start = time.Now()
   102  	var count int
   103  	for i, set := range sets {
   104  		results := searchIndex.Query(set)
   105  		for _, result := range results {
   106  			if result.X == i {
   107  				continue
   108  			}
   109  			w.Write([]string{
   110  				strconv.Itoa(i),
   111  				strconv.Itoa(result.X),
   112  				strconv.FormatFloat(result.Similarity, 'f', 4, 64),
   113  			})
   114  		}
   115  		count++
   116  		if count%100 == 0 {
   117  			fmt.Printf("\rQueried %d sets so far", count)
   118  		}
   119  	}
   120  	fmt.Println()
   121  	log.Printf("Finished querying in %s", time.Now().Sub(start).String())
   122  	w.Flush()
   123  	if err := w.Error(); err != nil {
   124  		b.Fatal(err)
   125  	}
   126  	log.Printf("Results written to %s", allPairsContainmentBenchmarkResult)
   127  }