github.com/weaviate/weaviate@v1.24.6/modules/generative-mistral/config/class_settings_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package config 13 14 import ( 15 "testing" 16 17 "github.com/pkg/errors" 18 "github.com/stretchr/testify/assert" 19 "github.com/weaviate/weaviate/entities/moduletools" 20 ) 21 22 func Test_classSettings_Validate(t *testing.T) { 23 tests := []struct { 24 name string 25 cfg moduletools.ClassConfig 26 wantModel string 27 wantMaxTokens int 28 wantTemperature int 29 wantBaseURL string 30 wantErr error 31 }{ 32 { 33 name: "default settings", 34 cfg: fakeClassConfig{ 35 classConfig: map[string]interface{}{}, 36 }, 37 wantModel: "open-mistral-7b", 38 wantMaxTokens: 2048, 39 wantTemperature: 0, 40 wantBaseURL: "https://api.mistral.ai", 41 wantErr: nil, 42 }, 43 { 44 name: "everything non default configured", 45 cfg: fakeClassConfig{ 46 classConfig: map[string]interface{}{ 47 "model": "mistral-medium", 48 "maxTokens": 50, 49 "temperature": 1, 50 }, 51 }, 52 wantModel: "mistral-medium", 53 wantMaxTokens: 50, 54 wantTemperature: 1, 55 wantBaseURL: "https://api.mistral.ai", 56 wantErr: nil, 57 }, 58 { 59 name: "wrong model configured", 60 cfg: fakeClassConfig{ 61 classConfig: map[string]interface{}{ 62 "model": "wrong-model", 63 }, 64 }, 65 wantErr: errors.Errorf("wrong Mistral model name, available model names are: " + 66 "[open-mistral-7b mistral-tiny-2312 mistral-tiny open-mixtral-8x7b mistral-small-2312 mistral-small mistral-small-2402 mistral-small-latest mistral-medium-latest mistral-medium-2312 mistral-medium mistral-large-latest mistral-large-2402]"), 67 }, 68 { 69 name: "default settings with command-light-nightly", 70 cfg: fakeClassConfig{ 71 classConfig: map[string]interface{}{ 72 "model": "command-light-nightly", 73 }, 74 }, 75 wantModel: "command-light-nightly", 76 wantMaxTokens: 2048, 77 wantTemperature: 0, 78 wantBaseURL: "https://api.mistral.ai", 79 wantErr: nil, 80 }, 81 { 82 name: "default settings with mistral-medium and baseURL", 83 cfg: fakeClassConfig{ 84 classConfig: map[string]interface{}{ 85 "model": "mistral-medium", 86 "baseURL": "http://custom-url.com", 87 }, 88 }, 89 wantModel: "mistral-medium", 90 wantMaxTokens: 2048, 91 wantTemperature: 0, 92 wantBaseURL: "http://custom-url.com", 93 wantErr: nil, 94 }, 95 } 96 for _, tt := range tests { 97 t.Run(tt.name, func(t *testing.T) { 98 ic := NewClassSettings(tt.cfg) 99 if tt.wantErr != nil { 100 assert.Equal(t, tt.wantErr.Error(), ic.Validate(nil).Error()) 101 } else { 102 assert.Equal(t, tt.wantModel, ic.Model()) 103 assert.Equal(t, tt.wantMaxTokens, ic.MaxTokens()) 104 assert.Equal(t, tt.wantTemperature, ic.Temperature()) 105 } 106 }) 107 } 108 } 109 110 type fakeClassConfig struct { 111 classConfig map[string]interface{} 112 } 113 114 func (f fakeClassConfig) Class() map[string]interface{} { 115 return f.classConfig 116 } 117 118 func (f fakeClassConfig) Tenant() string { 119 return "" 120 } 121 122 func (f fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} { 123 return f.classConfig 124 } 125 126 func (f fakeClassConfig) Property(propName string) map[string]interface{} { 127 return nil 128 } 129 130 func (f fakeClassConfig) TargetVector() string { 131 return "" 132 }