github.com/weaviate/weaviate@v1.24.6/modules/generative-cohere/config/class_settings_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package config 13 14 import ( 15 "testing" 16 17 "github.com/pkg/errors" 18 "github.com/stretchr/testify/assert" 19 "github.com/weaviate/weaviate/entities/moduletools" 20 ) 21 22 func Test_classSettings_Validate(t *testing.T) { 23 tests := []struct { 24 name string 25 cfg moduletools.ClassConfig 26 wantModel string 27 wantMaxTokens int 28 wantTemperature int 29 wantK int 30 wantStopSequences []string 31 wantReturnLikelihoods string 32 wantBaseURL string 33 wantErr error 34 }{ 35 { 36 name: "default settings", 37 cfg: fakeClassConfig{ 38 classConfig: map[string]interface{}{}, 39 }, 40 wantModel: "command-nightly", 41 wantMaxTokens: 2048, 42 wantTemperature: 0, 43 wantK: 0, 44 wantStopSequences: []string{}, 45 wantReturnLikelihoods: "NONE", 46 wantBaseURL: "https://api.cohere.ai", 47 wantErr: nil, 48 }, 49 { 50 name: "everything non default configured", 51 cfg: fakeClassConfig{ 52 classConfig: map[string]interface{}{ 53 "model": "command-xlarge", 54 "maxTokens": 2048, 55 "temperature": 1, 56 "k": 2, 57 "stopSequences": []string{"stop1", "stop2"}, 58 "returnLikelihoods": "NONE", 59 }, 60 }, 61 wantModel: "command-xlarge", 62 wantMaxTokens: 2048, 63 wantTemperature: 1, 64 wantK: 2, 65 wantStopSequences: []string{"stop1", "stop2"}, 66 wantReturnLikelihoods: "NONE", 67 wantBaseURL: "https://api.cohere.ai", 68 wantErr: nil, 69 }, 70 { 71 name: "wrong model configured", 72 cfg: fakeClassConfig{ 73 classConfig: map[string]interface{}{ 74 "model": "wrong-model", 75 }, 76 }, 77 wantErr: errors.Errorf("wrong Cohere model name, available model names are: " + 78 "[command-xlarge-beta command-xlarge command-medium command-xlarge-nightly " + 79 "command-medium-nightly xlarge medium command command-light command-nightly command-light-nightly base base-light]"), 80 }, 81 { 82 name: "default settings with command-light-nightly", 83 cfg: fakeClassConfig{ 84 classConfig: map[string]interface{}{ 85 "model": "command-light-nightly", 86 }, 87 }, 88 wantModel: "command-light-nightly", 89 wantMaxTokens: 2048, 90 wantTemperature: 0, 91 wantK: 0, 92 wantStopSequences: []string{}, 93 wantReturnLikelihoods: "NONE", 94 wantBaseURL: "https://api.cohere.ai", 95 wantErr: nil, 96 }, 97 { 98 name: "default settings with command-light-nightly and baseURL", 99 cfg: fakeClassConfig{ 100 classConfig: map[string]interface{}{ 101 "model": "command-light-nightly", 102 "baseURL": "http://custom-url.com", 103 }, 104 }, 105 wantModel: "command-light-nightly", 106 wantMaxTokens: 2048, 107 wantTemperature: 0, 108 wantK: 0, 109 wantStopSequences: []string{}, 110 wantReturnLikelihoods: "NONE", 111 wantBaseURL: "http://custom-url.com", 112 wantErr: nil, 113 }, 114 } 115 for _, tt := range tests { 116 t.Run(tt.name, func(t *testing.T) { 117 ic := NewClassSettings(tt.cfg) 118 if tt.wantErr != nil { 119 assert.Equal(t, tt.wantErr.Error(), ic.Validate(nil).Error()) 120 } else { 121 assert.Equal(t, tt.wantModel, ic.Model()) 122 assert.Equal(t, tt.wantMaxTokens, ic.MaxTokens()) 123 assert.Equal(t, tt.wantTemperature, ic.Temperature()) 124 assert.Equal(t, tt.wantK, ic.K()) 125 assert.Equal(t, tt.wantStopSequences, ic.StopSequences()) 126 assert.Equal(t, tt.wantReturnLikelihoods, ic.ReturnLikelihoods()) 127 } 128 }) 129 } 130 } 131 132 type fakeClassConfig struct { 133 classConfig map[string]interface{} 134 } 135 136 func (f fakeClassConfig) Class() map[string]interface{} { 137 return f.classConfig 138 } 139 140 func (f fakeClassConfig) Tenant() string { 141 return "" 142 } 143 144 func (f fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} { 145 return f.classConfig 146 } 147 148 func (f fakeClassConfig) Property(propName string) map[string]interface{} { 149 return nil 150 } 151 152 func (f fakeClassConfig) TargetVector() string { 153 return "" 154 }