github.com/weaviate/weaviate@v1.24.6/modules/text2vec-contextionary/extensions/rest_storage_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package extensions
    13  
    14  import (
    15  	"bytes"
    16  	"fmt"
    17  	"net/http"
    18  	"net/http/httptest"
    19  	"sort"
    20  	"testing"
    21  
    22  	"github.com/stretchr/testify/assert"
    23  )
    24  
    25  func Test_StorageHandlers(t *testing.T) {
    26  	ls := newFakeLoaderStorer()
    27  	h := NewRESTHandlers(ls, nil)
    28  
    29  	extensionAKey := "my-first-extension"
    30  	extensionAValue := []byte("some-value")
    31  
    32  	extensionBKey := "my-other-extension"
    33  	extensionBValue := []byte("some-other-value")
    34  
    35  	t.Run("retrieving a non existent concept", func(t *testing.T) {
    36  		r := httptest.NewRequest("GET", "/my-concept", nil)
    37  		w := httptest.NewRecorder()
    38  		h.StorageHandler().ServeHTTP(w, r)
    39  
    40  		res := w.Result()
    41  		defer res.Body.Close()
    42  		assert.Equal(t, http.StatusNotFound, res.StatusCode)
    43  	})
    44  
    45  	t.Run("storing two extensions", func(t *testing.T) {
    46  		t.Run("extension A", func(t *testing.T) {
    47  			body := bytes.NewReader(extensionAValue)
    48  			r := httptest.NewRequest("PUT", fmt.Sprintf("/%s", extensionAKey), body)
    49  			w := httptest.NewRecorder()
    50  			h.StorageHandler().ServeHTTP(w, r)
    51  
    52  			res := w.Result()
    53  			defer res.Body.Close()
    54  			assert.Equal(t, http.StatusOK, res.StatusCode)
    55  		})
    56  
    57  		t.Run("extension B", func(t *testing.T) {
    58  			body := bytes.NewReader(extensionBValue)
    59  			r := httptest.NewRequest("PUT", fmt.Sprintf("/%s", extensionBKey), body)
    60  			w := httptest.NewRecorder()
    61  			h.StorageHandler().ServeHTTP(w, r)
    62  
    63  			res := w.Result()
    64  			defer res.Body.Close()
    65  			assert.Equal(t, http.StatusOK, res.StatusCode)
    66  		})
    67  	})
    68  
    69  	t.Run("when storing fails", func(t *testing.T) {
    70  		ls.storeError = fmt.Errorf("oops")
    71  		body := bytes.NewReader(extensionAValue)
    72  		r := httptest.NewRequest("PUT", "/some-extension", body)
    73  
    74  		w := httptest.NewRecorder()
    75  		h.StorageHandler().ServeHTTP(w, r)
    76  
    77  		res := w.Result()
    78  		defer res.Body.Close()
    79  		assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
    80  	})
    81  
    82  	t.Run("storing with an empty concept", func(t *testing.T) {
    83  		body := bytes.NewReader(extensionAValue)
    84  		r := httptest.NewRequest("PUT", "/", body)
    85  
    86  		w := httptest.NewRecorder()
    87  		h.StorageHandler().ServeHTTP(w, r)
    88  
    89  		res := w.Result()
    90  		defer res.Body.Close()
    91  		assert.Equal(t, http.StatusNotFound, res.StatusCode)
    92  	})
    93  
    94  	t.Run("retrieving two extensions", func(t *testing.T) {
    95  		t.Run("extension A", func(t *testing.T) {
    96  			r := httptest.NewRequest("GET", fmt.Sprintf("/%s", extensionAKey), nil)
    97  			w := httptest.NewRecorder()
    98  			h.StorageHandler().ServeHTTP(w, r)
    99  
   100  			res := w.Result()
   101  			defer res.Body.Close()
   102  			assert.Equal(t, http.StatusOK, res.StatusCode)
   103  			assert.Equal(t, extensionAValue, w.Body.Bytes())
   104  		})
   105  
   106  		t.Run("extension B", func(t *testing.T) {
   107  			r := httptest.NewRequest("GET", fmt.Sprintf("/%s", extensionBKey), nil)
   108  			w := httptest.NewRecorder()
   109  			h.StorageHandler().ServeHTTP(w, r)
   110  
   111  			res := w.Result()
   112  			defer res.Body.Close()
   113  			assert.Equal(t, http.StatusOK, res.StatusCode)
   114  			assert.Equal(t, extensionBValue, w.Body.Bytes())
   115  		})
   116  
   117  		t.Run("full dump with trailing slash", func(t *testing.T) {
   118  			r := httptest.NewRequest("GET", "/", nil)
   119  			w := httptest.NewRecorder()
   120  			h.StorageHandler().ServeHTTP(w, r)
   121  			expectedValue := []byte("some-value\nsome-other-value\n")
   122  
   123  			res := w.Result()
   124  			defer res.Body.Close()
   125  			assert.Equal(t, http.StatusOK, res.StatusCode)
   126  			assert.Equal(t, expectedValue, w.Body.Bytes())
   127  		})
   128  	})
   129  
   130  	t.Run("when loading fails", func(t *testing.T) {
   131  		ls.loadError = fmt.Errorf("oops")
   132  		body := bytes.NewReader(extensionAValue)
   133  		r := httptest.NewRequest("GET", "/some-extension", body)
   134  
   135  		w := httptest.NewRecorder()
   136  		h.StorageHandler().ServeHTTP(w, r)
   137  
   138  		res := w.Result()
   139  		defer res.Body.Close()
   140  		assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
   141  	})
   142  }
   143  
   144  type fakeLoaderStorer struct {
   145  	store      map[string][]byte
   146  	storeError error
   147  	loadError  error
   148  }
   149  
   150  func newFakeLoaderStorer() *fakeLoaderStorer {
   151  	return &fakeLoaderStorer{
   152  		store: map[string][]byte{},
   153  	}
   154  }
   155  
   156  func (f *fakeLoaderStorer) Store(concept string, value []byte) error {
   157  	if f.storeError == nil {
   158  		f.store[concept] = value
   159  	}
   160  	return f.storeError
   161  }
   162  
   163  func (f *fakeLoaderStorer) Load(concept string) ([]byte, error) {
   164  	return f.store[concept], f.loadError
   165  }
   166  
   167  func (f *fakeLoaderStorer) LoadAll() ([]byte, error) {
   168  	var keys [][]byte
   169  	for key := range f.store {
   170  		keys = append(keys, []byte(key))
   171  	}
   172  
   173  	sort.Slice(keys, func(a, b int) bool {
   174  		return bytes.Compare(keys[a], keys[b]) == -1
   175  	})
   176  
   177  	buf := bytes.NewBuffer(nil)
   178  	for _, key := range keys {
   179  		buf.Write(f.store[string(key)])
   180  		buf.Write([]byte("\n"))
   181  	}
   182  
   183  	return buf.Bytes(), nil
   184  }