github.com/weaviate/weaviate@v1.24.6/modules/generative-aws/config/class_settings_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package config
    13  
    14  import (
    15  	"testing"
    16  
    17  	"github.com/pkg/errors"
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/weaviate/weaviate/entities/moduletools"
    20  )
    21  
    22  func Test_classSettings_Validate(t *testing.T) {
    23  	t.Skip("Skipping this test for now")
    24  	tests := []struct {
    25  		name              string
    26  		cfg               moduletools.ClassConfig
    27  		wantService       string
    28  		wantRegion        string
    29  		wantModel         string
    30  		wantEndpoint      string
    31  		wantTargetModel   string
    32  		wantTargetVariant string
    33  		wantMaxTokenCount int
    34  		wantStopSequences []string
    35  		wantTemperature   float64
    36  		wantTopP          int
    37  		wantErr           error
    38  	}{
    39  		{
    40  			name: "happy flow - Bedrock",
    41  			cfg: fakeClassConfig{
    42  				classConfig: map[string]interface{}{
    43  					"service": "bedrock",
    44  					"region":  "us-east-1",
    45  					"model":   "amazon.titan-tg1-large",
    46  				},
    47  			},
    48  			wantService:       "bedrock",
    49  			wantRegion:        "us-east-1",
    50  			wantModel:         "amazon.titan-tg1-large",
    51  			wantMaxTokenCount: 8192,
    52  			wantStopSequences: []string{},
    53  			wantTemperature:   0,
    54  			wantTopP:          1,
    55  		},
    56  		{
    57  			name: "happy flow - Sagemaker",
    58  			cfg: fakeClassConfig{
    59  				classConfig: map[string]interface{}{
    60  					"service":       "sagemaker",
    61  					"region":        "us-east-1",
    62  					"endpoint":      "my-endpoint-deployment",
    63  					"targetModel":   "model",
    64  					"targetVariant": "variant-1",
    65  				},
    66  			},
    67  			wantService:       "sagemaker",
    68  			wantRegion:        "us-east-1",
    69  			wantEndpoint:      "my-endpoint-deployment",
    70  			wantTargetModel:   "model",
    71  			wantTargetVariant: "variant-1",
    72  		},
    73  		{
    74  			name: "custom values - Bedrock",
    75  			cfg: fakeClassConfig{
    76  				classConfig: map[string]interface{}{
    77  					"service":       "bedrock",
    78  					"region":        "us-east-1",
    79  					"model":         "amazon.titan-tg1-large",
    80  					"maxTokenCount": 1,
    81  					"stopSequences": []string{"test", "test2"},
    82  					"temperature":   0.2,
    83  					"topP":          0,
    84  				},
    85  			},
    86  			wantService:       "bedrock",
    87  			wantRegion:        "us-east-1",
    88  			wantModel:         "amazon.titan-tg1-large",
    89  			wantMaxTokenCount: 1,
    90  			wantStopSequences: []string{"test", "test2"},
    91  			wantTemperature:   0.2,
    92  			wantTopP:          0,
    93  		},
    94  		{
    95  			name: "custom values - Sagemaker",
    96  			cfg: fakeClassConfig{
    97  				classConfig: map[string]interface{}{
    98  					"service":       "sagemaker",
    99  					"region":        "us-east-1",
   100  					"endpoint":      "this-is-my-endpoint",
   101  					"targetModel":   "my-target-model",
   102  					"targetVariant": "my-target¬variant",
   103  				},
   104  			},
   105  			wantService:       "sagemaker",
   106  			wantRegion:        "us-east-1",
   107  			wantEndpoint:      "this-is-my-endpoint",
   108  			wantTargetModel:   "my-target-model",
   109  			wantTargetVariant: "my-target¬variant",
   110  		},
   111  		{
   112  			name: "wrong temperature",
   113  			cfg: fakeClassConfig{
   114  				classConfig: map[string]interface{}{
   115  					"service":     "bedrock",
   116  					"region":      "us-east-1",
   117  					"model":       "amazon.titan-tg1-large",
   118  					"temperature": 2,
   119  				},
   120  			},
   121  			wantErr: errors.Errorf("temperature has to be float value between 0 and 1"),
   122  		},
   123  		{
   124  			name: "wrong maxTokenCount",
   125  			cfg: fakeClassConfig{
   126  				classConfig: map[string]interface{}{
   127  					"service":       "bedrock",
   128  					"region":        "us-east-1",
   129  					"model":         "amazon.titan-tg1-large",
   130  					"maxTokenCount": 9000,
   131  				},
   132  			},
   133  			wantErr: errors.Errorf("maxTokenCount has to be an integer value between 1 and 8096"),
   134  		},
   135  		{
   136  			name: "wrong topP",
   137  			cfg: fakeClassConfig{
   138  				classConfig: map[string]interface{}{
   139  					"service": "bedrock",
   140  					"region":  "us-east-1",
   141  					"model":   "amazon.titan-tg1-large",
   142  					"topP":    2000,
   143  				},
   144  			},
   145  			wantErr: errors.Errorf("topP has to be an integer value between 0 and 1"),
   146  		},
   147  		{
   148  			name: "wrong all",
   149  			cfg: fakeClassConfig{
   150  				classConfig: map[string]interface{}{
   151  					"maxTokenCount": 9000,
   152  					"temperature":   2,
   153  					"topP":          3,
   154  				},
   155  			},
   156  			wantErr: errors.Errorf("wrong service, " +
   157  				"available services are: [bedrock sagemaker], " +
   158  				"region cannot be empty",
   159  			),
   160  		},
   161  	}
   162  	for _, tt := range tests {
   163  		t.Run(tt.name, func(t *testing.T) {
   164  			ic := NewClassSettings(tt.cfg)
   165  			if tt.wantErr != nil {
   166  				assert.EqualError(t, ic.Validate(nil), tt.wantErr.Error())
   167  			} else {
   168  				assert.Equal(t, tt.wantService, ic.Service())
   169  				assert.Equal(t, tt.wantRegion, ic.Region())
   170  				assert.Equal(t, tt.wantModel, ic.Model())
   171  				assert.Equal(t, tt.wantEndpoint, ic.Endpoint())
   172  				assert.Equal(t, tt.wantTargetModel, ic.TargetModel())
   173  				assert.Equal(t, tt.wantTargetVariant, ic.TargetVariant())
   174  				if ic.Temperature() != nil {
   175  					assert.Equal(t, tt.wantTemperature, *ic.Temperature())
   176  				}
   177  				assert.Equal(t, tt.wantStopSequences, ic.StopSequences())
   178  				if ic.TopP() != nil {
   179  					assert.Equal(t, tt.wantTopP, *ic.TopP())
   180  				}
   181  			}
   182  		})
   183  	}
   184  }
   185  
   186  type fakeClassConfig struct {
   187  	classConfig map[string]interface{}
   188  }
   189  
   190  func (f fakeClassConfig) Class() map[string]interface{} {
   191  	return f.classConfig
   192  }
   193  
   194  func (f fakeClassConfig) Tenant() string {
   195  	return ""
   196  }
   197  
   198  func (f fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} {
   199  	return f.classConfig
   200  }
   201  
   202  func (f fakeClassConfig) Property(propName string) map[string]interface{} {
   203  	return nil
   204  }
   205  
   206  func (f fakeClassConfig) TargetVector() string {
   207  	return ""
   208  }