github.com/weaviate/weaviate@v1.24.6/entities/vectorindex/hnsw/pq_config.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package hnsw
    13  
    14  import (
    15  	"fmt"
    16  
    17  	"github.com/weaviate/weaviate/entities/vectorindex/common"
    18  )
    19  
    20  const (
    21  	PQEncoderTypeKMeans            = "kmeans"
    22  	PQEncoderTypeTile              = "tile"
    23  	PQEncoderDistributionLogNormal = "log-normal"
    24  	PQEncoderDistributionNormal    = "normal"
    25  )
    26  
    27  const (
    28  	DefaultPQEnabled             = false
    29  	DefaultPQBitCompression      = false
    30  	DefaultPQSegments            = 0
    31  	DefaultPQEncoderType         = PQEncoderTypeKMeans
    32  	DefaultPQEncoderDistribution = PQEncoderDistributionLogNormal
    33  	DefaultPQCentroids           = 256
    34  	DefaultPQTrainingLimit       = 100000
    35  )
    36  
    37  // Product Quantization encoder configuration
    38  type PQEncoder struct {
    39  	Type         string `json:"type"`
    40  	Distribution string `json:"distribution,omitempty"`
    41  }
    42  
    43  // Product Quantization configuration
    44  type PQConfig struct {
    45  	Enabled        bool      `json:"enabled"`
    46  	BitCompression bool      `json:"bitCompression"`
    47  	Segments       int       `json:"segments"`
    48  	Centroids      int       `json:"centroids"`
    49  	TrainingLimit  int       `json:"trainingLimit"`
    50  	Encoder        PQEncoder `json:"encoder"`
    51  }
    52  
    53  func validEncoder(v string) error {
    54  	switch v {
    55  	case PQEncoderTypeKMeans:
    56  	case PQEncoderTypeTile:
    57  	default:
    58  		return fmt.Errorf("invalid encoder type %s", v)
    59  	}
    60  
    61  	return nil
    62  }
    63  
    64  func validEncoderDistribution(v string) error {
    65  	switch v {
    66  	case PQEncoderDistributionLogNormal:
    67  	case PQEncoderDistributionNormal:
    68  	default:
    69  		return fmt.Errorf("invalid encoder distribution %s", v)
    70  	}
    71  
    72  	return nil
    73  }
    74  
    75  func ValidatePQConfig(cfg PQConfig) error {
    76  	if !cfg.Enabled {
    77  		return nil
    78  	}
    79  	err := validEncoder(cfg.Encoder.Type)
    80  	if err != nil {
    81  		return err
    82  	}
    83  
    84  	err = validEncoderDistribution(cfg.Encoder.Distribution)
    85  	if err != nil {
    86  		return err
    87  	}
    88  
    89  	return nil
    90  }
    91  
    92  func encoderFromMap(in map[string]interface{}, setFn func(v string)) error {
    93  	value, ok := in["type"]
    94  	if !ok {
    95  		return nil
    96  	}
    97  
    98  	asString, ok := value.(string)
    99  	if !ok {
   100  		return nil
   101  	}
   102  
   103  	err := validEncoder(asString)
   104  	if err != nil {
   105  		return err
   106  	}
   107  
   108  	setFn(asString)
   109  	return nil
   110  }
   111  
   112  func encoderDistributionFromMap(in map[string]interface{}, setFn func(v string)) error {
   113  	value, ok := in["distribution"]
   114  	if !ok {
   115  		return nil
   116  	}
   117  
   118  	asString, ok := value.(string)
   119  	if !ok {
   120  		return nil
   121  	}
   122  
   123  	err := validEncoderDistribution(asString)
   124  	if err != nil {
   125  		return err
   126  	}
   127  
   128  	setFn(asString)
   129  	return nil
   130  }
   131  
   132  func parsePQMap(in map[string]interface{}, pq *PQConfig) error {
   133  	pqConfigValue, ok := in["pq"]
   134  	if !ok {
   135  		return nil
   136  	}
   137  
   138  	pqConfigMap, ok := pqConfigValue.(map[string]interface{})
   139  	if !ok {
   140  		return nil
   141  	}
   142  
   143  	if err := common.OptionalBoolFromMap(pqConfigMap, "enabled", func(v bool) {
   144  		pq.Enabled = v
   145  	}); err != nil {
   146  		return err
   147  	}
   148  
   149  	if err := common.OptionalBoolFromMap(pqConfigMap, "bitCompression", func(v bool) {
   150  		pq.BitCompression = v
   151  	}); err != nil {
   152  		return err
   153  	}
   154  
   155  	if err := common.OptionalIntFromMap(pqConfigMap, "segments", func(v int) {
   156  		pq.Segments = v
   157  	}); err != nil {
   158  		return err
   159  	}
   160  
   161  	if err := common.OptionalIntFromMap(pqConfigMap, "centroids", func(v int) {
   162  		pq.Centroids = v
   163  	}); err != nil {
   164  		return err
   165  	}
   166  
   167  	if err := common.OptionalIntFromMap(pqConfigMap, "trainingLimit", func(v int) {
   168  		pq.TrainingLimit = v
   169  	}); err != nil {
   170  		return err
   171  	}
   172  
   173  	pqEncoderValue, ok := pqConfigMap["encoder"]
   174  	if !ok {
   175  		return nil
   176  	}
   177  
   178  	pqEncoderMap, ok := pqEncoderValue.(map[string]interface{})
   179  	if !ok {
   180  		return nil
   181  	}
   182  
   183  	if err := encoderFromMap(pqEncoderMap, func(v string) {
   184  		pq.Encoder.Type = v
   185  	}); err != nil {
   186  		return err
   187  	}
   188  
   189  	if err := encoderDistributionFromMap(pqEncoderMap, func(v string) {
   190  		pq.Encoder.Distribution = v
   191  	}); err != nil {
   192  		return err
   193  	}
   194  
   195  	return nil
   196  }