github.com/instill-ai/component@v0.16.0-beta/pkg/connector/pinecone/v0/connector_test.go (about)

     1  package pinecone
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"io"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"testing"
    10  
    11  	qt "github.com/frankban/quicktest"
    12  	"go.uber.org/zap"
    13  	"google.golang.org/protobuf/types/known/structpb"
    14  
    15  	"github.com/instill-ai/component/pkg/base"
    16  	"github.com/instill-ai/component/pkg/connector/util/httpclient"
    17  	"github.com/instill-ai/x/errmsg"
    18  )
    19  
    20  const (
    21  	pineconeKey = "secret-key"
    22  	namespace   = "pantone"
    23  	threshold   = 0.9
    24  
    25  	upsertOK = `{"upsertedCount": 1}`
    26  
    27  	queryOK = `
    28  {
    29  	"namespace": "color-schemes",
    30  	"matches": [
    31  		{
    32  			"id": "A",
    33  			"values": [ 2.23 ],
    34  			"metadata": { "color": "pumpkin" },
    35  			"score": 0.99
    36  		},
    37  		{
    38  			"id": "B",
    39  			"values": [ 3.32 ],
    40  			"metadata": { "color": "cerulean" },
    41  			"score": 0.87
    42  		}
    43  	]
    44  }`
    45  
    46  	errResp = `
    47  {
    48    "code": 3,
    49    "message": "Cannot provide both ID and vector at the same time.",
    50    "details": []
    51  }`
    52  )
    53  
    54  var (
    55  	vectorA = vector{
    56  		ID:       "A",
    57  		Values:   []float64{2.23},
    58  		Metadata: map[string]any{"color": "pumpkin"},
    59  	}
    60  	vectorB = vector{
    61  		ID:       "B",
    62  		Values:   []float64{3.32},
    63  		Metadata: map[string]any{"color": "cerulean"},
    64  	}
    65  	queryByVector = queryInput{
    66  		Namespace:       "color-schemes",
    67  		TopK:            1,
    68  		Vector:          vectorA.Values,
    69  		IncludeValues:   true,
    70  		IncludeMetadata: true,
    71  		Filter: map[string]any{
    72  			"color": map[string]any{
    73  				"$in": []string{"green", "cerulean", "pumpkin"},
    74  			},
    75  		},
    76  	}
    77  	queryWithThreshold = func(q queryInput, th float64) queryInput {
    78  		q.MinScore = th
    79  		return q
    80  	}
    81  	queryByID = queryInput{
    82  		Namespace:       "color-schemes",
    83  		TopK:            1,
    84  		Vector:          vectorA.Values,
    85  		ID:              vectorA.ID,
    86  		IncludeValues:   true,
    87  		IncludeMetadata: true,
    88  	}
    89  )
    90  
    91  func TestConnector_Execute(t *testing.T) {
    92  	c := qt.New(t)
    93  
    94  	testcases := []struct {
    95  		name string
    96  
    97  		task     string
    98  		execIn   any
    99  		wantExec any
   100  
   101  		wantClientPath string
   102  		wantClientReq  any
   103  		clientResp     string
   104  	}{
   105  		{
   106  			name: "ok - upsert",
   107  
   108  			task: taskUpsert,
   109  			execIn: upsertInput{
   110  				vector:    vectorA,
   111  				Namespace: namespace,
   112  			},
   113  			wantExec: upsertOutput{RecordsUpserted: 1},
   114  
   115  			wantClientPath: upsertPath,
   116  			wantClientReq:  upsertReq{Vectors: []vector{vectorA}, Namespace: namespace},
   117  			clientResp:     upsertOK,
   118  		},
   119  		{
   120  			name: "ok - query by vector",
   121  
   122  			task:   taskQuery,
   123  			execIn: queryByVector,
   124  			wantExec: queryResp{
   125  				Namespace: "color-schemes",
   126  				Matches: []match{
   127  					{
   128  						vector: vectorA,
   129  						Score:  0.99,
   130  					},
   131  					{
   132  						vector: vectorB,
   133  						Score:  0.87,
   134  					},
   135  				},
   136  			},
   137  
   138  			wantClientPath: queryPath,
   139  			wantClientReq:  queryByVector.asRequest(),
   140  			clientResp:     queryOK,
   141  		},
   142  		{
   143  			name: "ok - filter out below threshold score",
   144  
   145  			task:   taskQuery,
   146  			execIn: queryWithThreshold(queryByVector, threshold),
   147  			wantExec: queryResp{
   148  				Namespace: "color-schemes",
   149  				Matches: []match{
   150  					{
   151  						vector: vectorA,
   152  						Score:  0.99,
   153  					},
   154  				},
   155  			},
   156  
   157  			wantClientPath: queryPath,
   158  			wantClientReq:  queryByVector.asRequest(),
   159  			clientResp:     queryOK,
   160  		},
   161  		{
   162  			name: "ok - query by ID",
   163  
   164  			task:   taskQuery,
   165  			execIn: queryByID,
   166  			wantExec: queryResp{
   167  				Namespace: "color-schemes",
   168  				Matches: []match{
   169  					{
   170  						vector: vectorA,
   171  						Score:  0.99,
   172  					},
   173  					{
   174  						vector: vectorB,
   175  						Score:  0.87,
   176  					},
   177  				},
   178  			},
   179  
   180  			wantClientPath: queryPath,
   181  			wantClientReq: queryReq{
   182  				// Vector is wiped from the request.
   183  				Namespace:       "color-schemes",
   184  				TopK:            1,
   185  				ID:              vectorA.ID,
   186  				IncludeValues:   true,
   187  				IncludeMetadata: true,
   188  			},
   189  			clientResp: queryOK,
   190  		},
   191  	}
   192  
   193  	logger := zap.NewNop()
   194  	connector := Init(logger, nil)
   195  
   196  	for _, tc := range testcases {
   197  		c.Run(tc.name, func(c *qt.C) {
   198  			h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   199  				// For now only POST methods are considered. When this changes,
   200  				// this will need to be asserted per-path.
   201  				c.Check(r.Method, qt.Equals, http.MethodPost)
   202  				c.Check(r.URL.Path, qt.Equals, tc.wantClientPath)
   203  
   204  				c.Check(r.Header.Get("Content-Type"), qt.Equals, httpclient.MIMETypeJSON)
   205  				c.Check(r.Header.Get("Accept"), qt.Equals, httpclient.MIMETypeJSON)
   206  				c.Check(r.Header.Get("Api-Key"), qt.Equals, pineconeKey)
   207  
   208  				c.Assert(r.Body, qt.IsNotNil)
   209  				defer r.Body.Close()
   210  
   211  				body, err := io.ReadAll(r.Body)
   212  				c.Assert(err, qt.IsNil)
   213  				c.Check(body, qt.JSONEquals, tc.wantClientReq)
   214  
   215  				w.Header().Set("Content-Type", httpclient.MIMETypeJSON)
   216  				fmt.Fprintln(w, tc.clientResp)
   217  			})
   218  
   219  			pineconeServer := httptest.NewServer(h)
   220  			c.Cleanup(pineconeServer.Close)
   221  
   222  			connection, _ := structpb.NewStruct(map[string]any{
   223  				"api_key": pineconeKey,
   224  				"url":     pineconeServer.URL,
   225  			})
   226  
   227  			exec, err := connector.CreateExecution(nil, connection, tc.task)
   228  			c.Assert(err, qt.IsNil)
   229  
   230  			pbIn, err := base.ConvertToStructpb(tc.execIn)
   231  			c.Assert(err, qt.IsNil)
   232  
   233  			got, err := exec.Execution.Execute([]*structpb.Struct{pbIn})
   234  			c.Check(err, qt.IsNil)
   235  
   236  			c.Assert(got, qt.HasLen, 1)
   237  			wantJSON, err := json.Marshal(tc.wantExec)
   238  			c.Assert(err, qt.IsNil)
   239  			c.Check(wantJSON, qt.JSONEquals, got[0].AsMap())
   240  		})
   241  	}
   242  
   243  	c.Run("nok - 400", func(c *qt.C) {
   244  		h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   245  			w.Header().Set("Content-Type", "application/json")
   246  			w.WriteHeader(http.StatusBadRequest)
   247  			fmt.Fprintln(w, errResp)
   248  		})
   249  
   250  		pineconeServer := httptest.NewServer(h)
   251  		c.Cleanup(pineconeServer.Close)
   252  
   253  		connection, _ := structpb.NewStruct(map[string]any{
   254  			"url": pineconeServer.URL,
   255  		})
   256  
   257  		exec, err := connector.CreateExecution(nil, connection, taskUpsert)
   258  		c.Assert(err, qt.IsNil)
   259  
   260  		pbIn := new(structpb.Struct)
   261  		_, err = exec.Execution.Execute([]*structpb.Struct{pbIn})
   262  		c.Check(err, qt.IsNotNil)
   263  
   264  		want := "Pinecone responded with a 400 status code. Cannot provide both ID and vector at the same time."
   265  		c.Check(errmsg.Message(err), qt.Equals, want)
   266  	})
   267  
   268  	c.Run("nok - URL misconfiguration", func(c *qt.C) {
   269  		connection, _ := structpb.NewStruct(map[string]any{
   270  			"url": "http://no-such.host",
   271  		})
   272  
   273  		exec, err := connector.CreateExecution(nil, connection, taskUpsert)
   274  		c.Assert(err, qt.IsNil)
   275  
   276  		pbIn := new(structpb.Struct)
   277  		_, err = exec.Execution.Execute([]*structpb.Struct{pbIn})
   278  		c.Check(err, qt.IsNotNil)
   279  
   280  		want := "Failed to call http://no-such.host/.*. Please check that the connector configuration is correct."
   281  		c.Check(errmsg.Message(err), qt.Matches, want)
   282  	})
   283  }