github.com/weaviate/weaviate@v1.24.6/modules/generative-cohere/config/class_settings_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package config
    13  
    14  import (
    15  	"testing"
    16  
    17  	"github.com/pkg/errors"
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/weaviate/weaviate/entities/moduletools"
    20  )
    21  
    22  func Test_classSettings_Validate(t *testing.T) {
    23  	tests := []struct {
    24  		name                  string
    25  		cfg                   moduletools.ClassConfig
    26  		wantModel             string
    27  		wantMaxTokens         int
    28  		wantTemperature       int
    29  		wantK                 int
    30  		wantStopSequences     []string
    31  		wantReturnLikelihoods string
    32  		wantBaseURL           string
    33  		wantErr               error
    34  	}{
    35  		{
    36  			name: "default settings",
    37  			cfg: fakeClassConfig{
    38  				classConfig: map[string]interface{}{},
    39  			},
    40  			wantModel:             "command-nightly",
    41  			wantMaxTokens:         2048,
    42  			wantTemperature:       0,
    43  			wantK:                 0,
    44  			wantStopSequences:     []string{},
    45  			wantReturnLikelihoods: "NONE",
    46  			wantBaseURL:           "https://api.cohere.ai",
    47  			wantErr:               nil,
    48  		},
    49  		{
    50  			name: "everything non default configured",
    51  			cfg: fakeClassConfig{
    52  				classConfig: map[string]interface{}{
    53  					"model":             "command-xlarge",
    54  					"maxTokens":         2048,
    55  					"temperature":       1,
    56  					"k":                 2,
    57  					"stopSequences":     []string{"stop1", "stop2"},
    58  					"returnLikelihoods": "NONE",
    59  				},
    60  			},
    61  			wantModel:             "command-xlarge",
    62  			wantMaxTokens:         2048,
    63  			wantTemperature:       1,
    64  			wantK:                 2,
    65  			wantStopSequences:     []string{"stop1", "stop2"},
    66  			wantReturnLikelihoods: "NONE",
    67  			wantBaseURL:           "https://api.cohere.ai",
    68  			wantErr:               nil,
    69  		},
    70  		{
    71  			name: "wrong model configured",
    72  			cfg: fakeClassConfig{
    73  				classConfig: map[string]interface{}{
    74  					"model": "wrong-model",
    75  				},
    76  			},
    77  			wantErr: errors.Errorf("wrong Cohere model name, available model names are: " +
    78  				"[command-xlarge-beta command-xlarge command-medium command-xlarge-nightly " +
    79  				"command-medium-nightly xlarge medium command command-light command-nightly command-light-nightly base base-light]"),
    80  		},
    81  		{
    82  			name: "default settings with command-light-nightly",
    83  			cfg: fakeClassConfig{
    84  				classConfig: map[string]interface{}{
    85  					"model": "command-light-nightly",
    86  				},
    87  			},
    88  			wantModel:             "command-light-nightly",
    89  			wantMaxTokens:         2048,
    90  			wantTemperature:       0,
    91  			wantK:                 0,
    92  			wantStopSequences:     []string{},
    93  			wantReturnLikelihoods: "NONE",
    94  			wantBaseURL:           "https://api.cohere.ai",
    95  			wantErr:               nil,
    96  		},
    97  		{
    98  			name: "default settings with command-light-nightly and baseURL",
    99  			cfg: fakeClassConfig{
   100  				classConfig: map[string]interface{}{
   101  					"model":   "command-light-nightly",
   102  					"baseURL": "http://custom-url.com",
   103  				},
   104  			},
   105  			wantModel:             "command-light-nightly",
   106  			wantMaxTokens:         2048,
   107  			wantTemperature:       0,
   108  			wantK:                 0,
   109  			wantStopSequences:     []string{},
   110  			wantReturnLikelihoods: "NONE",
   111  			wantBaseURL:           "http://custom-url.com",
   112  			wantErr:               nil,
   113  		},
   114  	}
   115  	for _, tt := range tests {
   116  		t.Run(tt.name, func(t *testing.T) {
   117  			ic := NewClassSettings(tt.cfg)
   118  			if tt.wantErr != nil {
   119  				assert.Equal(t, tt.wantErr.Error(), ic.Validate(nil).Error())
   120  			} else {
   121  				assert.Equal(t, tt.wantModel, ic.Model())
   122  				assert.Equal(t, tt.wantMaxTokens, ic.MaxTokens())
   123  				assert.Equal(t, tt.wantTemperature, ic.Temperature())
   124  				assert.Equal(t, tt.wantK, ic.K())
   125  				assert.Equal(t, tt.wantStopSequences, ic.StopSequences())
   126  				assert.Equal(t, tt.wantReturnLikelihoods, ic.ReturnLikelihoods())
   127  			}
   128  		})
   129  	}
   130  }
   131  
   132  type fakeClassConfig struct {
   133  	classConfig map[string]interface{}
   134  }
   135  
   136  func (f fakeClassConfig) Class() map[string]interface{} {
   137  	return f.classConfig
   138  }
   139  
   140  func (f fakeClassConfig) Tenant() string {
   141  	return ""
   142  }
   143  
   144  func (f fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} {
   145  	return f.classConfig
   146  }
   147  
   148  func (f fakeClassConfig) Property(propName string) map[string]interface{} {
   149  	return nil
   150  }
   151  
   152  func (f fakeClassConfig) TargetVector() string {
   153  	return ""
   154  }