github.com/milvus-io/milvus-sdk-go/v2@v2.4.1/client/row.go (about)

     1  package client
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"reflect"
     7  
     8  	"github.com/cockroachdb/errors"
     9  
    10  	"github.com/golang/protobuf/proto"
    11  	"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
    12  	"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
    13  	"github.com/milvus-io/milvus-sdk-go/v2/entity"
    14  )
    15  
    16  // CreateCollectionByRow create collection by row
    17  func (c *GrpcClient) CreateCollectionByRow(ctx context.Context, row entity.Row, shardNum int32) error {
    18  	if c.Service == nil {
    19  		return ErrClientNotReady
    20  	}
    21  	// parse schema from row definition
    22  	sch, err := entity.ParseSchema(row)
    23  	if err != nil {
    24  		return err
    25  	}
    26  
    27  	// check collection already exists
    28  	has, err := c.HasCollection(ctx, sch.CollectionName)
    29  	if err != nil {
    30  		return err
    31  	}
    32  	// already exists collection with same name, return error
    33  	if has {
    34  		return fmt.Errorf("collection %s already exist", sch.CollectionName)
    35  	}
    36  	// marshal schema to bytes for message transfer
    37  	p := sch.ProtoMessage()
    38  	bs, err := proto.Marshal(p)
    39  	if err != nil {
    40  		return err
    41  	}
    42  	// compose request and invoke Service
    43  	req := &milvuspb.CreateCollectionRequest{
    44  		DbName:         "", // reserved fields, not used for now
    45  		CollectionName: sch.CollectionName,
    46  		Schema:         bs,
    47  		ShardsNum:      shardNum,
    48  	}
    49  	resp, err := c.Service.CreateCollection(ctx, req)
    50  	// handles response
    51  	if err != nil {
    52  		return err
    53  	}
    54  	err = handleRespStatus(resp)
    55  	if err != nil {
    56  		return nil
    57  	}
    58  	return nil
    59  }
    60  
    61  // InsertByRows insert by rows
    62  func (c *GrpcClient) InsertByRows(ctx context.Context, collName string, partitionName string,
    63  	rows []entity.Row) (entity.Column, error) {
    64  	anys := make([]interface{}, 0, len(rows))
    65  	for _, row := range rows {
    66  		anys = append(anys, row)
    67  	}
    68  
    69  	return c.InsertRows(ctx, collName, partitionName, anys)
    70  }
    71  
    72  // InsertRows allows insert with row based data
    73  // rows could be struct or map.
    74  func (c *GrpcClient) InsertRows(ctx context.Context, collName string, partitionName string,
    75  	rows []interface{}) (entity.Column, error) {
    76  	if c.Service == nil {
    77  		return nil, ErrClientNotReady
    78  	}
    79  	if len(rows) == 0 {
    80  		return nil, errors.New("empty rows provided")
    81  	}
    82  
    83  	if err := c.checkCollectionExists(ctx, collName); err != nil {
    84  		return nil, err
    85  	}
    86  	if partitionName != "" {
    87  		if err := c.checkPartitionExists(ctx, collName, partitionName); err != nil {
    88  			return nil, err
    89  		}
    90  	}
    91  	coll, err := c.DescribeCollection(ctx, collName)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  	// 1. convert rows to columns
    96  	columns, err := entity.AnyToColumns(rows, coll.Schema)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  	//fieldData
   101  	// 2. do insert request
   102  	req := &milvuspb.InsertRequest{
   103  		DbName:         "", // reserved
   104  		CollectionName: collName,
   105  		PartitionName:  partitionName,
   106  	}
   107  
   108  	req.NumRows = uint32(len(rows))
   109  	for _, column := range columns {
   110  		req.FieldsData = append(req.FieldsData, column.FieldData())
   111  	}
   112  	resp, err := c.Service.Insert(ctx, req)
   113  	if err != nil {
   114  		return nil, err
   115  	}
   116  	if err := handleRespStatus(resp.GetStatus()); err != nil {
   117  		return nil, err
   118  	}
   119  	MetaCache.setSessionTs(collName, resp.Timestamp)
   120  	// 3. parse id column
   121  	return entity.IDColumns(coll.Schema, resp.GetIDs(), 0, -1)
   122  }
   123  
   124  // SearchResultByRows search result for row-based Search
   125  type SearchResultByRows struct {
   126  	ResultCount int
   127  	Scores      []float32
   128  	Rows        []entity.Row
   129  	Err         error
   130  }
   131  
   132  // SearchResultToRows converts search result proto to rows
   133  func SearchResultToRows(sch *entity.Schema, results *schemapb.SearchResultData, t reflect.Type, _ map[string]struct{}) ([]SearchResultByRows, error) {
   134  	var err error
   135  	offset := 0
   136  	// new will have a pointer, so de-reference first if type is pointer to struct
   137  	if t.Kind() == reflect.Ptr {
   138  		t = t.Elem()
   139  	}
   140  	sr := make([]SearchResultByRows, 0, results.GetNumQueries())
   141  	fieldDataList := results.GetFieldsData()
   142  	nameFieldData := make(map[string]*schemapb.FieldData)
   143  	for _, fieldData := range fieldDataList {
   144  		nameFieldData[fieldData.FieldName] = fieldData
   145  	}
   146  	ids := results.GetIds()
   147  	for i := 0; i < int(results.GetNumQueries()); i++ {
   148  		rc := int(results.GetTopks()[i]) // result entry count for current query
   149  		entry := SearchResultByRows{
   150  			ResultCount: rc,
   151  			Rows:        make([]entity.Row, 0, rc),
   152  			Scores:      results.GetScores()[offset:rc],
   153  		}
   154  		for j := 0; j < rc; j++ {
   155  			p := reflect.New(t)
   156  			v := p.Elem()
   157  
   158  			// extract primary field logic
   159  			for _, field := range sch.Fields {
   160  				f := v.FieldByName(field.Name) // TODO silverxia field may be annotated by tags, which means the field name will not be the same
   161  				if !f.IsValid() {
   162  					continue
   163  				}
   164  				// Primary key has different field from search result
   165  				if field.PrimaryKey {
   166  					switch f.Kind() {
   167  					case reflect.Int64:
   168  						intIds := ids.GetIntId()
   169  						if intIds == nil {
   170  							entry.Err = fmt.Errorf("field %s is int64, but id column is not", field.Name)
   171  							break
   172  						}
   173  						f.SetInt(intIds.GetData()[offset+j])
   174  					case reflect.String:
   175  						strIds := ids.GetStrId()
   176  						if strIds == nil {
   177  							entry.Err = fmt.Errorf("field %s is string ,but id column is not", field.Name)
   178  							break
   179  						}
   180  						f.SetString(strIds.GetData()[offset+j])
   181  					default:
   182  						entry.Err = fmt.Errorf("field %s is not valid primary key", field.Name)
   183  					}
   184  					continue
   185  				}
   186  
   187  				// fieldDataList
   188  				fieldData, has := nameFieldData[field.Name]
   189  				if !has || fieldData == nil {
   190  					continue
   191  				}
   192  
   193  				// Set field value with offset+j-th item
   194  				err = SetFieldValue(field, f, fieldData, offset+j)
   195  				if err != nil {
   196  					entry.Err = err
   197  					break
   198  				}
   199  
   200  			}
   201  			r := p.Interface()
   202  			row, ok := r.(entity.Row)
   203  			if ok {
   204  				entry.Rows = append(entry.Rows, row)
   205  			}
   206  		}
   207  		sr = append(sr, entry)
   208  		// set offset after processed one result
   209  		offset += rc
   210  	}
   211  	return sr, nil
   212  }
   213  
   214  var (
   215  	// ErrFieldTypeNotMatch error for field type not match
   216  	ErrFieldTypeNotMatch = errors.New("field type not matched")
   217  )
   218  
   219  // SetFieldValue set row field value with reflection
   220  func SetFieldValue(field *entity.Field, f reflect.Value, fieldData *schemapb.FieldData, idx int) error {
   221  	scalars := fieldData.GetScalars()
   222  	vectors := fieldData.GetVectors()
   223  	// This switch part is messy
   224  	// Maybe this can be refactored later
   225  	switch field.DataType {
   226  	case entity.FieldTypeBool:
   227  		if f.Kind() != reflect.Bool {
   228  			return ErrFieldTypeNotMatch
   229  		}
   230  		if scalars == nil {
   231  			return ErrFieldTypeNotMatch
   232  		}
   233  		data := scalars.GetBoolData()
   234  		if data == nil {
   235  			return ErrFieldTypeNotMatch
   236  		}
   237  		f.SetBool(data.Data[idx])
   238  	case entity.FieldTypeInt8:
   239  		if f.Kind() != reflect.Int8 {
   240  			return ErrFieldTypeNotMatch
   241  		}
   242  		if scalars == nil {
   243  			return ErrFieldTypeNotMatch
   244  		}
   245  		data := scalars.GetIntData()
   246  		if data == nil {
   247  			return ErrFieldTypeNotMatch
   248  		}
   249  		f.SetInt(int64(data.Data[idx]))
   250  	case entity.FieldTypeInt16:
   251  		if f.Kind() != reflect.Int16 {
   252  			return ErrFieldTypeNotMatch
   253  		}
   254  		if scalars == nil {
   255  			return ErrFieldTypeNotMatch
   256  		}
   257  		data := scalars.GetIntData()
   258  		if data == nil {
   259  			return ErrFieldTypeNotMatch
   260  		}
   261  		f.SetInt(int64(data.Data[idx]))
   262  	case entity.FieldTypeInt32:
   263  		if f.Kind() != reflect.Int32 {
   264  			return ErrFieldTypeNotMatch
   265  		}
   266  		if scalars == nil {
   267  			return ErrFieldTypeNotMatch
   268  		}
   269  		data := scalars.GetIntData()
   270  		if data == nil {
   271  			return ErrFieldTypeNotMatch
   272  		}
   273  		f.SetInt(int64(data.Data[idx]))
   274  	case entity.FieldTypeInt64:
   275  		if f.Kind() != reflect.Int64 {
   276  			return ErrFieldTypeNotMatch
   277  		}
   278  		if scalars == nil {
   279  			return ErrFieldTypeNotMatch
   280  		}
   281  		data := scalars.GetLongData()
   282  		if data == nil {
   283  			return ErrFieldTypeNotMatch
   284  		}
   285  		f.SetInt(data.Data[idx])
   286  	case entity.FieldTypeFloat:
   287  		if f.Kind() != reflect.Float32 {
   288  			return ErrFieldTypeNotMatch
   289  		}
   290  		if scalars == nil {
   291  			return ErrFieldTypeNotMatch
   292  		}
   293  		data := scalars.GetFloatData()
   294  		if data == nil {
   295  			return ErrFieldTypeNotMatch
   296  		}
   297  
   298  		f.SetFloat(float64(data.Data[idx]))
   299  	case entity.FieldTypeDouble:
   300  		if f.Kind() != reflect.Float64 {
   301  			return ErrFieldTypeNotMatch
   302  		}
   303  		if scalars == nil {
   304  			return ErrFieldTypeNotMatch
   305  		}
   306  		data := scalars.GetDoubleData()
   307  		if data == nil {
   308  			return ErrFieldTypeNotMatch
   309  		}
   310  
   311  		f.SetFloat(data.Data[idx])
   312  	case entity.FieldTypeString:
   313  		if f.Kind() != reflect.String {
   314  			return ErrFieldTypeNotMatch
   315  		}
   316  		if scalars == nil {
   317  			return ErrFieldTypeNotMatch
   318  		}
   319  		data := scalars.GetStringData()
   320  		if data == nil {
   321  			return ErrFieldTypeNotMatch
   322  		}
   323  
   324  		f.SetString(data.Data[idx])
   325  
   326  	case entity.FieldTypeFloatVector:
   327  		if vectors == nil {
   328  			return ErrFieldTypeNotMatch
   329  		}
   330  		data := vectors.GetFloatVector()
   331  		if data == nil {
   332  
   333  			return ErrFieldTypeNotMatch
   334  		}
   335  		vector := data.Data[idx*int(vectors.Dim) : (idx+1)*int(vectors.Dim)]
   336  		switch f.Kind() {
   337  		case reflect.Slice:
   338  			f.Set(reflect.ValueOf(vector))
   339  		case reflect.Array:
   340  			arrType := reflect.ArrayOf(int(vectors.Dim), reflect.TypeOf(float32(0)))
   341  			arr := reflect.New(arrType).Elem()
   342  			for i := 0; i < int(vectors.Dim); i++ {
   343  				arr.Index(i).Set(reflect.ValueOf(vector[i]))
   344  			}
   345  			f.Set(arr)
   346  		default:
   347  			return ErrFieldTypeNotMatch
   348  		}
   349  	case entity.FieldTypeBinaryVector:
   350  		if vectors == nil {
   351  			return ErrFieldTypeNotMatch
   352  		}
   353  		data := vectors.GetBinaryVector()
   354  		if data == nil {
   355  			return ErrFieldTypeNotMatch
   356  		}
   357  		vector := data[idx*int(vectors.Dim/8) : (idx+1)*int(vectors.Dim/8)]
   358  		switch f.Kind() {
   359  		case reflect.Slice:
   360  			f.Set(reflect.ValueOf(vector))
   361  		case reflect.Array:
   362  			arrType := reflect.ArrayOf(int(vectors.Dim/8), reflect.TypeOf(byte(0)))
   363  			arr := reflect.New(arrType).Elem()
   364  			for i := 0; i < int(vectors.Dim/8); i++ {
   365  				arr.Index(i).Set(reflect.ValueOf(vector[i]))
   366  			}
   367  			f.Set(arr)
   368  		default:
   369  			return ErrFieldTypeNotMatch
   370  		}
   371  	default:
   372  		return ErrFieldTypeNotMatch
   373  	}
   374  	return nil
   375  }