github.com/weaviate/weaviate@v1.24.6/modules/text2vec-transformers/clients/meta_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"net/http"
    16  	"net/http/httptest"
    17  	"testing"
    18  	"time"
    19  
    20  	"github.com/stretchr/testify/assert"
    21  )
    22  
    23  func TestGetMeta(t *testing.T) {
    24  	t.Run("when common server is providing meta", func(t *testing.T) {
    25  		server := httptest.NewServer(&testMetaHandler{t: t})
    26  		defer server.Close()
    27  		v := New(server.URL, server.URL, 0, nullLogger())
    28  		meta, err := v.MetaInfo()
    29  
    30  		assert.Nil(t, err)
    31  		assert.NotNil(t, meta)
    32  
    33  		model := extractChildMap(t, meta, "model")
    34  		assert.NotNil(t, model["_name_or_path"])
    35  		assert.NotNil(t, model["architectures"])
    36  		assert.Contains(t, model["architectures"], "DistilBertModel")
    37  		ID2Label := extractChildMap(t, model, "id2label")
    38  		assert.NotNil(t, ID2Label["0"])
    39  		assert.NotNil(t, ID2Label["1"])
    40  	})
    41  
    42  	t.Run("when passage and query servers are providing meta", func(t *testing.T) {
    43  		serverPassage := httptest.NewServer(&testMetaHandler{t: t, modelType: "passage"})
    44  		serverQuery := httptest.NewServer(&testMetaHandler{t: t, modelType: "query"})
    45  		defer serverPassage.Close()
    46  		defer serverQuery.Close()
    47  		v := New(serverPassage.URL, serverQuery.URL, 0, nullLogger())
    48  		meta, err := v.MetaInfo()
    49  
    50  		assert.Nil(t, err)
    51  		assert.NotNil(t, meta)
    52  
    53  		passage := extractChildMap(t, meta, "passage")
    54  		passageModel := extractChildMap(t, passage, "model")
    55  		assert.NotNil(t, passageModel["_name_or_path"])
    56  		assert.NotNil(t, passageModel["architectures"])
    57  		assert.Contains(t, passageModel["architectures"], "DPRContextEncoder")
    58  		passageID2Label := extractChildMap(t, passageModel, "id2label")
    59  		assert.NotNil(t, passageID2Label["0"])
    60  		assert.NotNil(t, passageID2Label["1"])
    61  
    62  		query := extractChildMap(t, meta, "query")
    63  		queryModel := extractChildMap(t, query, "model")
    64  		assert.NotNil(t, queryModel["_name_or_path"])
    65  		assert.NotNil(t, queryModel["architectures"])
    66  		assert.Contains(t, queryModel["architectures"], "DPRQuestionEncoder")
    67  		queryID2Label := extractChildMap(t, queryModel, "id2label")
    68  		assert.NotNil(t, queryID2Label["0"])
    69  		assert.NotNil(t, queryID2Label["1"])
    70  	})
    71  
    72  	t.Run("when passage and query servers are unavailable", func(t *testing.T) {
    73  		rt := time.Now().Add(time.Hour)
    74  		serverPassage := httptest.NewServer(&testMetaHandler{t: t, modelType: "passage", readyTime: rt})
    75  		serverQuery := httptest.NewServer(&testMetaHandler{t: t, modelType: "query", readyTime: rt})
    76  		defer serverPassage.Close()
    77  		defer serverQuery.Close()
    78  		v := New(serverPassage.URL, serverQuery.URL, 0, nullLogger())
    79  		meta, err := v.MetaInfo()
    80  
    81  		assert.NotNil(t, err)
    82  		assert.Contains(t, err.Error(), "[passage] unexpected status code '503' of meta request")
    83  		assert.Contains(t, err.Error(), "[query] unexpected status code '503' of meta request")
    84  		assert.Nil(t, meta)
    85  	})
    86  }
    87  
    88  type testMetaHandler struct {
    89  	t *testing.T
    90  	// the test handler will report as not ready before the time has passed
    91  	readyTime time.Time
    92  	modelType string
    93  }
    94  
    95  func (h *testMetaHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    96  	assert.Equal(h.t, "/meta", r.URL.String())
    97  	assert.Equal(h.t, http.MethodGet, r.Method)
    98  
    99  	if time.Since(h.readyTime) < 0 {
   100  		w.WriteHeader(http.StatusServiceUnavailable)
   101  		return
   102  	}
   103  
   104  	w.Write([]byte(h.metaInfo()))
   105  }
   106  
   107  func (h *testMetaHandler) metaInfo() string {
   108  	switch h.modelType {
   109  	case "passage":
   110  		return `{
   111                "model": {
   112                  "return_dict": true,
   113                  "output_hidden_states": false,
   114                  "output_attentions": false,
   115                  "torchscript": false,
   116                  "torch_dtype": "float32",
   117                  "use_bfloat16": false,
   118                  "pruned_heads": {},
   119                  "tie_word_embeddings": true,
   120                  "is_encoder_decoder": false,
   121                  "is_decoder": false,
   122                  "cross_attention_hidden_size": null,
   123                  "add_cross_attention": false,
   124                  "tie_encoder_decoder": false,
   125                  "max_length": 20,
   126                  "min_length": 0,
   127                  "do_sample": false,
   128                  "early_stopping": false,
   129                  "num_beams": 1,
   130                  "num_beam_groups": 1,
   131                  "diversity_penalty": 0,
   132                  "temperature": 1,
   133                  "top_k": 50,
   134                  "top_p": 1,
   135                  "repetition_penalty": 1,
   136                  "length_penalty": 1,
   137                  "no_repeat_ngram_size": 0,
   138                  "encoder_no_repeat_ngram_size": 0,
   139                  "bad_words_ids": null,
   140                  "num_return_sequences": 1,
   141                  "chunk_size_feed_forward": 0,
   142                  "output_scores": false,
   143                  "return_dict_in_generate": false,
   144                  "forced_bos_token_id": null,
   145                  "forced_eos_token_id": null,
   146                  "remove_invalid_values": false,
   147                  "architectures": [
   148                    "DPRContextEncoder"
   149                  ],
   150                  "finetuning_task": null,
   151                  "id2label": {
   152                    "0": "LABEL_0",
   153                    "1": "LABEL_1"
   154                  },
   155                  "label2id": {
   156                    "LABEL_0": 0,
   157                    "LABEL_1": 1
   158                  },
   159                  "tokenizer_class": null,
   160                  "prefix": null,
   161                  "bos_token_id": null,
   162                  "pad_token_id": 0,
   163                  "eos_token_id": null,
   164                  "sep_token_id": null,
   165                  "decoder_start_token_id": null,
   166                  "task_specific_params": null,
   167                  "problem_type": null,
   168                  "_name_or_path": "./models/model",
   169                  "transformers_version": "4.16.2",
   170                  "gradient_checkpointing": false,
   171                  "model_type": "dpr",
   172                  "vocab_size": 30522,
   173                  "hidden_size": 768,
   174                  "num_hidden_layers": 12,
   175                  "num_attention_heads": 12,
   176                  "hidden_act": "gelu",
   177                  "intermediate_size": 3072,
   178                  "hidden_dropout_prob": 0.1,
   179                  "attention_probs_dropout_prob": 0.1,
   180                  "max_position_embeddings": 512,
   181                  "type_vocab_size": 2,
   182                  "initializer_range": 0.02,
   183                  "layer_norm_eps": 1e-12,
   184                  "projection_dim": 0,
   185                  "position_embedding_type": "absolute"
   186                }
   187              }`
   188  	case "query":
   189  		return `{
   190                "model": {
   191                  "return_dict": true,
   192                  "output_hidden_states": false,
   193                  "output_attentions": false,
   194                  "torchscript": false,
   195                  "torch_dtype": "float32",
   196                  "use_bfloat16": false,
   197                  "pruned_heads": {},
   198                  "tie_word_embeddings": true,
   199                  "is_encoder_decoder": false,
   200                  "is_decoder": false,
   201                  "cross_attention_hidden_size": null,
   202                  "add_cross_attention": false,
   203                  "tie_encoder_decoder": false,
   204                  "max_length": 20,
   205                  "min_length": 0,
   206                  "do_sample": false,
   207                  "early_stopping": false,
   208                  "num_beams": 1,
   209                  "num_beam_groups": 1,
   210                  "diversity_penalty": 0,
   211                  "temperature": 1,
   212                  "top_k": 50,
   213                  "top_p": 1,
   214                  "repetition_penalty": 1,
   215                  "length_penalty": 1,
   216                  "no_repeat_ngram_size": 0,
   217                  "encoder_no_repeat_ngram_size": 0,
   218                  "bad_words_ids": null,
   219                  "num_return_sequences": 1,
   220                  "chunk_size_feed_forward": 0,
   221                  "output_scores": false,
   222                  "return_dict_in_generate": false,
   223                  "forced_bos_token_id": null,
   224                  "forced_eos_token_id": null,
   225                  "remove_invalid_values": false,
   226                  "architectures": [
   227                    "DPRQuestionEncoder"
   228                  ],
   229                  "finetuning_task": null,
   230                  "id2label": {
   231                    "0": "LABEL_0",
   232                    "1": "LABEL_1"
   233                  },
   234                  "label2id": {
   235                    "LABEL_0": 0,
   236                    "LABEL_1": 1
   237                  },
   238                  "tokenizer_class": null,
   239                  "prefix": null,
   240                  "bos_token_id": null,
   241                  "pad_token_id": 0,
   242                  "eos_token_id": null,
   243                  "sep_token_id": null,
   244                  "decoder_start_token_id": null,
   245                  "task_specific_params": null,
   246                  "problem_type": null,
   247                  "_name_or_path": "./models/model",
   248                  "transformers_version": "4.16.2",
   249                  "gradient_checkpointing": false,
   250                  "model_type": "dpr",
   251                  "vocab_size": 30522,
   252                  "hidden_size": 768,
   253                  "num_hidden_layers": 12,
   254                  "num_attention_heads": 12,
   255                  "hidden_act": "gelu",
   256                  "intermediate_size": 3072,
   257                  "hidden_dropout_prob": 0.1,
   258                  "attention_probs_dropout_prob": 0.1,
   259                  "max_position_embeddings": 512,
   260                  "type_vocab_size": 2,
   261                  "initializer_range": 0.02,
   262                  "layer_norm_eps": 1e-12,
   263                  "projection_dim": 0,
   264                  "position_embedding_type": "absolute"
   265                }
   266              }`
   267  	default:
   268  		return `{
   269                "model": {
   270                  "_name_or_path": "distilbert-base-uncased",
   271                  "activation": "gelu",
   272                  "add_cross_attention": false,
   273                  "architectures": [
   274                    "DistilBertModel"
   275                  ],
   276                  "attention_dropout": 0.1,
   277                  "bad_words_ids": null,
   278                  "bos_token_id": null,
   279                  "chunk_size_feed_forward": 0,
   280                  "decoder_start_token_id": null,
   281                  "dim": 768,
   282                  "diversity_penalty": 0,
   283                  "do_sample": false,
   284                  "dropout": 0.1,
   285                  "early_stopping": false,
   286                  "encoder_no_repeat_ngram_size": 0,
   287                  "eos_token_id": null,
   288                  "finetuning_task": null,
   289                  "hidden_dim": 3072,
   290                  "id2label": {
   291                    "0": "LABEL_0",
   292                    "1": "LABEL_1"
   293                  },
   294                  "initializer_range": 0.02,
   295                  "is_decoder": false,
   296                  "is_encoder_decoder": false,
   297                  "label2id": {
   298                    "LABEL_0": 0,
   299                    "LABEL_1": 1
   300                  },
   301                  "length_penalty": 1,
   302                  "max_length": 20,
   303                  "max_position_embeddings": 512,
   304                  "min_length": 0,
   305                  "model_type": "distilbert",
   306                  "n_heads": 12,
   307                  "n_layers": 6,
   308                  "no_repeat_ngram_size": 0,
   309                  "num_beam_groups": 1,
   310                  "num_beams": 1,
   311                  "num_return_sequences": 1,
   312                  "output_attentions": false,
   313                  "output_hidden_states": false,
   314                  "output_scores": false,
   315                  "pad_token_id": 0,
   316                  "prefix": null,
   317                  "pruned_heads": {},
   318                  "qa_dropout": 0.1,
   319                  "repetition_penalty": 1,
   320                  "return_dict": true,
   321                  "return_dict_in_generate": false,
   322                  "sep_token_id": null,
   323                  "seq_classif_dropout": 0.2,
   324                  "sinusoidal_pos_embds": false,
   325                  "task_specific_params": null,
   326                  "temperature": 1,
   327                  "tie_encoder_decoder": false,
   328                  "tie_weights_": true,
   329                  "tie_word_embeddings": true,
   330                  "tokenizer_class": null,
   331                  "top_k": 50,
   332                  "top_p": 1,
   333                  "torchscript": false,
   334                  "transformers_version": "4.3.2",
   335                  "use_bfloat16": false,
   336                  "vocab_size": 30522,
   337                  "xla_device": null
   338                }
   339              }`
   340  	}
   341  }
   342  
   343  func extractChildMap(t *testing.T, parent map[string]interface{}, name string) map[string]interface{} {
   344  	assert.NotNil(t, parent[name])
   345  	child, ok := parent[name].(map[string]interface{})
   346  	assert.True(t, ok)
   347  	assert.NotNil(t, child)
   348  
   349  	return child
   350  }