github.com/weaviate/weaviate@v1.24.6/usecases/classification/validation.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package classification
    13  
    14  import (
    15  	"fmt"
    16  
    17  	"github.com/weaviate/weaviate/entities/errorcompounder"
    18  	"github.com/weaviate/weaviate/entities/models"
    19  	"github.com/weaviate/weaviate/entities/schema"
    20  	schemaUC "github.com/weaviate/weaviate/usecases/schema"
    21  )
    22  
    23  const (
    24  	TypeKNN        = "knn"
    25  	TypeContextual = "text2vec-contextionary-contextual"
    26  	TypeZeroShot   = "zeroshot"
    27  )
    28  
    29  type Validator struct {
    30  	schema  schema.Schema
    31  	errors  *errorcompounder.SafeErrorCompounder
    32  	subject models.Classification
    33  }
    34  
    35  func NewValidator(sg schemaUC.SchemaGetter, subject models.Classification) *Validator {
    36  	schema := sg.GetSchemaSkipAuth()
    37  	return &Validator{
    38  		schema:  schema,
    39  		errors:  &errorcompounder.SafeErrorCompounder{},
    40  		subject: subject,
    41  	}
    42  }
    43  
    44  func (v *Validator) Do() error {
    45  	v.validate()
    46  
    47  	err := v.errors.ToError()
    48  	if err != nil {
    49  		return fmt.Errorf("invalid classification: %v", err)
    50  	}
    51  
    52  	return nil
    53  }
    54  
    55  func (v *Validator) validate() {
    56  	if v.subject.Class == "" {
    57  		v.errors.Add(fmt.Errorf("class must be set"))
    58  		return
    59  	}
    60  
    61  	class := v.schema.FindClassByName(schema.ClassName(v.subject.Class))
    62  	if class == nil {
    63  		v.errors.Addf("class '%s' not found in schema", v.subject.Class)
    64  		return
    65  	}
    66  
    67  	v.contextualTypeFeasibility()
    68  	v.knnTypeFeasibility()
    69  	v.basedOnProperties(class)
    70  	v.classifyProperties(class)
    71  }
    72  
    73  func (v *Validator) contextualTypeFeasibility() {
    74  	if !v.typeText2vecContextionaryContextual() {
    75  		return
    76  	}
    77  
    78  	if v.subject.Filters != nil && v.subject.Filters.TrainingSetWhere != nil {
    79  		v.errors.Addf("type is 'text2vec-contextionary-contextual', but 'trainingSetWhere' filter is set, for 'text2vec-contextionary-contextual' there is no training data, instead limit possible target data directly through setting 'targetWhere'")
    80  	}
    81  }
    82  
    83  func (v *Validator) knnTypeFeasibility() {
    84  	if !v.typeKNN() {
    85  		return
    86  	}
    87  
    88  	if v.subject.Filters != nil && v.subject.Filters.TargetWhere != nil {
    89  		v.errors.Addf("type is 'knn', but 'targetWhere' filter is set, for 'knn' you cannot limit target data directly, instead limit training data through setting 'trainingSetWhere'")
    90  	}
    91  }
    92  
    93  func (v *Validator) basedOnProperties(class *models.Class) {
    94  	if v.subject.BasedOnProperties == nil || len(v.subject.BasedOnProperties) == 0 {
    95  		v.errors.Addf("basedOnProperties must have at least one property")
    96  		return
    97  	}
    98  
    99  	if len(v.subject.BasedOnProperties) > 1 {
   100  		v.errors.Addf("only a single property in basedOnProperties supported at the moment, got %v",
   101  			v.subject.BasedOnProperties)
   102  		return
   103  	}
   104  
   105  	for _, prop := range v.subject.BasedOnProperties {
   106  		v.basedOnProperty(class, prop)
   107  	}
   108  }
   109  
   110  func (v *Validator) basedOnProperty(class *models.Class, propName string) {
   111  	prop, ok := v.propertyByName(class, propName)
   112  	if !ok {
   113  		v.errors.Addf("basedOnProperties: property '%s' does not exist", propName)
   114  		return
   115  	}
   116  
   117  	dt, err := v.schema.FindPropertyDataType(prop.DataType)
   118  	if err != nil {
   119  		v.errors.Addf("basedOnProperties: %v", err)
   120  		return
   121  	}
   122  
   123  	if !dt.IsPrimitive() {
   124  		v.errors.Addf("basedOnProperties: property '%s' must be of type 'text'", propName)
   125  		return
   126  	}
   127  
   128  	if dt.AsPrimitive() != schema.DataTypeText {
   129  		v.errors.Addf("basedOnProperties: property '%s' must be of type 'text'", propName)
   130  		return
   131  	}
   132  }
   133  
   134  func (v *Validator) classifyProperties(class *models.Class) {
   135  	if v.subject.ClassifyProperties == nil || len(v.subject.ClassifyProperties) == 0 {
   136  		v.errors.Addf("classifyProperties must have at least one property")
   137  		return
   138  	}
   139  
   140  	for _, prop := range v.subject.ClassifyProperties {
   141  		v.classifyProperty(class, prop)
   142  	}
   143  }
   144  
   145  func (v *Validator) classifyProperty(class *models.Class, propName string) {
   146  	prop, ok := v.propertyByName(class, propName)
   147  	if !ok {
   148  		v.errors.Addf("classifyProperties: property '%s' does not exist", propName)
   149  		return
   150  	}
   151  
   152  	dt, err := v.schema.FindPropertyDataType(prop.DataType)
   153  	if err != nil {
   154  		v.errors.Addf("classifyProperties: %v", err)
   155  		return
   156  	}
   157  
   158  	if !dt.IsReference() {
   159  		v.errors.Addf("classifyProperties: property '%s' must be of reference type (cref)", propName)
   160  		return
   161  	}
   162  
   163  	if v.typeText2vecContextionaryContextual() {
   164  		if len(dt.Classes()) > 1 {
   165  			v.errors.Addf("classifyProperties: property '%s'"+
   166  				" has more than one target class, classification of type 'text2vec-contextionary-contextual' requires exactly one target class", propName)
   167  			return
   168  		}
   169  	}
   170  }
   171  
   172  func (v *Validator) propertyByName(class *models.Class, propName string) (*models.Property, bool) {
   173  	for _, prop := range class.Properties {
   174  		if prop.Name == propName {
   175  			return prop, true
   176  		}
   177  	}
   178  
   179  	return nil, false
   180  }
   181  
   182  func (v *Validator) typeText2vecContextionaryContextual() bool {
   183  	if v.subject.Type == "" {
   184  		return false
   185  	}
   186  
   187  	return v.subject.Type == TypeContextual
   188  }
   189  
   190  func (v *Validator) typeKNN() bool {
   191  	if v.subject.Type == "" {
   192  		return true
   193  	}
   194  
   195  	return v.subject.Type == TypeKNN
   196  }