github.com/weaviate/weaviate@v1.24.6/modules/text2vec-contextionary/extensions/rest_storage_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package extensions 13 14 import ( 15 "bytes" 16 "fmt" 17 "net/http" 18 "net/http/httptest" 19 "sort" 20 "testing" 21 22 "github.com/stretchr/testify/assert" 23 ) 24 25 func Test_StorageHandlers(t *testing.T) { 26 ls := newFakeLoaderStorer() 27 h := NewRESTHandlers(ls, nil) 28 29 extensionAKey := "my-first-extension" 30 extensionAValue := []byte("some-value") 31 32 extensionBKey := "my-other-extension" 33 extensionBValue := []byte("some-other-value") 34 35 t.Run("retrieving a non existent concept", func(t *testing.T) { 36 r := httptest.NewRequest("GET", "/my-concept", nil) 37 w := httptest.NewRecorder() 38 h.StorageHandler().ServeHTTP(w, r) 39 40 res := w.Result() 41 defer res.Body.Close() 42 assert.Equal(t, http.StatusNotFound, res.StatusCode) 43 }) 44 45 t.Run("storing two extensions", func(t *testing.T) { 46 t.Run("extension A", func(t *testing.T) { 47 body := bytes.NewReader(extensionAValue) 48 r := httptest.NewRequest("PUT", fmt.Sprintf("/%s", extensionAKey), body) 49 w := httptest.NewRecorder() 50 h.StorageHandler().ServeHTTP(w, r) 51 52 res := w.Result() 53 defer res.Body.Close() 54 assert.Equal(t, http.StatusOK, res.StatusCode) 55 }) 56 57 t.Run("extension B", func(t *testing.T) { 58 body := bytes.NewReader(extensionBValue) 59 r := httptest.NewRequest("PUT", fmt.Sprintf("/%s", extensionBKey), body) 60 w := httptest.NewRecorder() 61 h.StorageHandler().ServeHTTP(w, r) 62 63 res := w.Result() 64 defer res.Body.Close() 65 assert.Equal(t, http.StatusOK, res.StatusCode) 66 }) 67 }) 68 69 t.Run("when storing fails", func(t *testing.T) { 70 ls.storeError = fmt.Errorf("oops") 71 body := bytes.NewReader(extensionAValue) 72 r := httptest.NewRequest("PUT", "/some-extension", body) 73 74 w := httptest.NewRecorder() 75 h.StorageHandler().ServeHTTP(w, r) 76 77 res := w.Result() 78 defer res.Body.Close() 79 assert.Equal(t, http.StatusInternalServerError, res.StatusCode) 80 }) 81 82 t.Run("storing with an empty concept", func(t *testing.T) { 83 body := bytes.NewReader(extensionAValue) 84 r := httptest.NewRequest("PUT", "/", body) 85 86 w := httptest.NewRecorder() 87 h.StorageHandler().ServeHTTP(w, r) 88 89 res := w.Result() 90 defer res.Body.Close() 91 assert.Equal(t, http.StatusNotFound, res.StatusCode) 92 }) 93 94 t.Run("retrieving two extensions", func(t *testing.T) { 95 t.Run("extension A", func(t *testing.T) { 96 r := httptest.NewRequest("GET", fmt.Sprintf("/%s", extensionAKey), nil) 97 w := httptest.NewRecorder() 98 h.StorageHandler().ServeHTTP(w, r) 99 100 res := w.Result() 101 defer res.Body.Close() 102 assert.Equal(t, http.StatusOK, res.StatusCode) 103 assert.Equal(t, extensionAValue, w.Body.Bytes()) 104 }) 105 106 t.Run("extension B", func(t *testing.T) { 107 r := httptest.NewRequest("GET", fmt.Sprintf("/%s", extensionBKey), nil) 108 w := httptest.NewRecorder() 109 h.StorageHandler().ServeHTTP(w, r) 110 111 res := w.Result() 112 defer res.Body.Close() 113 assert.Equal(t, http.StatusOK, res.StatusCode) 114 assert.Equal(t, extensionBValue, w.Body.Bytes()) 115 }) 116 117 t.Run("full dump with trailing slash", func(t *testing.T) { 118 r := httptest.NewRequest("GET", "/", nil) 119 w := httptest.NewRecorder() 120 h.StorageHandler().ServeHTTP(w, r) 121 expectedValue := []byte("some-value\nsome-other-value\n") 122 123 res := w.Result() 124 defer res.Body.Close() 125 assert.Equal(t, http.StatusOK, res.StatusCode) 126 assert.Equal(t, expectedValue, w.Body.Bytes()) 127 }) 128 }) 129 130 t.Run("when loading fails", func(t *testing.T) { 131 ls.loadError = fmt.Errorf("oops") 132 body := bytes.NewReader(extensionAValue) 133 r := httptest.NewRequest("GET", "/some-extension", body) 134 135 w := httptest.NewRecorder() 136 h.StorageHandler().ServeHTTP(w, r) 137 138 res := w.Result() 139 defer res.Body.Close() 140 assert.Equal(t, http.StatusInternalServerError, res.StatusCode) 141 }) 142 } 143 144 type fakeLoaderStorer struct { 145 store map[string][]byte 146 storeError error 147 loadError error 148 } 149 150 func newFakeLoaderStorer() *fakeLoaderStorer { 151 return &fakeLoaderStorer{ 152 store: map[string][]byte{}, 153 } 154 } 155 156 func (f *fakeLoaderStorer) Store(concept string, value []byte) error { 157 if f.storeError == nil { 158 f.store[concept] = value 159 } 160 return f.storeError 161 } 162 163 func (f *fakeLoaderStorer) Load(concept string) ([]byte, error) { 164 return f.store[concept], f.loadError 165 } 166 167 func (f *fakeLoaderStorer) LoadAll() ([]byte, error) { 168 var keys [][]byte 169 for key := range f.store { 170 keys = append(keys, []byte(key)) 171 } 172 173 sort.Slice(keys, func(a, b int) bool { 174 return bytes.Compare(keys[a], keys[b]) == -1 175 }) 176 177 buf := bytes.NewBuffer(nil) 178 for _, key := range keys { 179 buf.Write(f.store[string(key)]) 180 buf.Write([]byte("\n")) 181 } 182 183 return buf.Bytes(), nil 184 }