github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/examples/multivectors/main.go (about)

     1  package main
     2  
     3  import (
     4  	"context"
     5  	"log"
     6  	"math/rand"
     7  	"time"
     8  
     9  	"github.com/milvus-io/milvus-sdk-go/v2/client"
    10  	"github.com/milvus-io/milvus-sdk-go/v2/entity"
    11  )
    12  
    13  const (
    14  	milvusAddr     = `localhost:19530`
    15  	nEntities, dim = 10000, 128
    16  	collectionName = "hello_multi_vectors"
    17  
    18  	idCol, keyCol, embeddingCol1, embeddingCol2 = "ID", "key", "vector1", "vector2"
    19  	topK                                        = 3
    20  )
    21  
    22  func main() {
    23  	ctx := context.Background()
    24  
    25  	log.Println("start connecting to Milvus")
    26  	c, err := client.NewClient(ctx, client.Config{
    27  		Address: milvusAddr,
    28  	})
    29  	if err != nil {
    30  		log.Fatalf("failed to connect to milvus, err: %v", err)
    31  	}
    32  	defer c.Close()
    33  
    34  	// delete collection if exists
    35  	has, err := c.HasCollection(ctx, collectionName)
    36  	if err != nil {
    37  		log.Fatalf("failed to check collection exists, err: %v", err)
    38  	}
    39  	if has {
    40  		c.DropCollection(ctx, collectionName)
    41  	}
    42  
    43  	// create collection
    44  	log.Printf("create collection `%s`\n", collectionName)
    45  	schema := entity.NewSchema().WithName(collectionName).WithDescription("hello_partition_key is the a demo to introduce the partition key related APIs").
    46  		WithField(entity.NewField().WithName(idCol).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true).WithIsAutoID(true)).
    47  		WithField(entity.NewField().WithName(keyCol).WithDataType(entity.FieldTypeInt64)).
    48  		WithField(entity.NewField().WithName(embeddingCol1).WithDataType(entity.FieldTypeFloatVector).WithDim(dim)).
    49  		WithField(entity.NewField().WithName(embeddingCol2).WithDataType(entity.FieldTypeFloatVector).WithDim(dim))
    50  
    51  	if err := c.CreateCollection(ctx, schema, entity.DefaultShardNumber); err != nil { // use default shard number
    52  		log.Fatalf("create collection failed, err: %v", err)
    53  	}
    54  
    55  	var keyList []int64
    56  	var embeddingList [][]float32
    57  	keyList = make([]int64, 0, nEntities)
    58  	embeddingList = make([][]float32, 0, nEntities)
    59  	for i := 0; i < nEntities; i++ {
    60  		keyList = append(keyList, rand.Int63()%512)
    61  	}
    62  	for i := 0; i < nEntities; i++ {
    63  		vec := make([]float32, 0, dim)
    64  		for j := 0; j < dim; j++ {
    65  			vec = append(vec, rand.Float32())
    66  		}
    67  		embeddingList = append(embeddingList, vec)
    68  	}
    69  	keyColData := entity.NewColumnInt64(keyCol, keyList)
    70  	embeddingColData1 := entity.NewColumnFloatVector(embeddingCol1, dim, embeddingList)
    71  	embeddingColData2 := entity.NewColumnFloatVector(embeddingCol2, dim, embeddingList)
    72  
    73  	log.Println("start to insert data into collection")
    74  
    75  	if _, err := c.Insert(ctx, collectionName, "", keyColData, embeddingColData1, embeddingColData2); err != nil {
    76  		log.Fatalf("failed to insert random data into `%s`, err: %v", collectionName, err)
    77  	}
    78  
    79  	log.Println("insert data done, start to flush")
    80  
    81  	if err := c.Flush(ctx, collectionName, false); err != nil {
    82  		log.Fatalf("failed to flush data, err: %v", err)
    83  	}
    84  	log.Println("flush data done")
    85  
    86  	// build index
    87  	log.Println("start creating index HNSW")
    88  	idx, err := entity.NewIndexHNSW(entity.L2, 16, 256)
    89  	if err != nil {
    90  		log.Fatalf("failed to create ivf flat index, err: %v", err)
    91  	}
    92  	if err := c.CreateIndex(ctx, collectionName, embeddingCol1, idx, false); err != nil {
    93  		log.Fatalf("failed to create index, err: %v", err)
    94  	}
    95  	if err := c.CreateIndex(ctx, collectionName, embeddingCol2, idx, false); err != nil {
    96  		log.Fatalf("failed to create index, err: %v", err)
    97  	}
    98  
    99  	log.Printf("build HNSW index done for collection `%s`\n", collectionName)
   100  	log.Printf("start to load collection `%s`\n", collectionName)
   101  
   102  	// load collection
   103  	if err := c.LoadCollection(ctx, collectionName, false); err != nil {
   104  		log.Fatalf("failed to load collection, err: %v", err)
   105  	}
   106  
   107  	log.Println("load collection done")
   108  
   109  	// currently only nq =1 is supported
   110  	vec2search1 := []entity.Vector{
   111  		entity.FloatVector(embeddingList[len(embeddingList)-2]),
   112  	}
   113  	vec2search2 := []entity.Vector{
   114  		entity.FloatVector(embeddingList[len(embeddingList)-1]),
   115  	}
   116  
   117  	begin := time.Now()
   118  	sp, _ := entity.NewIndexHNSWSearchParam(30)
   119  
   120  	log.Println("start to search vector field 1")
   121  	result, err := c.Search(ctx, collectionName, nil, "", []string{keyCol, embeddingCol1, embeddingCol2}, vec2search1,
   122  		embeddingCol1, entity.L2, topK, sp)
   123  	if err != nil {
   124  		log.Fatalf("failed to search collection, err: %v", err)
   125  	}
   126  
   127  	log.Printf("search `%s` done, latency %v\n", collectionName, time.Since(begin))
   128  	for _, rs := range result {
   129  		for i := 0; i < rs.ResultCount; i++ {
   130  			id, _ := rs.IDs.GetAsInt64(i)
   131  			score := rs.Scores[i]
   132  			embedding, _ := rs.Fields.GetColumn(embeddingCol1).Get(i)
   133  
   134  			log.Printf("ID: %d, score %f, embedding: %v\n", id, score, embedding)
   135  		}
   136  	}
   137  
   138  	log.Println("start to execute hybrid search")
   139  
   140  	result, err = c.HybridSearch(ctx, collectionName, nil, topK, []string{keyCol, embeddingCol1, embeddingCol2},
   141  		client.NewRRFReranker(), []*client.ANNSearchRequest{
   142  			client.NewANNSearchRequest(embeddingCol1, entity.L2, "", vec2search1, sp, topK),
   143  			client.NewANNSearchRequest(embeddingCol2, entity.L2, "", vec2search2, sp, topK),
   144  		})
   145  	if err != nil {
   146  		log.Fatalf("failed to search collection, err: %v", err)
   147  	}
   148  
   149  	log.Printf("hybrid search `%s` done, latency %v\n", collectionName, time.Since(begin))
   150  	for _, rs := range result {
   151  		for i := 0; i < rs.ResultCount; i++ {
   152  			id, _ := rs.IDs.GetAsInt64(i)
   153  			score := rs.Scores[i]
   154  			embedding1, _ := rs.Fields.GetColumn(embeddingCol1).Get(i)
   155  			embedding2, _ := rs.Fields.GetColumn(embeddingCol1).Get(i)
   156  			log.Printf("ID: %d, score %f, embedding1: %v, embedding2: %v\n", id, score, embedding1, embedding2)
   157  		}
   158  	}
   159  
   160  	c.DropCollection(ctx, collectionName)
   161  }