github.com/weaviate/weaviate@v1.24.6/modules/generative-palm/config/class_settings.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package config 13 14 import ( 15 "fmt" 16 "strings" 17 18 "github.com/pkg/errors" 19 "github.com/weaviate/weaviate/entities/models" 20 "github.com/weaviate/weaviate/entities/moduletools" 21 basesettings "github.com/weaviate/weaviate/usecases/modulecomponents/settings" 22 ) 23 24 const ( 25 apiEndpointProperty = "apiEndpoint" 26 projectIDProperty = "projectId" 27 endpointIDProperty = "endpointId" 28 modelIDProperty = "modelId" 29 temperatureProperty = "temperature" 30 tokenLimitProperty = "tokenLimit" 31 topPProperty = "topP" 32 topKProperty = "topK" 33 ) 34 35 var ( 36 DefaultPaLMApiEndpoint = "us-central1-aiplatform.googleapis.com" 37 DefaultPaLMModel = "chat-bison" 38 DefaultPaLMTemperature = 0.2 39 DefaultTokenLimit = 256 40 DefaultPaLMTopP = 0.95 41 DefaultPaLMTopK = 40 42 DefaulGenerativeAIApiEndpoint = "generativelanguage.googleapis.com" 43 DefaulGenerativeAIModelID = "chat-bison-001" 44 ) 45 46 var supportedGenerativeAIModels = []string{ 47 DefaulGenerativeAIModelID, 48 "gemini-pro", 49 "gemini-pro-vision", 50 "gemini-ultra", 51 } 52 53 type ClassSettings interface { 54 Validate(class *models.Class) error 55 // Module settings 56 ApiEndpoint() string 57 ProjectID() string 58 EndpointID() string 59 ModelID() string 60 61 // parameters 62 // 0.0 - 1.0 63 Temperature() float64 64 // 1 - 1024 65 TokenLimit() int 66 // 1 - 40 67 TopK() int 68 // 0.0 - 1.0 69 TopP() float64 70 } 71 72 type classSettings struct { 73 cfg moduletools.ClassConfig 74 propertyValuesHelper basesettings.PropertyValuesHelper 75 } 76 77 func NewClassSettings(cfg moduletools.ClassConfig) ClassSettings { 78 return &classSettings{cfg: cfg, propertyValuesHelper: basesettings.NewPropertyValuesHelper("generative-palm")} 79 } 80 81 func (ic *classSettings) Validate(class *models.Class) error { 82 if ic.cfg == nil { 83 // we would receive a nil-config on cross-class requests, such as Explore{} 84 return errors.New("empty config") 85 } 86 87 var errorMessages []string 88 89 apiEndpoint := ic.ApiEndpoint() 90 projectID := ic.ProjectID() 91 if apiEndpoint != DefaulGenerativeAIApiEndpoint && projectID == "" { 92 errorMessages = append(errorMessages, fmt.Sprintf("%s cannot be empty", projectIDProperty)) 93 } 94 temperature := ic.Temperature() 95 if temperature < 0 || temperature > 1 { 96 errorMessages = append(errorMessages, fmt.Sprintf("%s has to be float value between 0 and 1", temperatureProperty)) 97 } 98 tokenLimit := ic.TokenLimit() 99 if tokenLimit < 1 || tokenLimit > 1024 { 100 errorMessages = append(errorMessages, fmt.Sprintf("%s has to be an integer value between 1 and 1024", tokenLimitProperty)) 101 } 102 topK := ic.TopK() 103 if topK < 1 || topK > 40 { 104 errorMessages = append(errorMessages, fmt.Sprintf("%s has to be an integer value between 1 and 40", topKProperty)) 105 } 106 topP := ic.TopP() 107 if topP < 0 || topP > 1 { 108 errorMessages = append(errorMessages, fmt.Sprintf("%s has to be float value between 0 and 1", topPProperty)) 109 } 110 // Google MakerSuite 111 model := ic.ModelID() 112 if apiEndpoint == DefaulGenerativeAIApiEndpoint && !contains[string](supportedGenerativeAIModels, model) { 113 errorMessages = append(errorMessages, fmt.Sprintf("%s is not supported available models are: %+v", model, supportedGenerativeAIModels)) 114 } 115 116 if len(errorMessages) > 0 { 117 return fmt.Errorf("%s", strings.Join(errorMessages, ", ")) 118 } 119 120 return nil 121 } 122 123 func (ic *classSettings) getStringProperty(name, defaultValue string) string { 124 return ic.propertyValuesHelper.GetPropertyAsString(ic.cfg, name, defaultValue) 125 } 126 127 func (ic *classSettings) getFloatProperty(name string, defaultValue float64) float64 { 128 asFloat64 := ic.propertyValuesHelper.GetPropertyAsFloat64(ic.cfg, name, &defaultValue) 129 return *asFloat64 130 } 131 132 func (ic *classSettings) getIntProperty(name string, defaultValue int) int { 133 asInt := ic.propertyValuesHelper.GetPropertyAsInt(ic.cfg, name, &defaultValue) 134 return *asInt 135 } 136 137 func (ic *classSettings) getDefaultModel(apiEndpoint string) string { 138 if apiEndpoint == DefaulGenerativeAIApiEndpoint { 139 return DefaulGenerativeAIModelID 140 } 141 return DefaultPaLMModel 142 } 143 144 // PaLM params 145 func (ic *classSettings) ApiEndpoint() string { 146 return ic.getStringProperty(apiEndpointProperty, DefaultPaLMApiEndpoint) 147 } 148 149 func (ic *classSettings) ProjectID() string { 150 return ic.getStringProperty(projectIDProperty, "") 151 } 152 153 func (ic *classSettings) EndpointID() string { 154 return ic.getStringProperty(endpointIDProperty, "") 155 } 156 157 func (ic *classSettings) ModelID() string { 158 return ic.getStringProperty(modelIDProperty, ic.getDefaultModel(ic.ApiEndpoint())) 159 } 160 161 // parameters 162 163 // 0.0 - 1.0 164 func (ic *classSettings) Temperature() float64 { 165 return ic.getFloatProperty(temperatureProperty, DefaultPaLMTemperature) 166 } 167 168 // 1 - 1024 169 func (ic *classSettings) TokenLimit() int { 170 return ic.getIntProperty(tokenLimitProperty, DefaultTokenLimit) 171 } 172 173 // 1 - 40 174 func (ic *classSettings) TopK() int { 175 return ic.getIntProperty(topKProperty, DefaultPaLMTopK) 176 } 177 178 // 0.0 - 1.0 179 func (ic *classSettings) TopP() float64 { 180 return ic.getFloatProperty(topPProperty, DefaultPaLMTopP) 181 } 182 183 func contains[T comparable](s []T, e T) bool { 184 for _, v := range s { 185 if v == e { 186 return true 187 } 188 } 189 return false 190 }