github.com/weaviate/weaviate@v1.24.6/usecases/objects/batch_references_add.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package objects 13 14 import ( 15 "context" 16 "fmt" 17 "strings" 18 "sync" 19 20 enterrors "github.com/weaviate/weaviate/entities/errors" 21 22 "github.com/go-openapi/strfmt" 23 "github.com/weaviate/weaviate/entities/additional" 24 "github.com/weaviate/weaviate/entities/models" 25 "github.com/weaviate/weaviate/entities/schema" 26 "github.com/weaviate/weaviate/entities/schema/crossref" 27 ) 28 29 // AddReferences Class Instances in batch to the connected DB 30 func (b *BatchManager) AddReferences(ctx context.Context, principal *models.Principal, 31 refs []*models.BatchReference, repl *additional.ReplicationProperties, 32 ) (BatchReferences, error) { 33 err := b.authorizer.Authorize(principal, "update", "batch/*") 34 if err != nil { 35 return nil, err 36 } 37 38 unlock, err := b.locks.LockSchema() 39 if err != nil { 40 return nil, NewErrInternal("could not acquire lock: %v", err) 41 } 42 defer unlock() 43 44 b.metrics.BatchRefInc() 45 defer b.metrics.BatchRefDec() 46 47 return b.addReferences(ctx, principal, refs, repl) 48 } 49 50 func (b *BatchManager) addReferences(ctx context.Context, principal *models.Principal, 51 refs []*models.BatchReference, repl *additional.ReplicationProperties, 52 ) (BatchReferences, error) { 53 if err := b.validateReferenceForm(refs); err != nil { 54 return nil, NewErrInvalidUserInput("invalid params: %v", err) 55 } 56 57 batchReferences := b.validateReferencesConcurrently(ctx, principal, refs) 58 59 if err := b.autodetectToClass(ctx, principal, batchReferences); err != nil { 60 return nil, err 61 } 62 63 // MT validation must be done after auto-detection as we cannot know the target class beforehand in all cases 64 for i, ref := range batchReferences { 65 if ref.Err == nil { 66 if shouldValidateMultiTenantRef(ref.Tenant, ref.From, ref.To) { 67 // can only validate multi-tenancy when everything above succeeds 68 err := validateReferenceMultiTenancy(ctx, principal, b.schemaManager, b.vectorRepo, ref.From, ref.To, ref.Tenant) 69 if err != nil { 70 batchReferences[i].Err = err 71 } 72 } 73 } 74 } 75 76 if res, err := b.vectorRepo.AddBatchReferences(ctx, batchReferences, repl); err != nil { 77 return nil, NewErrInternal("could not add batch request to connector: %v", err) 78 } else { 79 return res, nil 80 } 81 } 82 83 func (b *BatchManager) validateReferenceForm(refs []*models.BatchReference) error { 84 if len(refs) == 0 { 85 return fmt.Errorf("length cannot be 0, need at least one reference for batching") 86 } 87 88 return nil 89 } 90 91 func (b *BatchManager) validateReferencesConcurrently(ctx context.Context, 92 principal *models.Principal, refs []*models.BatchReference, 93 ) BatchReferences { 94 c := make(chan BatchReference, len(refs)) 95 wg := new(sync.WaitGroup) 96 97 // Generate a goroutine for each separate request 98 for i, ref := range refs { 99 i := i 100 ref := ref 101 wg.Add(1) 102 enterrors.GoWrapper(func() { b.validateReference(ctx, principal, wg, ref, i, &c) }, b.logger) 103 } 104 105 wg.Wait() 106 close(c) 107 108 return referencesChanToSlice(c) 109 } 110 111 // autodetectToClass gets the class name of the referenced class through the schema definition 112 func (b *BatchManager) autodetectToClass(ctx context.Context, 113 principal *models.Principal, batchReferences BatchReferences, 114 ) error { 115 classPropTarget := make(map[string]string) 116 scheme, err := b.schemaManager.GetSchema(principal) 117 if err != nil { 118 return NewErrInvalidUserInput("get schema: %v", err) 119 } 120 for i, ref := range batchReferences { 121 // get to class from property datatype 122 if ref.To.Class != "" || ref.Err != nil { 123 continue 124 } 125 className := string(ref.From.Class) 126 propName := schema.LowercaseFirstLetter(string(ref.From.Property)) 127 128 target, ok := classPropTarget[className+propName] 129 if !ok { 130 class := scheme.FindClassByName(ref.From.Class) 131 if class == nil { 132 batchReferences[i].Err = fmt.Errorf("class %s does not exist", className) 133 continue 134 } 135 136 prop, err := schema.GetPropertyByName(class, propName) 137 if err != nil { 138 batchReferences[i].Err = fmt.Errorf("property %s does not exist for class %s", propName, className) 139 continue 140 } 141 if len(prop.DataType) > 1 { 142 continue // can't auto-detect for multi-target 143 } 144 target = prop.DataType[0] // datatype is the name of the class that is referenced 145 classPropTarget[className+propName] = target 146 } 147 batchReferences[i].To.Class = target 148 } 149 return nil 150 } 151 152 func (b *BatchManager) validateReference(ctx context.Context, principal *models.Principal, 153 wg *sync.WaitGroup, ref *models.BatchReference, i int, resultsC *chan BatchReference, 154 ) { 155 defer wg.Done() 156 var validateErrors []error 157 source, err := crossref.ParseSource(string(ref.From)) 158 if err != nil { 159 validateErrors = append(validateErrors, err) 160 } else if !source.Local { 161 validateErrors = append(validateErrors, fmt.Errorf("source class must always point to the local peer, but got %s", 162 source.PeerName)) 163 } 164 165 target, err := crossref.Parse(string(ref.To)) 166 if err != nil { 167 validateErrors = append(validateErrors, err) 168 } else if !target.Local { 169 validateErrors = append(validateErrors, fmt.Errorf("importing network references in batch is not possible. "+ 170 "Please perform a regular non-batch import for network references, got peer %s", 171 target.PeerName)) 172 } 173 174 // target id must be lowercase 175 target.TargetID = strfmt.UUID(strings.ToLower(target.TargetID.String())) 176 177 if len(validateErrors) == 0 { 178 err = nil 179 } else { 180 err = joinErrors(validateErrors) 181 } 182 183 *resultsC <- BatchReference{ 184 From: source, 185 To: target, 186 Err: err, 187 OriginalIndex: i, 188 Tenant: ref.Tenant, 189 } 190 } 191 192 func validateReferenceMultiTenancy(ctx context.Context, 193 principal *models.Principal, schemaManager schemaManager, 194 repo VectorRepo, source *crossref.RefSource, target *crossref.Ref, 195 tenant string, 196 ) error { 197 if source == nil || target == nil { 198 return fmt.Errorf("can't validate multi-tenancy for nil refs") 199 } 200 201 sourceClass, targetClass, err := getReferenceClasses( 202 ctx, principal, schemaManager, source.Class.String(), source.Property.String(), target.Class) 203 if err != nil { 204 return err 205 } 206 207 sourceEnabled := schema.MultiTenancyEnabled(sourceClass) 208 targetEnabled := schema.MultiTenancyEnabled(targetClass) 209 210 if !sourceEnabled && targetEnabled { 211 return fmt.Errorf("invalid reference: cannot reference a multi-tenant " + 212 "enabled class from a non multi-tenant enabled class") 213 } 214 if sourceEnabled && !targetEnabled { 215 if err := validateTenantRefObject(ctx, repo, sourceClass, source.TargetID, tenant); err != nil { 216 return fmt.Errorf("source: %w", err) 217 } 218 if err := validateTenantRefObject(ctx, repo, targetClass, target.TargetID, ""); err != nil { 219 return fmt.Errorf("target: %w", err) 220 } 221 } 222 // if both classes have MT enabled but different tenant keys, 223 // no cross-tenant references can be made 224 if sourceEnabled && targetEnabled { 225 if err := validateTenantRefObject(ctx, repo, sourceClass, source.TargetID, tenant); err != nil { 226 return fmt.Errorf("source: %w", err) 227 } 228 if err := validateTenantRefObject(ctx, repo, targetClass, target.TargetID, tenant); err != nil { 229 return fmt.Errorf("target: %w", err) 230 } 231 } 232 233 return nil 234 } 235 236 func getReferenceClasses(ctx context.Context, 237 principal *models.Principal, schemaManager schemaManager, 238 classFrom, fromProperty, classTo string, 239 ) (sourceClass *models.Class, targetClass *models.Class, err error) { 240 if classFrom == "" { 241 err = fmt.Errorf("references involving a multi-tenancy enabled class " + 242 "requires class name in the source beacon url") 243 return 244 } 245 246 sourceClass, err = schemaManager.GetClass(ctx, principal, classFrom) 247 if err != nil { 248 err = fmt.Errorf("get source class %q: %w", classFrom, err) 249 return 250 } 251 if sourceClass == nil { 252 err = fmt.Errorf("source class %q not found in schema", classFrom) 253 return 254 } 255 // we can auto-detect the to class from the schema if it is a single target reference 256 if classTo == "" { 257 refProp, err2 := schema.GetPropertyByName(sourceClass, fromProperty) 258 if err2 != nil { 259 err = fmt.Errorf("get source refprop %q: %w", classFrom, err2) 260 return 261 } 262 263 if len(refProp.DataType) != 1 { 264 err = fmt.Errorf("multi-target references require the class name in the target beacon url") 265 return 266 } 267 classTo = refProp.DataType[0] 268 } 269 270 targetClass, err = schemaManager.GetClass(ctx, principal, classTo) 271 if err != nil { 272 err = fmt.Errorf("get target class %q: %w", classTo, err) 273 return 274 } 275 if targetClass == nil { 276 err = fmt.Errorf("target class %q not found in schema", classTo) 277 return 278 } 279 return 280 } 281 282 // validateTenantRefObject ensures that object exist for the given tenant key. 283 // This asserts that no cross-tenant references can occur, 284 // as a class+id which belongs to a different 285 // tenant will not be found in the searched tenant shard 286 func validateTenantRefObject(ctx context.Context, repo VectorRepo, 287 class *models.Class, ID strfmt.UUID, tenant string, 288 ) error { 289 exists, err := repo.Exists(ctx, class.Class, ID, nil, tenant) 290 if err != nil { 291 return fmt.Errorf("get object %s/%s: %w", class.Class, ID, err) 292 } 293 if !exists { 294 return fmt.Errorf("object %s/%s not found for tenant %q", class.Class, ID, tenant) 295 } 296 return nil 297 } 298 299 func referencesChanToSlice(c chan BatchReference) BatchReferences { 300 result := make([]BatchReference, len(c)) 301 for reference := range c { 302 result[reference.OriginalIndex] = reference 303 } 304 305 return result 306 } 307 308 func joinErrors(errors []error) error { 309 errorStrings := []string{} 310 for _, err := range errors { 311 if err != nil { 312 errorStrings = append(errorStrings, err.Error()) 313 } 314 } 315 316 if len(errorStrings) == 0 { 317 return nil 318 } 319 320 return fmt.Errorf(strings.Join(errorStrings, ", ")) 321 }