github.com/weaviate/weaviate@v1.24.6/modules/qna-openai/config/class_settings_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package config
    13  
    14  import (
    15  	"testing"
    16  
    17  	"github.com/pkg/errors"
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/weaviate/weaviate/entities/moduletools"
    20  )
    21  
    22  func Test_classSettings_Validate(t *testing.T) {
    23  	tests := []struct {
    24  		name                 string
    25  		cfg                  moduletools.ClassConfig
    26  		wantModel            string
    27  		wantMaxTokens        float64
    28  		wantTemperature      float64
    29  		wantTopP             float64
    30  		wantFrequencyPenalty float64
    31  		wantPresencePenalty  float64
    32  		wantResourceName     string
    33  		wantDeploymentID     string
    34  		wantIsAzure          bool
    35  		wantErr              error
    36  		wantBaseURL          string
    37  	}{
    38  		{
    39  			name: "Happy flow",
    40  			cfg: fakeClassConfig{
    41  				classConfig: map[string]interface{}{},
    42  			},
    43  			wantModel:            "text-ada-001",
    44  			wantMaxTokens:        16,
    45  			wantTemperature:      0.0,
    46  			wantTopP:             1,
    47  			wantFrequencyPenalty: 0.0,
    48  			wantPresencePenalty:  0.0,
    49  			wantErr:              nil,
    50  			wantBaseURL:          "https://api.openai.com",
    51  		},
    52  		{
    53  			name: "Everything non default configured",
    54  			cfg: fakeClassConfig{
    55  				classConfig: map[string]interface{}{
    56  					"model":            "text-babbage-001",
    57  					"maxTokens":        100,
    58  					"temperature":      0.5,
    59  					"topP":             3,
    60  					"frequencyPenalty": 0.1,
    61  					"presencePenalty":  0.9,
    62  					"baseURL":          "https://openai.proxy.dev",
    63  				},
    64  			},
    65  			wantModel:            "text-babbage-001",
    66  			wantMaxTokens:        100,
    67  			wantTemperature:      0.5,
    68  			wantTopP:             3,
    69  			wantFrequencyPenalty: 0.1,
    70  			wantPresencePenalty:  0.9,
    71  			wantBaseURL:          "https://openai.proxy.dev",
    72  			wantErr:              nil,
    73  		},
    74  		{
    75  			name: "Azure OpenAI config",
    76  			cfg: fakeClassConfig{
    77  				classConfig: map[string]interface{}{
    78  					"resourceName": "weaviate",
    79  					"deploymentId": "text-ada-001",
    80  				},
    81  			},
    82  			wantModel:            "text-ada-001",
    83  			wantResourceName:     "weaviate",
    84  			wantDeploymentID:     "text-ada-001",
    85  			wantIsAzure:          true,
    86  			wantMaxTokens:        16,
    87  			wantTemperature:      0.0,
    88  			wantTopP:             1,
    89  			wantFrequencyPenalty: 0.0,
    90  			wantPresencePenalty:  0.0,
    91  			wantErr:              nil,
    92  			wantBaseURL:          "https://api.openai.com",
    93  		},
    94  		{
    95  			name: "Wrong model data type configured",
    96  			cfg: fakeClassConfig{
    97  				classConfig: map[string]interface{}{
    98  					"model": true,
    99  				},
   100  			},
   101  			wantErr: errors.Errorf("wrong OpenAI model name, available model names are: %v", availableOpenAIModels),
   102  		},
   103  		{
   104  			name: "Wrong model data type configured",
   105  			cfg: fakeClassConfig{
   106  				classConfig: map[string]interface{}{
   107  					"model": "this-is-a-non-existing-model",
   108  				},
   109  			},
   110  			wantErr: errors.Errorf("wrong OpenAI model name, available model names are: %v", availableOpenAIModels),
   111  		},
   112  		{
   113  			name: "Wrong maxTokens configured",
   114  			cfg: fakeClassConfig{
   115  				classConfig: map[string]interface{}{
   116  					"maxTokens": true,
   117  				},
   118  			},
   119  			wantErr: errors.Errorf("Wrong maxTokens configuration, values are should have a minimal value of 1 and max is dependant on the model used"),
   120  		},
   121  		{
   122  			name: "Wrong temperature configured",
   123  			cfg: fakeClassConfig{
   124  				classConfig: map[string]interface{}{
   125  					"temperature": true,
   126  				},
   127  			},
   128  			wantErr: errors.Errorf("Wrong temperature configuration, values are between 0.0 and 1.0"),
   129  		},
   130  		{
   131  			name: "Wrong frequencyPenalty configured",
   132  			cfg: fakeClassConfig{
   133  				classConfig: map[string]interface{}{
   134  					"frequencyPenalty": true,
   135  				},
   136  			},
   137  			wantErr: errors.Errorf("Wrong frequencyPenalty configuration, values are between 0.0 and 1.0"),
   138  		},
   139  		{
   140  			name: "Wrong presencePenalty configured",
   141  			cfg: fakeClassConfig{
   142  				classConfig: map[string]interface{}{
   143  					"presencePenalty": true,
   144  				},
   145  			},
   146  			wantErr: errors.Errorf("Wrong presencePenalty configuration, values are between 0.0 and 1.0"),
   147  		},
   148  		{
   149  			name: "Wrong topP configured",
   150  			cfg: fakeClassConfig{
   151  				classConfig: map[string]interface{}{
   152  					"topP": true,
   153  				},
   154  			},
   155  			wantErr: errors.Errorf("Wrong topP configuration, values are should have a minimal value of 1 and max of 5"),
   156  		},
   157  		{
   158  			name: "Wrong Azure OpenAI config - empty deploymentId",
   159  			cfg: fakeClassConfig{
   160  				classConfig: map[string]interface{}{
   161  					"resourceName": "resource-name",
   162  				},
   163  			},
   164  			wantErr: errors.Errorf("both resourceName and deploymentId must be provided"),
   165  		},
   166  		{
   167  			name: "Wrong Azure OpenAI config - empty resourceName",
   168  			cfg: fakeClassConfig{
   169  				classConfig: map[string]interface{}{
   170  					"deploymentId": "ada",
   171  				},
   172  			},
   173  			wantErr: errors.Errorf("both resourceName and deploymentId must be provided"),
   174  		},
   175  	}
   176  	for _, tt := range tests {
   177  		t.Run(tt.name, func(t *testing.T) {
   178  			ic := NewClassSettings(tt.cfg)
   179  			if tt.wantErr != nil {
   180  				assert.EqualError(t, tt.wantErr, ic.Validate(nil).Error())
   181  			} else {
   182  				assert.Equal(t, tt.wantModel, ic.Model())
   183  				assert.Equal(t, tt.wantMaxTokens, ic.MaxTokens())
   184  				assert.Equal(t, tt.wantTemperature, ic.Temperature())
   185  				assert.Equal(t, tt.wantTopP, ic.TopP())
   186  				assert.Equal(t, tt.wantFrequencyPenalty, ic.FrequencyPenalty())
   187  				assert.Equal(t, tt.wantPresencePenalty, ic.PresencePenalty())
   188  				assert.Equal(t, tt.wantResourceName, ic.ResourceName())
   189  				assert.Equal(t, tt.wantDeploymentID, ic.DeploymentID())
   190  				assert.Equal(t, tt.wantIsAzure, ic.IsAzure())
   191  				assert.Equal(t, tt.wantBaseURL, ic.BaseURL())
   192  			}
   193  		})
   194  	}
   195  }
   196  
   197  type fakeClassConfig struct {
   198  	classConfig map[string]interface{}
   199  }
   200  
   201  func (f fakeClassConfig) Class() map[string]interface{} {
   202  	return f.classConfig
   203  }
   204  
   205  func (f fakeClassConfig) Tenant() string {
   206  	return ""
   207  }
   208  
   209  func (f fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} {
   210  	return f.classConfig
   211  }
   212  
   213  func (f fakeClassConfig) Property(propName string) map[string]interface{} {
   214  	return nil
   215  }
   216  
   217  func (f fakeClassConfig) TargetVector() string {
   218  	return ""
   219  }