github.com/weaviate/weaviate@v1.24.6/modules/generative-palm/config/class_settings_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package config
    13  
    14  import (
    15  	"fmt"
    16  	"testing"
    17  
    18  	"github.com/pkg/errors"
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/weaviate/weaviate/entities/moduletools"
    21  )
    22  
    23  func Test_classSettings_Validate(t *testing.T) {
    24  	tests := []struct {
    25  		name            string
    26  		cfg             moduletools.ClassConfig
    27  		wantApiEndpoint string
    28  		wantProjectID   string
    29  		wantModelID     string
    30  		wantTemperature float64
    31  		wantTokenLimit  int
    32  		wantTopK        int
    33  		wantTopP        float64
    34  		wantErr         error
    35  	}{
    36  		{
    37  			name: "happy flow",
    38  			cfg: fakeClassConfig{
    39  				classConfig: map[string]interface{}{
    40  					"projectId": "projectId",
    41  				},
    42  			},
    43  			wantApiEndpoint: "us-central1-aiplatform.googleapis.com",
    44  			wantProjectID:   "projectId",
    45  			wantModelID:     "chat-bison",
    46  			wantTemperature: 0.2,
    47  			wantTokenLimit:  256,
    48  			wantTopK:        40,
    49  			wantTopP:        0.95,
    50  			wantErr:         nil,
    51  		},
    52  		{
    53  			name: "custom values",
    54  			cfg: fakeClassConfig{
    55  				classConfig: map[string]interface{}{
    56  					"apiEndpoint": "google.com",
    57  					"projectId":   "cloud-project",
    58  					"modelId":     "model-id",
    59  					"temperature": 0.25,
    60  					"tokenLimit":  254,
    61  					"topK":        30,
    62  					"topP":        0.97,
    63  				},
    64  			},
    65  			wantApiEndpoint: "google.com",
    66  			wantProjectID:   "cloud-project",
    67  			wantModelID:     "model-id",
    68  			wantTemperature: 0.25,
    69  			wantTokenLimit:  254,
    70  			wantTopK:        30,
    71  			wantTopP:        0.97,
    72  			wantErr:         nil,
    73  		},
    74  		{
    75  			name: "wrong temperature",
    76  			cfg: fakeClassConfig{
    77  				classConfig: map[string]interface{}{
    78  					"projectId":   "cloud-project",
    79  					"temperature": 2,
    80  				},
    81  			},
    82  			wantErr: errors.Errorf("temperature has to be float value between 0 and 1"),
    83  		},
    84  		{
    85  			name: "wrong tokenLimit",
    86  			cfg: fakeClassConfig{
    87  				classConfig: map[string]interface{}{
    88  					"projectId":  "cloud-project",
    89  					"tokenLimit": 2000,
    90  				},
    91  			},
    92  			wantErr: errors.Errorf("tokenLimit has to be an integer value between 1 and 1024"),
    93  		},
    94  		{
    95  			name: "wrong topK",
    96  			cfg: fakeClassConfig{
    97  				classConfig: map[string]interface{}{
    98  					"projectId": "cloud-project",
    99  					"topK":      2000,
   100  				},
   101  			},
   102  			wantErr: errors.Errorf("topK has to be an integer value between 1 and 40"),
   103  		},
   104  		{
   105  			name: "wrong topP",
   106  			cfg: fakeClassConfig{
   107  				classConfig: map[string]interface{}{
   108  					"projectId": "cloud-project",
   109  					"topP":      3,
   110  				},
   111  			},
   112  			wantErr: errors.Errorf("topP has to be float value between 0 and 1"),
   113  		},
   114  		{
   115  			name: "wrong all",
   116  			cfg: fakeClassConfig{
   117  				classConfig: map[string]interface{}{
   118  					"projectId":   "",
   119  					"temperature": 2,
   120  					"tokenLimit":  2000,
   121  					"topK":        2000,
   122  					"topP":        3,
   123  				},
   124  			},
   125  			wantErr: errors.Errorf("projectId cannot be empty, " +
   126  				"temperature has to be float value between 0 and 1, " +
   127  				"tokenLimit has to be an integer value between 1 and 1024, " +
   128  				"topK has to be an integer value between 1 and 40, " +
   129  				"topP has to be float value between 0 and 1"),
   130  		},
   131  		{
   132  			name: "Generative AI",
   133  			cfg: fakeClassConfig{
   134  				classConfig: map[string]interface{}{
   135  					"apiEndpoint": "generativelanguage.googleapis.com",
   136  				},
   137  			},
   138  			wantApiEndpoint: "generativelanguage.googleapis.com",
   139  			wantProjectID:   "",
   140  			wantModelID:     "chat-bison-001",
   141  			wantTemperature: 0.2,
   142  			wantTokenLimit:  256,
   143  			wantTopK:        40,
   144  			wantTopP:        0.95,
   145  			wantErr:         nil,
   146  		},
   147  		{
   148  			name: "Generative AI with model",
   149  			cfg: fakeClassConfig{
   150  				classConfig: map[string]interface{}{
   151  					"apiEndpoint": "generativelanguage.googleapis.com",
   152  					"modelId":     "chat-bison-001",
   153  				},
   154  			},
   155  			wantApiEndpoint: "generativelanguage.googleapis.com",
   156  			wantProjectID:   "",
   157  			wantModelID:     "chat-bison-001",
   158  			wantTemperature: 0.2,
   159  			wantTokenLimit:  256,
   160  			wantTopK:        40,
   161  			wantTopP:        0.95,
   162  			wantErr:         nil,
   163  		},
   164  		{
   165  			name: "Generative AI with gemini-ultra model",
   166  			cfg: fakeClassConfig{
   167  				classConfig: map[string]interface{}{
   168  					"apiEndpoint": "generativelanguage.googleapis.com",
   169  					"modelId":     "gemini-ultra",
   170  				},
   171  			},
   172  			wantApiEndpoint: "generativelanguage.googleapis.com",
   173  			wantProjectID:   "",
   174  			wantModelID:     "gemini-ultra",
   175  			wantTemperature: 0.2,
   176  			wantTokenLimit:  256,
   177  			wantTopK:        40,
   178  			wantTopP:        0.95,
   179  			wantErr:         nil,
   180  		},
   181  		{
   182  			name: "Generative AI with not supported model",
   183  			cfg: fakeClassConfig{
   184  				classConfig: map[string]interface{}{
   185  					"apiEndpoint": "generativelanguage.googleapis.com",
   186  					"modelId":     "unsupported-model",
   187  				},
   188  			},
   189  			wantErr: fmt.Errorf("unsupported-model is not supported available models are: [chat-bison-001 gemini-pro gemini-pro-vision gemini-ultra]"),
   190  		},
   191  	}
   192  	for _, tt := range tests {
   193  		t.Run(tt.name, func(t *testing.T) {
   194  			ic := NewClassSettings(tt.cfg)
   195  			if tt.wantErr != nil {
   196  				assert.EqualError(t, ic.Validate(nil), tt.wantErr.Error())
   197  			} else {
   198  				assert.Equal(t, tt.wantApiEndpoint, ic.ApiEndpoint())
   199  				assert.Equal(t, tt.wantProjectID, ic.ProjectID())
   200  				assert.Equal(t, tt.wantModelID, ic.ModelID())
   201  				assert.Equal(t, tt.wantTemperature, ic.Temperature())
   202  				assert.Equal(t, tt.wantTokenLimit, ic.TokenLimit())
   203  				assert.Equal(t, tt.wantTopK, ic.TopK())
   204  				assert.Equal(t, tt.wantTopP, ic.TopP())
   205  			}
   206  		})
   207  	}
   208  }
   209  
   210  type fakeClassConfig struct {
   211  	classConfig map[string]interface{}
   212  }
   213  
   214  func (f fakeClassConfig) Class() map[string]interface{} {
   215  	return f.classConfig
   216  }
   217  
   218  func (f fakeClassConfig) Tenant() string {
   219  	return ""
   220  }
   221  
   222  func (f fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} {
   223  	return f.classConfig
   224  }
   225  
   226  func (f fakeClassConfig) Property(propName string) map[string]interface{} {
   227  	return nil
   228  }
   229  
   230  func (f fakeClassConfig) TargetVector() string {
   231  	return ""
   232  }