github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/plan/build_vector_index_util.go (about)

     1  // Copyright 2023 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"github.com/matrixorigin/matrixone/pkg/catalog"
    19  	"github.com/matrixorigin/matrixone/pkg/container/types"
    20  	"github.com/matrixorigin/matrixone/pkg/pb/plan"
    21  )
    22  
    23  var (
    24  	bigIntType  = types.T_int64.ToType()
    25  	varCharType = types.T_varchar.ToType()
    26  
    27  	opTypeToDistanceFunc = map[string]string{
    28  		"vector_l2_ops":     "l2_distance",
    29  		"vector_ip_ops":     "inner_product",
    30  		"vector_cosine_ops": "cosine_distance",
    31  	}
    32  )
    33  
    34  func makeIvfFlatIndexTblScan(builder *QueryBuilder, bindCtx *BindContext,
    35  	indexTableDefs []*TableDef, idxRefs []*ObjectRef, idxTableId int32) (int32, []*Expr) {
    36  	scanNodeProjections := make([]*Expr, len(indexTableDefs[idxTableId].Cols))
    37  	for colIdx, column := range indexTableDefs[idxTableId].Cols {
    38  		scanNodeProjections[colIdx] = &plan.Expr{
    39  			Typ: column.Typ,
    40  			Expr: &plan.Expr_Col{
    41  				Col: &plan.ColRef{
    42  					ColPos: int32(colIdx),
    43  					Name:   column.Name,
    44  				},
    45  			},
    46  		}
    47  	}
    48  	centroidsScanId := builder.appendNode(&Node{
    49  		NodeType:    plan.Node_TABLE_SCAN,
    50  		ObjRef:      idxRefs[idxTableId],
    51  		TableDef:    indexTableDefs[idxTableId],
    52  		ProjectList: scanNodeProjections,
    53  	}, bindCtx)
    54  	return centroidsScanId, scanNodeProjections
    55  }
    56  
    57  func makeMetaTblScanWhereKeyEqVersion(builder *QueryBuilder, bindCtx *BindContext, indexTableDefs []*TableDef, idxRefs []*ObjectRef) (int32, error) {
    58  	metaTableScanId, scanCols := makeIvfFlatIndexTblScan(builder, bindCtx, indexTableDefs, idxRefs, 0)
    59  
    60  	whereKeyEqVersion, err := BindFuncExprImplByPlanExpr(builder.GetContext(), "=", []*Expr{
    61  		DeepCopyExpr(scanCols[0]),
    62  		MakePlan2StringConstExprWithType("version"),
    63  	})
    64  	if err != nil {
    65  		return -1, err
    66  	}
    67  	builder.qry.Nodes[metaTableScanId].FilterList = []*Expr{whereKeyEqVersion}
    68  	return metaTableScanId, nil
    69  }
    70  
    71  func makeTableProjectionIncludingNormalizeL2(builder *QueryBuilder, bindCtx *BindContext, tableScanId int32,
    72  	tableDef *TableDef,
    73  	typeOriginPk Type, posOriginPk int,
    74  	typeOriginVecColumn Type, posOriginVecColumn int) (int32, error) {
    75  
    76  	normalizeL2, err := BindFuncExprImplByPlanExpr(builder.GetContext(), "normalize_l2", []*Expr{
    77  		{ // tbl.embedding
    78  			Typ: typeOriginVecColumn,
    79  			Expr: &plan.Expr_Col{
    80  				Col: &plan.ColRef{
    81  					RelPos: 0,
    82  					ColPos: int32(posOriginVecColumn),
    83  					Name:   tableDef.Cols[posOriginVecColumn].Name,
    84  				},
    85  			},
    86  		},
    87  	})
    88  	if err != nil {
    89  		return -1, err
    90  	}
    91  
    92  	// id, normalize_l2(embedding)
    93  	tableProjectId := builder.appendNode(&plan.Node{
    94  		NodeType: plan.Node_PROJECT,
    95  		Children: []int32{tableScanId},
    96  
    97  		ProjectList: []*Expr{
    98  
    99  			{ // tbl.pk
   100  				Typ: typeOriginPk,
   101  				Expr: &plan.Expr_Col{
   102  					Col: &plan.ColRef{
   103  						RelPos: 0,
   104  						ColPos: int32(posOriginPk),
   105  						Name:   tableDef.Cols[posOriginPk].Name,
   106  					},
   107  				},
   108  			},
   109  
   110  			// tbl.normalize_l2(embedding)
   111  			normalizeL2,
   112  		},
   113  	}, bindCtx)
   114  	return tableProjectId, nil
   115  }
   116  
   117  func makeCrossJoinCentroidsMetaForCurrVersion(builder *QueryBuilder, bindCtx *BindContext, indexTableDefs []*TableDef, idxRefs []*ObjectRef, metaTableScanId int32) (int32, error) {
   118  	centroidsScanId, _ := makeIvfFlatIndexTblScan(builder, bindCtx, indexTableDefs, idxRefs, 1)
   119  
   120  	metaProjection := getProjectionByLastNode(builder, metaTableScanId)
   121  	metaProjectValueCol := DeepCopyExpr(metaProjection[1])
   122  	metaProjectValueCol.Expr.(*plan.Expr_Col).Col.RelPos = 1
   123  	prevMetaScanCastValAsBigInt, err := makePlan2CastExpr(builder.GetContext(), metaProjectValueCol, makePlan2Type(&bigIntType))
   124  	if err != nil {
   125  		return -1, err
   126  	}
   127  	// 0: centroids.version
   128  	// 1: centroids.centroid_id
   129  	// 2: centroids.centroid
   130  	prevCentroidScanProjection := getProjectionByLastNode(builder, centroidsScanId)[:3]
   131  	whereCentroidVersionEqCurrVersion, err := BindFuncExprImplByPlanExpr(builder.GetContext(), "=", []*Expr{
   132  		prevCentroidScanProjection[0],
   133  		prevMetaScanCastValAsBigInt,
   134  	})
   135  	if err != nil {
   136  		return -1, err
   137  	}
   138  
   139  	joinMetaAndCentroidsId := builder.appendNode(&plan.Node{
   140  		NodeType:    plan.Node_JOIN,
   141  		JoinType:    plan.Node_INNER,
   142  		Children:    []int32{centroidsScanId, metaTableScanId},
   143  		ProjectList: prevCentroidScanProjection,
   144  		OnList:      []*Expr{whereCentroidVersionEqCurrVersion},
   145  	}, bindCtx)
   146  
   147  	return joinMetaAndCentroidsId, nil
   148  }
   149  
   150  func makeCrossJoinTblAndCentroids(builder *QueryBuilder, bindCtx *BindContext, tableDef *TableDef,
   151  	leftChildTblId int32, rightChildCentroidsId int32,
   152  	typeOriginPk Type, posOriginPk int,
   153  	typeOriginVecColumn Type) int32 {
   154  
   155  	crossJoinTblAndCentroidsId := builder.appendNode(&plan.Node{
   156  		NodeType: plan.Node_JOIN,
   157  		JoinType: plan.Node_INNER, // since there is no OnList, it is a cross join
   158  		Children: []int32{leftChildTblId, rightChildCentroidsId},
   159  		ProjectList: []*Expr{
   160  			{ // centroids.version
   161  				Typ: makePlan2TypeValue(&bigIntType),
   162  				Expr: &plan.Expr_Col{
   163  					Col: &plan.ColRef{
   164  						RelPos: 1,
   165  						ColPos: 0,
   166  						Name:   catalog.SystemSI_IVFFLAT_TblCol_Centroids_version,
   167  					},
   168  				},
   169  			},
   170  			{ // centroids.centroid_id
   171  				Typ: makePlan2TypeValue(&bigIntType),
   172  				Expr: &plan.Expr_Col{
   173  					Col: &plan.ColRef{
   174  						RelPos: 1,
   175  						ColPos: 1,
   176  						Name:   catalog.SystemSI_IVFFLAT_TblCol_Centroids_id,
   177  					},
   178  				},
   179  			},
   180  			{ // tbl.pk
   181  				Typ: typeOriginPk,
   182  				Expr: &plan.Expr_Col{
   183  					Col: &plan.ColRef{
   184  						RelPos: 0,
   185  						ColPos: 0,
   186  						Name:   tableDef.Cols[posOriginPk].Name,
   187  					},
   188  				},
   189  			},
   190  			{ // centroids.centroid
   191  				Typ: typeOriginVecColumn,
   192  				Expr: &plan.Expr_Col{
   193  					Col: &plan.ColRef{
   194  						RelPos: 1,
   195  						ColPos: 2,
   196  						Name:   catalog.SystemSI_IVFFLAT_TblCol_Centroids_centroid,
   197  					},
   198  				},
   199  			},
   200  			{ // tbl.normalize_l2(embedding)
   201  				Typ: typeOriginVecColumn,
   202  				Expr: &plan.Expr_Col{
   203  					Col: &plan.ColRef{
   204  						RelPos: 0,
   205  						ColPos: 1,
   206  					},
   207  				},
   208  			},
   209  		},
   210  	}, bindCtx)
   211  
   212  	return crossJoinTblAndCentroidsId
   213  }
   214  
   215  func makeMinCentroidIdAndCpKey(builder *QueryBuilder, bindCtx *BindContext,
   216  	crossJoinTblAndCentroidsID int32, multiTableIndex *MultiTableIndex) (int32, error) {
   217  
   218  	lastNodeProjections := getProjectionByLastNode(builder, crossJoinTblAndCentroidsID)
   219  	centroidsVersion := lastNodeProjections[0]
   220  	centroidsId := lastNodeProjections[1]
   221  	tblPk := lastNodeProjections[2]
   222  	centroidsCentroid := lastNodeProjections[3]
   223  	tblNormalizeL2Embedding := lastNodeProjections[4]
   224  
   225  	// 1.a Group By
   226  	groupByList := []*plan.Expr{
   227  		DeepCopyExpr(centroidsVersion), // centroids.version
   228  		DeepCopyExpr(tblPk),            // tbl.pk
   229  	}
   230  
   231  	// 1.b Agg Functions
   232  	entriesParams := multiTableIndex.IndexDefs[catalog.SystemSI_IVFFLAT_TblType_Entries].IndexAlgoParams
   233  	paramMap, err := catalog.IndexParamsStringToMap(entriesParams)
   234  	if err != nil {
   235  		return -1, err
   236  	}
   237  	vectorOps := paramMap[catalog.IndexAlgoParamOpType]
   238  	distFn := opTypeToDistanceFunc[vectorOps]
   239  	l2Distance, err := BindFuncExprImplByPlanExpr(builder.GetContext(), distFn, []*plan.Expr{
   240  		DeepCopyExpr(centroidsCentroid),       // centroids.centroid
   241  		DeepCopyExpr(tblNormalizeL2Embedding), // tbl.normalize_l2(embedding)
   242  	})
   243  	if err != nil {
   244  		return -1, err
   245  	}
   246  
   247  	serialL2DistanceCentroidId, err := BindFuncExprImplByPlanExpr(builder.GetContext(), "serial_full", []*plan.Expr{
   248  		l2Distance,
   249  		DeepCopyExpr(centroidsId), // centroids.centroid_id
   250  	})
   251  	if err != nil {
   252  		return -1, err
   253  	}
   254  
   255  	minSerialFullL2DistanceAndCentroidId, err := BindFuncExprImplByPlanExpr(builder.GetContext(), "min", []*plan.Expr{
   256  		serialL2DistanceCentroidId,
   257  	})
   258  	if err != nil {
   259  		return -1, err
   260  	}
   261  
   262  	aggList := []*plan.Expr{
   263  		minSerialFullL2DistanceAndCentroidId,
   264  	}
   265  
   266  	// 1.c Project List
   267  	centroidsIdOfMinimumL2Distance, err := BindFuncExprImplByPlanExpr(builder.GetContext(), "serial_extract", []*plan.Expr{
   268  		{
   269  			Typ: makePlan2TypeValue(&varCharType),
   270  			Expr: &plan.Expr_Col{
   271  				Col: &plan.ColRef{
   272  					RelPos: -2,                          // -1 is group by, -2 is agg function
   273  					ColPos: int32(0 + len(groupByList)), // agg function is the one after `group by`
   274  				},
   275  			},
   276  		},
   277  		makePlan2Int64ConstExprWithType(1),
   278  		{
   279  			Typ: makePlan2TypeValue(&bigIntType),
   280  			Expr: &plan.Expr_T{
   281  				T: &plan.TargetType{},
   282  			},
   283  		},
   284  	})
   285  	if err != nil {
   286  		return -1, err
   287  	}
   288  
   289  	centroidsVersionProj := DeepCopyExpr(centroidsVersion)
   290  	centroidsVersionProj.Expr.(*plan.Expr_Col).Col.RelPos = -1
   291  	centroidsVersionProj.Expr.(*plan.Expr_Col).Col.ColPos = 0
   292  
   293  	tblPkProj := DeepCopyExpr(tblPk)
   294  	tblPkProj.Expr.(*plan.Expr_Col).Col.RelPos = -1
   295  	tblPkProj.Expr.(*plan.Expr_Col).Col.ColPos = 1
   296  
   297  	// 1.d Create a new AGG node
   298  	// NOTE: Don't add
   299  	// serial(centroidsVersionProj, "centroidsIdOfMinimumL2Distance", tblPkProj) in here as you will be computing
   300  	// the same value multiple times. Instead, add serial(...) in the next PROJECT node.
   301  	projectionList := []*plan.Expr{
   302  		centroidsVersionProj,           // centroids.version
   303  		centroidsIdOfMinimumL2Distance, // centroids.centroid_id
   304  		tblPkProj,                      // tbl.pk
   305  	}
   306  
   307  	newNodeID := builder.appendNode(
   308  		&plan.Node{
   309  			NodeType:    plan.Node_AGG,
   310  			Children:    []int32{crossJoinTblAndCentroidsID},
   311  			ProjectList: projectionList,
   312  			AggList:     aggList,
   313  			GroupBy:     groupByList,
   314  		},
   315  		bindCtx)
   316  
   317  	// 2.a Project List
   318  	lastProjections := getProjectionByLastNode(builder, newNodeID)
   319  
   320  	// 2.b Create a serial(...) expression
   321  	cpKey, err := BindFuncExprImplByPlanExpr(builder.GetContext(), "serial", []*plan.Expr{
   322  		DeepCopyExpr(lastProjections[0]),
   323  		DeepCopyExpr(lastProjections[1]),
   324  		DeepCopyExpr(lastProjections[2]),
   325  	})
   326  	if err != nil {
   327  		return -1, err
   328  	}
   329  
   330  	// 2.c Create a new PROJECT node
   331  	project := builder.appendNode(&plan.Node{
   332  		NodeType: plan.Node_PROJECT,
   333  		Children: []int32{newNodeID},
   334  		ProjectList: []*Expr{
   335  			lastProjections[0],
   336  			lastProjections[1],
   337  			lastProjections[2],
   338  			cpKey,
   339  		},
   340  	}, bindCtx)
   341  
   342  	return project, nil
   343  }
   344  
   345  func makeFinalProjectWithTblEmbedding(builder *QueryBuilder, bindCtx *BindContext,
   346  	lastNodeId, minCentroidIdNode int32,
   347  	tableDef *TableDef,
   348  	typeOriginPk Type, posOriginPk int,
   349  	typeOriginVecColumn Type, posOriginVecColumn int) (int32, error) {
   350  
   351  	condExpr, err := BindFuncExprImplByPlanExpr(builder.GetContext(), "=", []*Expr{
   352  		{ // tbl.pk
   353  			Typ: typeOriginPk,
   354  			Expr: &plan.Expr_Col{
   355  				Col: &plan.ColRef{
   356  					RelPos: 0,
   357  					ColPos: int32(posOriginPk),
   358  					Name:   tableDef.Cols[posOriginPk].Name,
   359  				},
   360  			},
   361  		},
   362  		{ // join.pk
   363  			Typ: typeOriginPk,
   364  			Expr: &plan.Expr_Col{
   365  				Col: &plan.ColRef{
   366  					RelPos: 1,
   367  					ColPos: 2,
   368  					Name:   tableDef.Cols[posOriginPk].Name,
   369  				},
   370  			},
   371  		},
   372  	})
   373  	if err != nil {
   374  		return -1, err
   375  	}
   376  
   377  	// 0: centroids.version,
   378  	// 1: centroids.centroid_id,
   379  	// 2: tbl.pk,
   380  	// 3: tbl.embedding,
   381  	var rProjections = getProjectionByLastNode(builder, minCentroidIdNode)
   382  
   383  	rCentroidsVersion := DeepCopyExpr(rProjections[0])
   384  	rCentroidsCentroidId := DeepCopyExpr(rProjections[1])
   385  	rTblPk := DeepCopyExpr(rProjections[2])
   386  	rCpKey := DeepCopyExpr(rProjections[3])
   387  
   388  	rCentroidsVersion.Expr.(*plan.Expr_Col).Col.RelPos = 1
   389  	rCentroidsCentroidId.Expr.(*plan.Expr_Col).Col.RelPos = 1
   390  	rTblPk.Expr.(*plan.Expr_Col).Col.RelPos = 1
   391  	rCpKey.Expr.(*plan.Expr_Col).Col.RelPos = 1
   392  
   393  	finalProjectId := builder.appendNode(&plan.Node{
   394  		NodeType: plan.Node_JOIN,
   395  		JoinType: plan.Node_INNER,
   396  		Children: []int32{lastNodeId, minCentroidIdNode},
   397  		// version, centroid_id, pk, serial(version,pk)
   398  		ProjectList: []*Expr{
   399  			DeepCopyExpr(rCentroidsVersion),
   400  			DeepCopyExpr(rCentroidsCentroidId),
   401  			DeepCopyExpr(rTblPk),
   402  			{ // tbl.pk
   403  				Typ: typeOriginVecColumn,
   404  				Expr: &plan.Expr_Col{
   405  					Col: &plan.ColRef{
   406  						RelPos: 0,
   407  						ColPos: int32(posOriginVecColumn),
   408  						Name:   tableDef.Cols[posOriginVecColumn].Name,
   409  					},
   410  				},
   411  			},
   412  			rCpKey,
   413  		},
   414  		OnList: []*Expr{condExpr},
   415  	}, bindCtx)
   416  
   417  	return finalProjectId, nil
   418  }