github.com/weaviate/weaviate@v1.24.6/usecases/objects/batch_references_add.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package objects
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  	"strings"
    18  	"sync"
    19  
    20  	enterrors "github.com/weaviate/weaviate/entities/errors"
    21  
    22  	"github.com/go-openapi/strfmt"
    23  	"github.com/weaviate/weaviate/entities/additional"
    24  	"github.com/weaviate/weaviate/entities/models"
    25  	"github.com/weaviate/weaviate/entities/schema"
    26  	"github.com/weaviate/weaviate/entities/schema/crossref"
    27  )
    28  
    29  // AddReferences Class Instances in batch to the connected DB
    30  func (b *BatchManager) AddReferences(ctx context.Context, principal *models.Principal,
    31  	refs []*models.BatchReference, repl *additional.ReplicationProperties,
    32  ) (BatchReferences, error) {
    33  	err := b.authorizer.Authorize(principal, "update", "batch/*")
    34  	if err != nil {
    35  		return nil, err
    36  	}
    37  
    38  	unlock, err := b.locks.LockSchema()
    39  	if err != nil {
    40  		return nil, NewErrInternal("could not acquire lock: %v", err)
    41  	}
    42  	defer unlock()
    43  
    44  	b.metrics.BatchRefInc()
    45  	defer b.metrics.BatchRefDec()
    46  
    47  	return b.addReferences(ctx, principal, refs, repl)
    48  }
    49  
    50  func (b *BatchManager) addReferences(ctx context.Context, principal *models.Principal,
    51  	refs []*models.BatchReference, repl *additional.ReplicationProperties,
    52  ) (BatchReferences, error) {
    53  	if err := b.validateReferenceForm(refs); err != nil {
    54  		return nil, NewErrInvalidUserInput("invalid params: %v", err)
    55  	}
    56  
    57  	batchReferences := b.validateReferencesConcurrently(ctx, principal, refs)
    58  
    59  	if err := b.autodetectToClass(ctx, principal, batchReferences); err != nil {
    60  		return nil, err
    61  	}
    62  
    63  	// MT validation must be done after auto-detection as we cannot know the target class beforehand in all cases
    64  	for i, ref := range batchReferences {
    65  		if ref.Err == nil {
    66  			if shouldValidateMultiTenantRef(ref.Tenant, ref.From, ref.To) {
    67  				// can only validate multi-tenancy when everything above succeeds
    68  				err := validateReferenceMultiTenancy(ctx, principal, b.schemaManager, b.vectorRepo, ref.From, ref.To, ref.Tenant)
    69  				if err != nil {
    70  					batchReferences[i].Err = err
    71  				}
    72  			}
    73  		}
    74  	}
    75  
    76  	if res, err := b.vectorRepo.AddBatchReferences(ctx, batchReferences, repl); err != nil {
    77  		return nil, NewErrInternal("could not add batch request to connector: %v", err)
    78  	} else {
    79  		return res, nil
    80  	}
    81  }
    82  
    83  func (b *BatchManager) validateReferenceForm(refs []*models.BatchReference) error {
    84  	if len(refs) == 0 {
    85  		return fmt.Errorf("length cannot be 0, need at least one reference for batching")
    86  	}
    87  
    88  	return nil
    89  }
    90  
    91  func (b *BatchManager) validateReferencesConcurrently(ctx context.Context,
    92  	principal *models.Principal, refs []*models.BatchReference,
    93  ) BatchReferences {
    94  	c := make(chan BatchReference, len(refs))
    95  	wg := new(sync.WaitGroup)
    96  
    97  	// Generate a goroutine for each separate request
    98  	for i, ref := range refs {
    99  		i := i
   100  		ref := ref
   101  		wg.Add(1)
   102  		enterrors.GoWrapper(func() { b.validateReference(ctx, principal, wg, ref, i, &c) }, b.logger)
   103  	}
   104  
   105  	wg.Wait()
   106  	close(c)
   107  
   108  	return referencesChanToSlice(c)
   109  }
   110  
   111  // autodetectToClass gets the class name of the referenced class through the schema definition
   112  func (b *BatchManager) autodetectToClass(ctx context.Context,
   113  	principal *models.Principal, batchReferences BatchReferences,
   114  ) error {
   115  	classPropTarget := make(map[string]string)
   116  	scheme, err := b.schemaManager.GetSchema(principal)
   117  	if err != nil {
   118  		return NewErrInvalidUserInput("get schema: %v", err)
   119  	}
   120  	for i, ref := range batchReferences {
   121  		// get to class from property datatype
   122  		if ref.To.Class != "" || ref.Err != nil {
   123  			continue
   124  		}
   125  		className := string(ref.From.Class)
   126  		propName := schema.LowercaseFirstLetter(string(ref.From.Property))
   127  
   128  		target, ok := classPropTarget[className+propName]
   129  		if !ok {
   130  			class := scheme.FindClassByName(ref.From.Class)
   131  			if class == nil {
   132  				batchReferences[i].Err = fmt.Errorf("class %s does not exist", className)
   133  				continue
   134  			}
   135  
   136  			prop, err := schema.GetPropertyByName(class, propName)
   137  			if err != nil {
   138  				batchReferences[i].Err = fmt.Errorf("property %s does not exist for class %s", propName, className)
   139  				continue
   140  			}
   141  			if len(prop.DataType) > 1 {
   142  				continue // can't auto-detect for multi-target
   143  			}
   144  			target = prop.DataType[0] // datatype is the name of the class that is referenced
   145  			classPropTarget[className+propName] = target
   146  		}
   147  		batchReferences[i].To.Class = target
   148  	}
   149  	return nil
   150  }
   151  
   152  func (b *BatchManager) validateReference(ctx context.Context, principal *models.Principal,
   153  	wg *sync.WaitGroup, ref *models.BatchReference, i int, resultsC *chan BatchReference,
   154  ) {
   155  	defer wg.Done()
   156  	var validateErrors []error
   157  	source, err := crossref.ParseSource(string(ref.From))
   158  	if err != nil {
   159  		validateErrors = append(validateErrors, err)
   160  	} else if !source.Local {
   161  		validateErrors = append(validateErrors, fmt.Errorf("source class must always point to the local peer, but got %s",
   162  			source.PeerName))
   163  	}
   164  
   165  	target, err := crossref.Parse(string(ref.To))
   166  	if err != nil {
   167  		validateErrors = append(validateErrors, err)
   168  	} else if !target.Local {
   169  		validateErrors = append(validateErrors, fmt.Errorf("importing network references in batch is not possible. "+
   170  			"Please perform a regular non-batch import for network references, got peer %s",
   171  			target.PeerName))
   172  	}
   173  
   174  	// target id must be lowercase
   175  	target.TargetID = strfmt.UUID(strings.ToLower(target.TargetID.String()))
   176  
   177  	if len(validateErrors) == 0 {
   178  		err = nil
   179  	} else {
   180  		err = joinErrors(validateErrors)
   181  	}
   182  
   183  	*resultsC <- BatchReference{
   184  		From:          source,
   185  		To:            target,
   186  		Err:           err,
   187  		OriginalIndex: i,
   188  		Tenant:        ref.Tenant,
   189  	}
   190  }
   191  
   192  func validateReferenceMultiTenancy(ctx context.Context,
   193  	principal *models.Principal, schemaManager schemaManager,
   194  	repo VectorRepo, source *crossref.RefSource, target *crossref.Ref,
   195  	tenant string,
   196  ) error {
   197  	if source == nil || target == nil {
   198  		return fmt.Errorf("can't validate multi-tenancy for nil refs")
   199  	}
   200  
   201  	sourceClass, targetClass, err := getReferenceClasses(
   202  		ctx, principal, schemaManager, source.Class.String(), source.Property.String(), target.Class)
   203  	if err != nil {
   204  		return err
   205  	}
   206  
   207  	sourceEnabled := schema.MultiTenancyEnabled(sourceClass)
   208  	targetEnabled := schema.MultiTenancyEnabled(targetClass)
   209  
   210  	if !sourceEnabled && targetEnabled {
   211  		return fmt.Errorf("invalid reference: cannot reference a multi-tenant " +
   212  			"enabled class from a non multi-tenant enabled class")
   213  	}
   214  	if sourceEnabled && !targetEnabled {
   215  		if err := validateTenantRefObject(ctx, repo, sourceClass, source.TargetID, tenant); err != nil {
   216  			return fmt.Errorf("source: %w", err)
   217  		}
   218  		if err := validateTenantRefObject(ctx, repo, targetClass, target.TargetID, ""); err != nil {
   219  			return fmt.Errorf("target: %w", err)
   220  		}
   221  	}
   222  	// if both classes have MT enabled but different tenant keys,
   223  	// no cross-tenant references can be made
   224  	if sourceEnabled && targetEnabled {
   225  		if err := validateTenantRefObject(ctx, repo, sourceClass, source.TargetID, tenant); err != nil {
   226  			return fmt.Errorf("source: %w", err)
   227  		}
   228  		if err := validateTenantRefObject(ctx, repo, targetClass, target.TargetID, tenant); err != nil {
   229  			return fmt.Errorf("target: %w", err)
   230  		}
   231  	}
   232  
   233  	return nil
   234  }
   235  
   236  func getReferenceClasses(ctx context.Context,
   237  	principal *models.Principal, schemaManager schemaManager,
   238  	classFrom, fromProperty, classTo string,
   239  ) (sourceClass *models.Class, targetClass *models.Class, err error) {
   240  	if classFrom == "" {
   241  		err = fmt.Errorf("references involving a multi-tenancy enabled class " +
   242  			"requires class name in the source beacon url")
   243  		return
   244  	}
   245  
   246  	sourceClass, err = schemaManager.GetClass(ctx, principal, classFrom)
   247  	if err != nil {
   248  		err = fmt.Errorf("get source class %q: %w", classFrom, err)
   249  		return
   250  	}
   251  	if sourceClass == nil {
   252  		err = fmt.Errorf("source class %q not found in schema", classFrom)
   253  		return
   254  	}
   255  	// we can auto-detect the to class from the schema if it is a single target reference
   256  	if classTo == "" {
   257  		refProp, err2 := schema.GetPropertyByName(sourceClass, fromProperty)
   258  		if err2 != nil {
   259  			err = fmt.Errorf("get source refprop %q: %w", classFrom, err2)
   260  			return
   261  		}
   262  
   263  		if len(refProp.DataType) != 1 {
   264  			err = fmt.Errorf("multi-target references require the class name in the target beacon url")
   265  			return
   266  		}
   267  		classTo = refProp.DataType[0]
   268  	}
   269  
   270  	targetClass, err = schemaManager.GetClass(ctx, principal, classTo)
   271  	if err != nil {
   272  		err = fmt.Errorf("get target class %q: %w", classTo, err)
   273  		return
   274  	}
   275  	if targetClass == nil {
   276  		err = fmt.Errorf("target class %q not found in schema", classTo)
   277  		return
   278  	}
   279  	return
   280  }
   281  
   282  // validateTenantRefObject ensures that object exist for the given tenant key.
   283  // This asserts that no cross-tenant references can occur,
   284  // as a class+id which belongs to a different
   285  // tenant will not be found in the searched tenant shard
   286  func validateTenantRefObject(ctx context.Context, repo VectorRepo,
   287  	class *models.Class, ID strfmt.UUID, tenant string,
   288  ) error {
   289  	exists, err := repo.Exists(ctx, class.Class, ID, nil, tenant)
   290  	if err != nil {
   291  		return fmt.Errorf("get object %s/%s: %w", class.Class, ID, err)
   292  	}
   293  	if !exists {
   294  		return fmt.Errorf("object %s/%s not found for tenant %q", class.Class, ID, tenant)
   295  	}
   296  	return nil
   297  }
   298  
   299  func referencesChanToSlice(c chan BatchReference) BatchReferences {
   300  	result := make([]BatchReference, len(c))
   301  	for reference := range c {
   302  		result[reference.OriginalIndex] = reference
   303  	}
   304  
   305  	return result
   306  }
   307  
   308  func joinErrors(errors []error) error {
   309  	errorStrings := []string{}
   310  	for _, err := range errors {
   311  		if err != nil {
   312  			errorStrings = append(errorStrings, err.Error())
   313  		}
   314  	}
   315  
   316  	if len(errorStrings) == 0 {
   317  		return nil
   318  	}
   319  
   320  	return fmt.Errorf(strings.Join(errorStrings, ", "))
   321  }