github.com/weaviate/weaviate@v1.24.6/modules/text2vec-aws/vectorizer/class_settings_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package vectorizer
    13  
    14  import (
    15  	"testing"
    16  
    17  	"github.com/pkg/errors"
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/weaviate/weaviate/entities/moduletools"
    20  )
    21  
    22  func Test_classSettings_Validate(t *testing.T) {
    23  	tests := []struct {
    24  		name              string
    25  		cfg               moduletools.ClassConfig
    26  		wantService       string
    27  		wantRegion        string
    28  		wantModel         string
    29  		wantEndpoint      string
    30  		wantTargetModel   string
    31  		wantTargetVariant string
    32  		wantErr           error
    33  	}{
    34  		{
    35  			name: "happy flow - Bedrock",
    36  			cfg: fakeClassConfig{
    37  				classConfig: map[string]interface{}{
    38  					"service": "bedrock",
    39  					"region":  "us-east-1",
    40  					"model":   "amazon.titan-embed-text-v1",
    41  				},
    42  			},
    43  			wantService: "bedrock",
    44  			wantRegion:  "us-east-1",
    45  			wantModel:   "amazon.titan-embed-text-v1",
    46  			wantErr:     nil,
    47  		},
    48  		{
    49  			name: "happy flow - SageMaker",
    50  			cfg: fakeClassConfig{
    51  				classConfig: map[string]interface{}{
    52  					"service":  "sagemaker",
    53  					"region":   "us-east-1",
    54  					"endpoint": "my-sagemaker",
    55  				},
    56  			},
    57  			wantService:  "sagemaker",
    58  			wantRegion:   "us-east-1",
    59  			wantEndpoint: "my-sagemaker",
    60  			wantErr:      nil,
    61  		},
    62  		{
    63  			name: "empty service",
    64  			cfg: fakeClassConfig{
    65  				classConfig: map[string]interface{}{
    66  					"region": "us-east-1",
    67  					"model":  "amazon.titan-embed-text-v1",
    68  				},
    69  			},
    70  			wantService: "bedrock",
    71  			wantRegion:  "us-east-1",
    72  			wantModel:   "amazon.titan-embed-text-v1",
    73  		},
    74  		{
    75  			name: "empty region - Bedrock",
    76  			cfg: fakeClassConfig{
    77  				classConfig: map[string]interface{}{
    78  					"service": "bedrock",
    79  					"model":   "amazon.titan-embed-text-v1",
    80  				},
    81  			},
    82  			wantErr: errors.Errorf("region cannot be empty"),
    83  		},
    84  		{
    85  			name: "wrong model",
    86  			cfg: fakeClassConfig{
    87  				classConfig: map[string]interface{}{
    88  					"service": "bedrock",
    89  					"region":  "us-west-1",
    90  					"model":   "wrong-model",
    91  				},
    92  			},
    93  			wantErr: errors.Errorf("wrong model, available models are: [amazon.titan-embed-text-v1 cohere.embed-english-v3 cohere.embed-multilingual-v3]"),
    94  		},
    95  		{
    96  			name: "all wrong",
    97  			cfg: fakeClassConfig{
    98  				classConfig: map[string]interface{}{
    99  					"service": "",
   100  					"region":  "",
   101  					"model":   "",
   102  				},
   103  			},
   104  			wantErr: errors.Errorf("wrong service, available services are: [bedrock sagemaker], " +
   105  				"region cannot be empty"),
   106  		},
   107  		{
   108  			name: "wrong properties",
   109  			cfg: fakeClassConfig{
   110  				classConfig: map[string]interface{}{
   111  					"service":    "bedrock",
   112  					"region":     "us-west-1",
   113  					"model":      "cohere.embed-multilingual-v3",
   114  					"properties": []interface{}{"prop1", 1111},
   115  				},
   116  			},
   117  			wantErr: errors.Errorf("properties field value: 1111 must be a string"),
   118  		},
   119  	}
   120  	for _, tt := range tests {
   121  		t.Run(tt.name, func(t *testing.T) {
   122  			ic := NewClassSettings(tt.cfg)
   123  			if tt.wantErr != nil {
   124  				assert.EqualError(t, ic.Validate(nil), tt.wantErr.Error())
   125  			} else {
   126  				assert.Equal(t, tt.wantService, ic.Service())
   127  				assert.Equal(t, tt.wantRegion, ic.Region())
   128  				assert.Equal(t, tt.wantModel, ic.Model())
   129  				assert.Equal(t, tt.wantEndpoint, ic.Endpoint())
   130  				assert.Equal(t, tt.wantTargetModel, ic.TargetModel())
   131  				assert.Equal(t, tt.wantTargetVariant, ic.TargetVariant())
   132  			}
   133  		})
   134  	}
   135  }