github.com/weaviate/weaviate@v1.24.6/modules/generative-openai/config/class_settings_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package config
    13  
    14  import (
    15  	"testing"
    16  
    17  	"github.com/pkg/errors"
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/weaviate/weaviate/entities/moduletools"
    20  )
    21  
    22  func Test_classSettings_Validate(t *testing.T) {
    23  	tests := []struct {
    24  		name                 string
    25  		cfg                  moduletools.ClassConfig
    26  		wantModel            string
    27  		wantMaxTokens        float64
    28  		wantTemperature      float64
    29  		wantTopP             float64
    30  		wantFrequencyPenalty float64
    31  		wantPresencePenalty  float64
    32  		wantResourceName     string
    33  		wantDeploymentID     string
    34  		wantIsAzure          bool
    35  		wantErr              error
    36  		wantBaseURL          string
    37  		wantApiVersion       string
    38  	}{
    39  		{
    40  			name: "Happy flow",
    41  			cfg: fakeClassConfig{
    42  				classConfig: map[string]interface{}{},
    43  			},
    44  			wantModel:            "gpt-3.5-turbo",
    45  			wantMaxTokens:        4097,
    46  			wantTemperature:      0.0,
    47  			wantTopP:             1,
    48  			wantFrequencyPenalty: 0.0,
    49  			wantPresencePenalty:  0.0,
    50  			wantErr:              nil,
    51  			wantBaseURL:          "https://api.openai.com",
    52  			wantApiVersion:       "2023-05-15",
    53  		},
    54  		{
    55  			name: "Everything non default configured",
    56  			cfg: fakeClassConfig{
    57  				classConfig: map[string]interface{}{
    58  					"model":            "gpt-3.5-turbo",
    59  					"maxTokens":        4097,
    60  					"temperature":      0.5,
    61  					"topP":             3,
    62  					"frequencyPenalty": 0.1,
    63  					"presencePenalty":  0.9,
    64  				},
    65  			},
    66  			wantModel:            "gpt-3.5-turbo",
    67  			wantMaxTokens:        4097,
    68  			wantTemperature:      0.5,
    69  			wantTopP:             3,
    70  			wantFrequencyPenalty: 0.1,
    71  			wantPresencePenalty:  0.9,
    72  			wantErr:              nil,
    73  			wantBaseURL:          "https://api.openai.com",
    74  			wantApiVersion:       "2023-05-15",
    75  		},
    76  		{
    77  			name: "OpenAI Proxy",
    78  			cfg: fakeClassConfig{
    79  				classConfig: map[string]interface{}{
    80  					"model":            "gpt-3.5-turbo",
    81  					"maxTokens":        4097,
    82  					"temperature":      0.5,
    83  					"topP":             3,
    84  					"frequencyPenalty": 0.1,
    85  					"presencePenalty":  0.9,
    86  					"baseURL":          "https://proxy.weaviate.dev/",
    87  				},
    88  			},
    89  			wantBaseURL:          "https://proxy.weaviate.dev/",
    90  			wantApiVersion:       "2023-05-15",
    91  			wantModel:            "gpt-3.5-turbo",
    92  			wantMaxTokens:        4097,
    93  			wantTemperature:      0.5,
    94  			wantTopP:             3,
    95  			wantFrequencyPenalty: 0.1,
    96  			wantPresencePenalty:  0.9,
    97  			wantErr:              nil,
    98  		},
    99  		{
   100  			name: "Legacy config",
   101  			cfg: fakeClassConfig{
   102  				classConfig: map[string]interface{}{
   103  					"model":            "text-davinci-003",
   104  					"maxTokens":        1200,
   105  					"temperature":      0.5,
   106  					"topP":             3,
   107  					"frequencyPenalty": 0.1,
   108  					"presencePenalty":  0.9,
   109  				},
   110  			},
   111  			wantModel:            "text-davinci-003",
   112  			wantMaxTokens:        1200,
   113  			wantTemperature:      0.5,
   114  			wantTopP:             3,
   115  			wantFrequencyPenalty: 0.1,
   116  			wantPresencePenalty:  0.9,
   117  			wantErr:              nil,
   118  			wantBaseURL:          "https://api.openai.com",
   119  			wantApiVersion:       "2023-05-15",
   120  		},
   121  		{
   122  			name: "Azure OpenAI config",
   123  			cfg: fakeClassConfig{
   124  				classConfig: map[string]interface{}{
   125  					"resourceName":     "weaviate",
   126  					"deploymentId":     "gpt-3.5-turbo",
   127  					"maxTokens":        4097,
   128  					"temperature":      0.5,
   129  					"topP":             3,
   130  					"frequencyPenalty": 0.1,
   131  					"presencePenalty":  0.9,
   132  				},
   133  			},
   134  			wantResourceName:     "weaviate",
   135  			wantDeploymentID:     "gpt-3.5-turbo",
   136  			wantIsAzure:          true,
   137  			wantModel:            "gpt-3.5-turbo",
   138  			wantMaxTokens:        4097,
   139  			wantTemperature:      0.5,
   140  			wantTopP:             3,
   141  			wantFrequencyPenalty: 0.1,
   142  			wantPresencePenalty:  0.9,
   143  			wantErr:              nil,
   144  			wantBaseURL:          "https://api.openai.com",
   145  			wantApiVersion:       "2023-05-15",
   146  		},
   147  		{
   148  			name: "Azure OpenAI config with baseURL",
   149  			cfg: fakeClassConfig{
   150  				classConfig: map[string]interface{}{
   151  					"baseURL":          "some-base-url",
   152  					"resourceName":     "weaviate",
   153  					"deploymentId":     "gpt-3.5-turbo",
   154  					"maxTokens":        4097,
   155  					"temperature":      0.5,
   156  					"topP":             3,
   157  					"frequencyPenalty": 0.1,
   158  					"presencePenalty":  0.9,
   159  				},
   160  			},
   161  			wantResourceName:     "weaviate",
   162  			wantDeploymentID:     "gpt-3.5-turbo",
   163  			wantIsAzure:          true,
   164  			wantModel:            "gpt-3.5-turbo",
   165  			wantMaxTokens:        4097,
   166  			wantTemperature:      0.5,
   167  			wantTopP:             3,
   168  			wantFrequencyPenalty: 0.1,
   169  			wantPresencePenalty:  0.9,
   170  			wantErr:              nil,
   171  			wantBaseURL:          "some-base-url",
   172  			wantApiVersion:       "2023-05-15",
   173  		},
   174  		{
   175  			name: "With gpt-3.5-turbo-16k model",
   176  			cfg: fakeClassConfig{
   177  				classConfig: map[string]interface{}{
   178  					"model":            "gpt-3.5-turbo-16k",
   179  					"maxTokens":        4097,
   180  					"temperature":      0.5,
   181  					"topP":             3,
   182  					"frequencyPenalty": 0.1,
   183  					"presencePenalty":  0.9,
   184  				},
   185  			},
   186  			wantModel:            "gpt-3.5-turbo-16k",
   187  			wantMaxTokens:        4097,
   188  			wantTemperature:      0.5,
   189  			wantTopP:             3,
   190  			wantFrequencyPenalty: 0.1,
   191  			wantPresencePenalty:  0.9,
   192  			wantErr:              nil,
   193  			wantBaseURL:          "https://api.openai.com",
   194  			wantApiVersion:       "2023-05-15",
   195  		},
   196  		{
   197  			name: "Wrong maxTokens configured",
   198  			cfg: fakeClassConfig{
   199  				classConfig: map[string]interface{}{
   200  					"maxTokens": true,
   201  				},
   202  			},
   203  			wantErr: errors.Errorf("Wrong maxTokens configuration, values are should have a minimal value of 1 and max is dependant on the model used"),
   204  		},
   205  		{
   206  			name: "Wrong temperature configured",
   207  			cfg: fakeClassConfig{
   208  				classConfig: map[string]interface{}{
   209  					"temperature": true,
   210  				},
   211  			},
   212  			wantErr: errors.Errorf("Wrong temperature configuration, values are between 0.0 and 1.0"),
   213  		},
   214  		{
   215  			name: "Wrong frequencyPenalty configured",
   216  			cfg: fakeClassConfig{
   217  				classConfig: map[string]interface{}{
   218  					"frequencyPenalty": true,
   219  				},
   220  			},
   221  			wantErr: errors.Errorf("Wrong frequencyPenalty configuration, values are between 0.0 and 1.0"),
   222  		},
   223  		{
   224  			name: "Wrong presencePenalty configured",
   225  			cfg: fakeClassConfig{
   226  				classConfig: map[string]interface{}{
   227  					"presencePenalty": true,
   228  				},
   229  			},
   230  			wantErr: errors.Errorf("Wrong presencePenalty configuration, values are between 0.0 and 1.0"),
   231  		},
   232  		{
   233  			name: "Wrong topP configured",
   234  			cfg: fakeClassConfig{
   235  				classConfig: map[string]interface{}{
   236  					"topP": true,
   237  				},
   238  			},
   239  			wantErr: errors.Errorf("Wrong topP configuration, values are should have a minimal value of 1 and max of 5"),
   240  		},
   241  		{
   242  			name: "Wrong Azure config - empty deploymentId",
   243  			cfg: fakeClassConfig{
   244  				classConfig: map[string]interface{}{
   245  					"resourceName": "resource-name",
   246  				},
   247  			},
   248  			wantErr: errors.Errorf("both resourceName and deploymentId must be provided"),
   249  		},
   250  		{
   251  			name: "Wrong Azure config - empty resourceName",
   252  			cfg: fakeClassConfig{
   253  				classConfig: map[string]interface{}{
   254  					"deploymentId": "deployment-name",
   255  				},
   256  			},
   257  			wantErr: errors.Errorf("both resourceName and deploymentId must be provided"),
   258  		},
   259  		{
   260  			name: "Wrong Azure config - wrong api version",
   261  			cfg: fakeClassConfig{
   262  				classConfig: map[string]interface{}{
   263  					"apiVersion": "wrong-api-version",
   264  				},
   265  			},
   266  			wantErr: errors.Errorf("wrong Azure OpenAI apiVersion, available api versions are: " +
   267  				"[2022-12-01 2023-03-15-preview 2023-05-15 2023-06-01-preview 2023-07-01-preview " +
   268  				"2023-08-01-preview 2023-09-01-preview 2023-12-01-preview]"),
   269  		},
   270  	}
   271  	for _, tt := range tests {
   272  		t.Run(tt.name, func(t *testing.T) {
   273  			ic := NewClassSettings(tt.cfg)
   274  			if tt.wantErr != nil {
   275  				assert.EqualError(t, tt.wantErr, ic.Validate(nil).Error())
   276  			} else {
   277  				assert.Equal(t, tt.wantModel, ic.Model())
   278  				assert.Equal(t, tt.wantMaxTokens, ic.MaxTokens())
   279  				assert.Equal(t, tt.wantTemperature, ic.Temperature())
   280  				assert.Equal(t, tt.wantTopP, ic.TopP())
   281  				assert.Equal(t, tt.wantFrequencyPenalty, ic.FrequencyPenalty())
   282  				assert.Equal(t, tt.wantPresencePenalty, ic.PresencePenalty())
   283  				assert.Equal(t, tt.wantResourceName, ic.ResourceName())
   284  				assert.Equal(t, tt.wantDeploymentID, ic.DeploymentID())
   285  				assert.Equal(t, tt.wantIsAzure, ic.IsAzure())
   286  				assert.Equal(t, tt.wantBaseURL, ic.BaseURL())
   287  				assert.Equal(t, tt.wantApiVersion, ic.ApiVersion())
   288  			}
   289  		})
   290  	}
   291  }
   292  
   293  type fakeClassConfig struct {
   294  	classConfig map[string]interface{}
   295  }
   296  
   297  func (f fakeClassConfig) Class() map[string]interface{} {
   298  	return f.classConfig
   299  }
   300  
   301  func (f fakeClassConfig) Tenant() string {
   302  	return ""
   303  }
   304  
   305  func (f fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} {
   306  	return f.classConfig
   307  }
   308  
   309  func (f fakeClassConfig) Property(propName string) map[string]interface{} {
   310  	return nil
   311  }
   312  
   313  func (f fakeClassConfig) TargetVector() string {
   314  	return ""
   315  }