github.com/weaviate/weaviate@v1.24.6/modules/ner-transformers/clients/ner_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"context"
    16  	"encoding/json"
    17  	"net/http"
    18  	"net/http/httptest"
    19  	"testing"
    20  
    21  	"github.com/stretchr/testify/assert"
    22  	"github.com/stretchr/testify/require"
    23  	"github.com/weaviate/weaviate/modules/ner-transformers/ent"
    24  )
    25  
    26  func TestGetAnswer(t *testing.T) {
    27  	t.Run("when the server has a successful answer (with distance)", func(t *testing.T) {
    28  		server := httptest.NewServer(&testNERHandler{
    29  			t: t,
    30  			res: nerResponse{
    31  				nerInput: nerInput{
    32  					Text: "I work at Apple",
    33  				},
    34  				Tokens: []tokenResponse{
    35  					{
    36  						Entity:        "I-ORG",
    37  						Distance:      0.3,
    38  						Word:          "Apple",
    39  						StartPosition: 20,
    40  						EndPosition:   25,
    41  					},
    42  				},
    43  			},
    44  		})
    45  		defer server.Close()
    46  		c := New(server.URL, 0, nullLogger())
    47  		res, err := c.GetTokens(context.Background(), "prop",
    48  			"I work at Apple")
    49  
    50  		assert.Nil(t, err)
    51  		assert.Equal(t, []ent.TokenResult{
    52  			{
    53  				Entity:        "I-ORG",
    54  				Distance:      0.3,
    55  				Word:          "Apple",
    56  				StartPosition: 20,
    57  				EndPosition:   25,
    58  				Property:      "prop",
    59  			},
    60  		}, res)
    61  	})
    62  
    63  	t.Run("when the server has a successful answer (with certainty)", func(t *testing.T) {
    64  		server := httptest.NewServer(&testNERHandler{
    65  			t: t,
    66  			res: nerResponse{
    67  				nerInput: nerInput{
    68  					Text: "I work at Apple",
    69  				},
    70  				Tokens: []tokenResponse{
    71  					{
    72  						Entity:        "I-ORG",
    73  						Certainty:     0.7,
    74  						Word:          "Apple",
    75  						StartPosition: 20,
    76  						EndPosition:   25,
    77  					},
    78  				},
    79  			},
    80  		})
    81  		defer server.Close()
    82  		c := New(server.URL, 0, nullLogger())
    83  		res, err := c.GetTokens(context.Background(), "prop",
    84  			"I work at Apple")
    85  
    86  		assert.Nil(t, err)
    87  		assert.Equal(t, []ent.TokenResult{
    88  			{
    89  				Entity:        "I-ORG",
    90  				Certainty:     0.7,
    91  				Word:          "Apple",
    92  				StartPosition: 20,
    93  				EndPosition:   25,
    94  				Property:      "prop",
    95  			},
    96  		}, res)
    97  	})
    98  
    99  	t.Run("when the server has a an error", func(t *testing.T) {
   100  		server := httptest.NewServer(&testNERHandler{
   101  			t: t,
   102  			res: nerResponse{
   103  				Error: "some error from the server",
   104  			},
   105  		})
   106  		defer server.Close()
   107  		c := New(server.URL, 0, nullLogger())
   108  		_, err := c.GetTokens(context.Background(), "prop",
   109  			"I work at Apple")
   110  
   111  		require.NotNil(t, err)
   112  		assert.Contains(t, err.Error(), "some error from the server")
   113  	})
   114  }
   115  
   116  type testNERHandler struct {
   117  	t *testing.T
   118  	// the test handler will report as not ready before the time has passed
   119  	res nerResponse
   120  }
   121  
   122  func (f *testNERHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   123  	assert.Equal(f.t, "/ner/", r.URL.String())
   124  	assert.Equal(f.t, http.MethodPost, r.Method)
   125  
   126  	if f.res.Error != "" {
   127  		w.WriteHeader(500)
   128  	}
   129  
   130  	jsonBytes, _ := json.Marshal(f.res)
   131  	w.Write(jsonBytes)
   132  }