github.com/weaviate/weaviate@v1.24.6/usecases/objects/query_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package objects
    13  
    14  import (
    15  	"context"
    16  	"errors"
    17  	"testing"
    18  
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/weaviate/weaviate/entities/additional"
    21  	"github.com/weaviate/weaviate/entities/models"
    22  	"github.com/weaviate/weaviate/entities/schema"
    23  	"github.com/weaviate/weaviate/entities/search"
    24  )
    25  
    26  func TestQuery(t *testing.T) {
    27  	t.Parallel()
    28  	var (
    29  		cls    = "MyClass"
    30  		m      = newFakeGetManager(schema.Schema{})
    31  		errAny = errors.New("any")
    32  	)
    33  	params := QueryParams{
    34  		Class: cls,
    35  		Limit: ptInt64(10),
    36  	}
    37  	inputs := QueryInput{
    38  		Class: cls,
    39  		Limit: 10,
    40  	}
    41  	tests := []struct {
    42  		class             string
    43  		name              string
    44  		param             QueryParams
    45  		mockedErr         *Error
    46  		authErr           error
    47  		lockErr           error
    48  		wantCode          int
    49  		mockedDBResponse  []search.Result
    50  		wantResponse      []*models.Object
    51  		wantQueryInput    QueryInput
    52  		wantUsageTracking bool
    53  	}{
    54  		{
    55  			name:           "not found",
    56  			class:          cls,
    57  			param:          params,
    58  			mockedErr:      &Error{Code: StatusNotFound},
    59  			wantCode:       StatusNotFound,
    60  			wantQueryInput: inputs,
    61  		},
    62  		{
    63  			name:           "forbidden",
    64  			class:          cls,
    65  			param:          params,
    66  			authErr:        errAny,
    67  			wantCode:       StatusForbidden,
    68  			wantQueryInput: inputs,
    69  		},
    70  		{
    71  			name:           "internal error",
    72  			class:          cls,
    73  			param:          params,
    74  			lockErr:        errAny,
    75  			wantCode:       StatusInternalServerError,
    76  			wantQueryInput: inputs,
    77  		},
    78  		{
    79  			name:  "happy path",
    80  			class: cls,
    81  			param: params,
    82  			mockedDBResponse: []search.Result{
    83  				{
    84  					ClassName: cls,
    85  					Schema: map[string]interface{}{
    86  						"foo": "bar",
    87  					},
    88  					Dims: 3,
    89  					Dist: 0,
    90  				},
    91  			},
    92  			wantResponse: []*models.Object{{
    93  				Class:         cls,
    94  				VectorWeights: map[string]string(nil),
    95  				Properties: map[string]interface{}{
    96  					"foo": "bar",
    97  				},
    98  			}},
    99  			wantQueryInput: inputs,
   100  		},
   101  		{
   102  			name:  "happy path with explicit vector requested",
   103  			class: cls,
   104  			param: QueryParams{
   105  				Class:      cls,
   106  				Limit:      ptInt64(10),
   107  				Additional: additional.Properties{Vector: true},
   108  			},
   109  			mockedDBResponse: []search.Result{
   110  				{
   111  					ClassName: cls,
   112  					Schema: map[string]interface{}{
   113  						"foo": "bar",
   114  					},
   115  					Dims: 3,
   116  				},
   117  			},
   118  			wantResponse: []*models.Object{{
   119  				Class:         cls,
   120  				VectorWeights: map[string]string(nil),
   121  				Properties: map[string]interface{}{
   122  					"foo": "bar",
   123  				},
   124  			}},
   125  			wantQueryInput: QueryInput{
   126  				Class:      cls,
   127  				Limit:      10,
   128  				Additional: additional.Properties{Vector: true},
   129  			},
   130  			wantUsageTracking: true,
   131  		},
   132  		{
   133  			name:           "bad request",
   134  			class:          cls,
   135  			param:          QueryParams{Class: cls, Offset: ptInt64(1), Limit: &m.config.Config.QueryMaximumResults},
   136  			wantCode:       StatusBadRequest,
   137  			wantQueryInput: inputs,
   138  		},
   139  	}
   140  	for i, tc := range tests {
   141  		t.Run(tc.name, func(t *testing.T) {
   142  			m.authorizer.Err = tc.authErr
   143  			m.locks.Err = tc.lockErr
   144  			if tc.authErr == nil && tc.lockErr == nil {
   145  				m.repo.On("Query", &tc.wantQueryInput).Return(tc.mockedDBResponse, tc.mockedErr).Once()
   146  			}
   147  			if tc.wantUsageTracking {
   148  				m.metrics.On("AddUsageDimensions", cls, "get_rest", "list_include_vector",
   149  					tc.mockedDBResponse[0].Dims)
   150  			}
   151  			res, err := m.Manager.Query(context.Background(), nil, &tc.param)
   152  			code := 0
   153  			if err != nil {
   154  				code = err.Code
   155  			}
   156  			if tc.wantCode != code {
   157  				t.Errorf("case %d expected:%v got:%v", i+1, tc.wantCode, code)
   158  			}
   159  
   160  			assert.Equal(t, tc.wantResponse, res)
   161  		})
   162  	}
   163  }