github.com/weaviate/weaviate@v1.24.6/modules/text2vec-transformers/clients/meta_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "net/http" 16 "net/http/httptest" 17 "testing" 18 "time" 19 20 "github.com/stretchr/testify/assert" 21 ) 22 23 func TestGetMeta(t *testing.T) { 24 t.Run("when common server is providing meta", func(t *testing.T) { 25 server := httptest.NewServer(&testMetaHandler{t: t}) 26 defer server.Close() 27 v := New(server.URL, server.URL, 0, nullLogger()) 28 meta, err := v.MetaInfo() 29 30 assert.Nil(t, err) 31 assert.NotNil(t, meta) 32 33 model := extractChildMap(t, meta, "model") 34 assert.NotNil(t, model["_name_or_path"]) 35 assert.NotNil(t, model["architectures"]) 36 assert.Contains(t, model["architectures"], "DistilBertModel") 37 ID2Label := extractChildMap(t, model, "id2label") 38 assert.NotNil(t, ID2Label["0"]) 39 assert.NotNil(t, ID2Label["1"]) 40 }) 41 42 t.Run("when passage and query servers are providing meta", func(t *testing.T) { 43 serverPassage := httptest.NewServer(&testMetaHandler{t: t, modelType: "passage"}) 44 serverQuery := httptest.NewServer(&testMetaHandler{t: t, modelType: "query"}) 45 defer serverPassage.Close() 46 defer serverQuery.Close() 47 v := New(serverPassage.URL, serverQuery.URL, 0, nullLogger()) 48 meta, err := v.MetaInfo() 49 50 assert.Nil(t, err) 51 assert.NotNil(t, meta) 52 53 passage := extractChildMap(t, meta, "passage") 54 passageModel := extractChildMap(t, passage, "model") 55 assert.NotNil(t, passageModel["_name_or_path"]) 56 assert.NotNil(t, passageModel["architectures"]) 57 assert.Contains(t, passageModel["architectures"], "DPRContextEncoder") 58 passageID2Label := extractChildMap(t, passageModel, "id2label") 59 assert.NotNil(t, passageID2Label["0"]) 60 assert.NotNil(t, passageID2Label["1"]) 61 62 query := extractChildMap(t, meta, "query") 63 queryModel := extractChildMap(t, query, "model") 64 assert.NotNil(t, queryModel["_name_or_path"]) 65 assert.NotNil(t, queryModel["architectures"]) 66 assert.Contains(t, queryModel["architectures"], "DPRQuestionEncoder") 67 queryID2Label := extractChildMap(t, queryModel, "id2label") 68 assert.NotNil(t, queryID2Label["0"]) 69 assert.NotNil(t, queryID2Label["1"]) 70 }) 71 72 t.Run("when passage and query servers are unavailable", func(t *testing.T) { 73 rt := time.Now().Add(time.Hour) 74 serverPassage := httptest.NewServer(&testMetaHandler{t: t, modelType: "passage", readyTime: rt}) 75 serverQuery := httptest.NewServer(&testMetaHandler{t: t, modelType: "query", readyTime: rt}) 76 defer serverPassage.Close() 77 defer serverQuery.Close() 78 v := New(serverPassage.URL, serverQuery.URL, 0, nullLogger()) 79 meta, err := v.MetaInfo() 80 81 assert.NotNil(t, err) 82 assert.Contains(t, err.Error(), "[passage] unexpected status code '503' of meta request") 83 assert.Contains(t, err.Error(), "[query] unexpected status code '503' of meta request") 84 assert.Nil(t, meta) 85 }) 86 } 87 88 type testMetaHandler struct { 89 t *testing.T 90 // the test handler will report as not ready before the time has passed 91 readyTime time.Time 92 modelType string 93 } 94 95 func (h *testMetaHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 96 assert.Equal(h.t, "/meta", r.URL.String()) 97 assert.Equal(h.t, http.MethodGet, r.Method) 98 99 if time.Since(h.readyTime) < 0 { 100 w.WriteHeader(http.StatusServiceUnavailable) 101 return 102 } 103 104 w.Write([]byte(h.metaInfo())) 105 } 106 107 func (h *testMetaHandler) metaInfo() string { 108 switch h.modelType { 109 case "passage": 110 return `{ 111 "model": { 112 "return_dict": true, 113 "output_hidden_states": false, 114 "output_attentions": false, 115 "torchscript": false, 116 "torch_dtype": "float32", 117 "use_bfloat16": false, 118 "pruned_heads": {}, 119 "tie_word_embeddings": true, 120 "is_encoder_decoder": false, 121 "is_decoder": false, 122 "cross_attention_hidden_size": null, 123 "add_cross_attention": false, 124 "tie_encoder_decoder": false, 125 "max_length": 20, 126 "min_length": 0, 127 "do_sample": false, 128 "early_stopping": false, 129 "num_beams": 1, 130 "num_beam_groups": 1, 131 "diversity_penalty": 0, 132 "temperature": 1, 133 "top_k": 50, 134 "top_p": 1, 135 "repetition_penalty": 1, 136 "length_penalty": 1, 137 "no_repeat_ngram_size": 0, 138 "encoder_no_repeat_ngram_size": 0, 139 "bad_words_ids": null, 140 "num_return_sequences": 1, 141 "chunk_size_feed_forward": 0, 142 "output_scores": false, 143 "return_dict_in_generate": false, 144 "forced_bos_token_id": null, 145 "forced_eos_token_id": null, 146 "remove_invalid_values": false, 147 "architectures": [ 148 "DPRContextEncoder" 149 ], 150 "finetuning_task": null, 151 "id2label": { 152 "0": "LABEL_0", 153 "1": "LABEL_1" 154 }, 155 "label2id": { 156 "LABEL_0": 0, 157 "LABEL_1": 1 158 }, 159 "tokenizer_class": null, 160 "prefix": null, 161 "bos_token_id": null, 162 "pad_token_id": 0, 163 "eos_token_id": null, 164 "sep_token_id": null, 165 "decoder_start_token_id": null, 166 "task_specific_params": null, 167 "problem_type": null, 168 "_name_or_path": "./models/model", 169 "transformers_version": "4.16.2", 170 "gradient_checkpointing": false, 171 "model_type": "dpr", 172 "vocab_size": 30522, 173 "hidden_size": 768, 174 "num_hidden_layers": 12, 175 "num_attention_heads": 12, 176 "hidden_act": "gelu", 177 "intermediate_size": 3072, 178 "hidden_dropout_prob": 0.1, 179 "attention_probs_dropout_prob": 0.1, 180 "max_position_embeddings": 512, 181 "type_vocab_size": 2, 182 "initializer_range": 0.02, 183 "layer_norm_eps": 1e-12, 184 "projection_dim": 0, 185 "position_embedding_type": "absolute" 186 } 187 }` 188 case "query": 189 return `{ 190 "model": { 191 "return_dict": true, 192 "output_hidden_states": false, 193 "output_attentions": false, 194 "torchscript": false, 195 "torch_dtype": "float32", 196 "use_bfloat16": false, 197 "pruned_heads": {}, 198 "tie_word_embeddings": true, 199 "is_encoder_decoder": false, 200 "is_decoder": false, 201 "cross_attention_hidden_size": null, 202 "add_cross_attention": false, 203 "tie_encoder_decoder": false, 204 "max_length": 20, 205 "min_length": 0, 206 "do_sample": false, 207 "early_stopping": false, 208 "num_beams": 1, 209 "num_beam_groups": 1, 210 "diversity_penalty": 0, 211 "temperature": 1, 212 "top_k": 50, 213 "top_p": 1, 214 "repetition_penalty": 1, 215 "length_penalty": 1, 216 "no_repeat_ngram_size": 0, 217 "encoder_no_repeat_ngram_size": 0, 218 "bad_words_ids": null, 219 "num_return_sequences": 1, 220 "chunk_size_feed_forward": 0, 221 "output_scores": false, 222 "return_dict_in_generate": false, 223 "forced_bos_token_id": null, 224 "forced_eos_token_id": null, 225 "remove_invalid_values": false, 226 "architectures": [ 227 "DPRQuestionEncoder" 228 ], 229 "finetuning_task": null, 230 "id2label": { 231 "0": "LABEL_0", 232 "1": "LABEL_1" 233 }, 234 "label2id": { 235 "LABEL_0": 0, 236 "LABEL_1": 1 237 }, 238 "tokenizer_class": null, 239 "prefix": null, 240 "bos_token_id": null, 241 "pad_token_id": 0, 242 "eos_token_id": null, 243 "sep_token_id": null, 244 "decoder_start_token_id": null, 245 "task_specific_params": null, 246 "problem_type": null, 247 "_name_or_path": "./models/model", 248 "transformers_version": "4.16.2", 249 "gradient_checkpointing": false, 250 "model_type": "dpr", 251 "vocab_size": 30522, 252 "hidden_size": 768, 253 "num_hidden_layers": 12, 254 "num_attention_heads": 12, 255 "hidden_act": "gelu", 256 "intermediate_size": 3072, 257 "hidden_dropout_prob": 0.1, 258 "attention_probs_dropout_prob": 0.1, 259 "max_position_embeddings": 512, 260 "type_vocab_size": 2, 261 "initializer_range": 0.02, 262 "layer_norm_eps": 1e-12, 263 "projection_dim": 0, 264 "position_embedding_type": "absolute" 265 } 266 }` 267 default: 268 return `{ 269 "model": { 270 "_name_or_path": "distilbert-base-uncased", 271 "activation": "gelu", 272 "add_cross_attention": false, 273 "architectures": [ 274 "DistilBertModel" 275 ], 276 "attention_dropout": 0.1, 277 "bad_words_ids": null, 278 "bos_token_id": null, 279 "chunk_size_feed_forward": 0, 280 "decoder_start_token_id": null, 281 "dim": 768, 282 "diversity_penalty": 0, 283 "do_sample": false, 284 "dropout": 0.1, 285 "early_stopping": false, 286 "encoder_no_repeat_ngram_size": 0, 287 "eos_token_id": null, 288 "finetuning_task": null, 289 "hidden_dim": 3072, 290 "id2label": { 291 "0": "LABEL_0", 292 "1": "LABEL_1" 293 }, 294 "initializer_range": 0.02, 295 "is_decoder": false, 296 "is_encoder_decoder": false, 297 "label2id": { 298 "LABEL_0": 0, 299 "LABEL_1": 1 300 }, 301 "length_penalty": 1, 302 "max_length": 20, 303 "max_position_embeddings": 512, 304 "min_length": 0, 305 "model_type": "distilbert", 306 "n_heads": 12, 307 "n_layers": 6, 308 "no_repeat_ngram_size": 0, 309 "num_beam_groups": 1, 310 "num_beams": 1, 311 "num_return_sequences": 1, 312 "output_attentions": false, 313 "output_hidden_states": false, 314 "output_scores": false, 315 "pad_token_id": 0, 316 "prefix": null, 317 "pruned_heads": {}, 318 "qa_dropout": 0.1, 319 "repetition_penalty": 1, 320 "return_dict": true, 321 "return_dict_in_generate": false, 322 "sep_token_id": null, 323 "seq_classif_dropout": 0.2, 324 "sinusoidal_pos_embds": false, 325 "task_specific_params": null, 326 "temperature": 1, 327 "tie_encoder_decoder": false, 328 "tie_weights_": true, 329 "tie_word_embeddings": true, 330 "tokenizer_class": null, 331 "top_k": 50, 332 "top_p": 1, 333 "torchscript": false, 334 "transformers_version": "4.3.2", 335 "use_bfloat16": false, 336 "vocab_size": 30522, 337 "xla_device": null 338 } 339 }` 340 } 341 } 342 343 func extractChildMap(t *testing.T, parent map[string]interface{}, name string) map[string]interface{} { 344 assert.NotNil(t, parent[name]) 345 child, ok := parent[name].(map[string]interface{}) 346 assert.True(t, ok) 347 assert.NotNil(t, child) 348 349 return child 350 }