github.com/weaviate/weaviate@v1.24.6/modules/generative-aws/config/class_settings_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package config 13 14 import ( 15 "testing" 16 17 "github.com/pkg/errors" 18 "github.com/stretchr/testify/assert" 19 "github.com/weaviate/weaviate/entities/moduletools" 20 ) 21 22 func Test_classSettings_Validate(t *testing.T) { 23 t.Skip("Skipping this test for now") 24 tests := []struct { 25 name string 26 cfg moduletools.ClassConfig 27 wantService string 28 wantRegion string 29 wantModel string 30 wantEndpoint string 31 wantTargetModel string 32 wantTargetVariant string 33 wantMaxTokenCount int 34 wantStopSequences []string 35 wantTemperature float64 36 wantTopP int 37 wantErr error 38 }{ 39 { 40 name: "happy flow - Bedrock", 41 cfg: fakeClassConfig{ 42 classConfig: map[string]interface{}{ 43 "service": "bedrock", 44 "region": "us-east-1", 45 "model": "amazon.titan-tg1-large", 46 }, 47 }, 48 wantService: "bedrock", 49 wantRegion: "us-east-1", 50 wantModel: "amazon.titan-tg1-large", 51 wantMaxTokenCount: 8192, 52 wantStopSequences: []string{}, 53 wantTemperature: 0, 54 wantTopP: 1, 55 }, 56 { 57 name: "happy flow - Sagemaker", 58 cfg: fakeClassConfig{ 59 classConfig: map[string]interface{}{ 60 "service": "sagemaker", 61 "region": "us-east-1", 62 "endpoint": "my-endpoint-deployment", 63 "targetModel": "model", 64 "targetVariant": "variant-1", 65 }, 66 }, 67 wantService: "sagemaker", 68 wantRegion: "us-east-1", 69 wantEndpoint: "my-endpoint-deployment", 70 wantTargetModel: "model", 71 wantTargetVariant: "variant-1", 72 }, 73 { 74 name: "custom values - Bedrock", 75 cfg: fakeClassConfig{ 76 classConfig: map[string]interface{}{ 77 "service": "bedrock", 78 "region": "us-east-1", 79 "model": "amazon.titan-tg1-large", 80 "maxTokenCount": 1, 81 "stopSequences": []string{"test", "test2"}, 82 "temperature": 0.2, 83 "topP": 0, 84 }, 85 }, 86 wantService: "bedrock", 87 wantRegion: "us-east-1", 88 wantModel: "amazon.titan-tg1-large", 89 wantMaxTokenCount: 1, 90 wantStopSequences: []string{"test", "test2"}, 91 wantTemperature: 0.2, 92 wantTopP: 0, 93 }, 94 { 95 name: "custom values - Sagemaker", 96 cfg: fakeClassConfig{ 97 classConfig: map[string]interface{}{ 98 "service": "sagemaker", 99 "region": "us-east-1", 100 "endpoint": "this-is-my-endpoint", 101 "targetModel": "my-target-model", 102 "targetVariant": "my-target¬variant", 103 }, 104 }, 105 wantService: "sagemaker", 106 wantRegion: "us-east-1", 107 wantEndpoint: "this-is-my-endpoint", 108 wantTargetModel: "my-target-model", 109 wantTargetVariant: "my-target¬variant", 110 }, 111 { 112 name: "wrong temperature", 113 cfg: fakeClassConfig{ 114 classConfig: map[string]interface{}{ 115 "service": "bedrock", 116 "region": "us-east-1", 117 "model": "amazon.titan-tg1-large", 118 "temperature": 2, 119 }, 120 }, 121 wantErr: errors.Errorf("temperature has to be float value between 0 and 1"), 122 }, 123 { 124 name: "wrong maxTokenCount", 125 cfg: fakeClassConfig{ 126 classConfig: map[string]interface{}{ 127 "service": "bedrock", 128 "region": "us-east-1", 129 "model": "amazon.titan-tg1-large", 130 "maxTokenCount": 9000, 131 }, 132 }, 133 wantErr: errors.Errorf("maxTokenCount has to be an integer value between 1 and 8096"), 134 }, 135 { 136 name: "wrong topP", 137 cfg: fakeClassConfig{ 138 classConfig: map[string]interface{}{ 139 "service": "bedrock", 140 "region": "us-east-1", 141 "model": "amazon.titan-tg1-large", 142 "topP": 2000, 143 }, 144 }, 145 wantErr: errors.Errorf("topP has to be an integer value between 0 and 1"), 146 }, 147 { 148 name: "wrong all", 149 cfg: fakeClassConfig{ 150 classConfig: map[string]interface{}{ 151 "maxTokenCount": 9000, 152 "temperature": 2, 153 "topP": 3, 154 }, 155 }, 156 wantErr: errors.Errorf("wrong service, " + 157 "available services are: [bedrock sagemaker], " + 158 "region cannot be empty", 159 ), 160 }, 161 } 162 for _, tt := range tests { 163 t.Run(tt.name, func(t *testing.T) { 164 ic := NewClassSettings(tt.cfg) 165 if tt.wantErr != nil { 166 assert.EqualError(t, ic.Validate(nil), tt.wantErr.Error()) 167 } else { 168 assert.Equal(t, tt.wantService, ic.Service()) 169 assert.Equal(t, tt.wantRegion, ic.Region()) 170 assert.Equal(t, tt.wantModel, ic.Model()) 171 assert.Equal(t, tt.wantEndpoint, ic.Endpoint()) 172 assert.Equal(t, tt.wantTargetModel, ic.TargetModel()) 173 assert.Equal(t, tt.wantTargetVariant, ic.TargetVariant()) 174 if ic.Temperature() != nil { 175 assert.Equal(t, tt.wantTemperature, *ic.Temperature()) 176 } 177 assert.Equal(t, tt.wantStopSequences, ic.StopSequences()) 178 if ic.TopP() != nil { 179 assert.Equal(t, tt.wantTopP, *ic.TopP()) 180 } 181 } 182 }) 183 } 184 } 185 186 type fakeClassConfig struct { 187 classConfig map[string]interface{} 188 } 189 190 func (f fakeClassConfig) Class() map[string]interface{} { 191 return f.classConfig 192 } 193 194 func (f fakeClassConfig) Tenant() string { 195 return "" 196 } 197 198 func (f fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} { 199 return f.classConfig 200 } 201 202 func (f fakeClassConfig) Property(propName string) map[string]interface{} { 203 return nil 204 } 205 206 func (f fakeClassConfig) TargetVector() string { 207 return "" 208 }