github.com/weaviate/weaviate@v1.24.6/adapters/clients/remote_index_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  //	_       _
    13  //
    14  // __      _____  __ ___   ___  __ _| |_ ___
    15  //
    16  //	\ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
    17  //	 \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
    18  //	  \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
    19  //
    20  //	 Copyright © 2016 - 2022 SeMI Technologies B.V. All rights reserved.
    21  //
    22  //	 CONTACT: hello@semi.technology
    23  package clients
    24  
    25  import (
    26  	"context"
    27  	"fmt"
    28  	"io"
    29  	"net/http"
    30  	"net/http/httptest"
    31  	"strings"
    32  	"testing"
    33  	"time"
    34  
    35  	"github.com/stretchr/testify/assert"
    36  	"github.com/weaviate/weaviate/adapters/handlers/rest/clusterapi"
    37  )
    38  
    39  func TestRemoteIndexIncreaseRF(t *testing.T) {
    40  	t.Parallel()
    41  
    42  	ctx := context.Background()
    43  	path := "/replicas/indices/C1/replication-factor:increase"
    44  	fs := newFakeRemoteIndexServer(t, http.MethodPut, path)
    45  	ts := fs.server(t)
    46  	defer ts.Close()
    47  	client := newRemoteIndex(ts.Client())
    48  	t.Run("ConnectionError", func(t *testing.T) {
    49  		err := client.IncreaseReplicationFactor(ctx, "", "C1", nil)
    50  		assert.NotNil(t, err)
    51  		assert.Contains(t, err.Error(), "connect")
    52  	})
    53  	n := 0
    54  	fs.doAfter = func(w http.ResponseWriter, r *http.Request) {
    55  		if n == 0 {
    56  			w.WriteHeader(http.StatusInternalServerError)
    57  		} else if n == 1 {
    58  			w.WriteHeader(http.StatusTooManyRequests)
    59  		} else {
    60  			w.WriteHeader(http.StatusNoContent)
    61  		}
    62  		n++
    63  	}
    64  	t.Run("Success", func(t *testing.T) {
    65  		err := client.IncreaseReplicationFactor(ctx, fs.host, "C1", nil)
    66  		assert.Nil(t, err)
    67  	})
    68  }
    69  
    70  func TestRemoteIndexReInitShardIn(t *testing.T) {
    71  	t.Parallel()
    72  
    73  	ctx := context.Background()
    74  	path := "/indices/C1/shards/S1:reinit"
    75  	fs := newFakeRemoteIndexServer(t, http.MethodPut, path)
    76  	ts := fs.server(t)
    77  	defer ts.Close()
    78  	client := newRemoteIndex(ts.Client())
    79  	t.Run("ConnectionError", func(t *testing.T) {
    80  		err := client.ReInitShard(ctx, "", "C1", "S1")
    81  		assert.NotNil(t, err)
    82  		assert.Contains(t, err.Error(), "connect")
    83  	})
    84  	n := 0
    85  	fs.doAfter = func(w http.ResponseWriter, r *http.Request) {
    86  		if n == 0 {
    87  			w.WriteHeader(http.StatusInternalServerError)
    88  		} else if n == 1 {
    89  			w.WriteHeader(http.StatusTooManyRequests)
    90  		} else {
    91  			w.WriteHeader(http.StatusNoContent)
    92  		}
    93  		n++
    94  	}
    95  	t.Run("Success", func(t *testing.T) {
    96  		err := client.ReInitShard(ctx, fs.host, "C1", "S1")
    97  		assert.Nil(t, err)
    98  	})
    99  }
   100  
   101  func TestRemoteIndexCreateShard(t *testing.T) {
   102  	t.Parallel()
   103  
   104  	ctx := context.Background()
   105  	path := "/indices/C1/shards/S1"
   106  	fs := newFakeRemoteIndexServer(t, http.MethodPost, path)
   107  	ts := fs.server(t)
   108  	defer ts.Close()
   109  	client := newRemoteIndex(ts.Client())
   110  	t.Run("ConnectionError", func(t *testing.T) {
   111  		err := client.CreateShard(ctx, "", "C1", "S1")
   112  		assert.NotNil(t, err)
   113  		assert.Contains(t, err.Error(), "connect")
   114  	})
   115  	n := 0
   116  	fs.doAfter = func(w http.ResponseWriter, r *http.Request) {
   117  		if n == 0 {
   118  			w.WriteHeader(http.StatusInternalServerError)
   119  		} else if n == 1 {
   120  			w.WriteHeader(http.StatusTooManyRequests)
   121  		} else {
   122  			w.WriteHeader(http.StatusCreated)
   123  		}
   124  		n++
   125  	}
   126  	t.Run("Success", func(t *testing.T) {
   127  		err := client.CreateShard(ctx, fs.host, "C1", "S1")
   128  		assert.Nil(t, err)
   129  	})
   130  }
   131  
   132  func TestRemoteIndexUpdateShardStatus(t *testing.T) {
   133  	t.Parallel()
   134  
   135  	ctx := context.Background()
   136  	path := "/indices/C1/shards/S1/status"
   137  	fs := newFakeRemoteIndexServer(t, http.MethodPost, path)
   138  	ts := fs.server(t)
   139  	defer ts.Close()
   140  	client := newRemoteIndex(ts.Client())
   141  	t.Run("ConnectionError", func(t *testing.T) {
   142  		err := client.UpdateShardStatus(ctx, "", "C1", "S1", "NewStatus")
   143  		assert.NotNil(t, err)
   144  		assert.Contains(t, err.Error(), "connect")
   145  	})
   146  	n := 0
   147  	fs.doAfter = func(w http.ResponseWriter, r *http.Request) {
   148  		if n == 0 {
   149  			w.WriteHeader(http.StatusInternalServerError)
   150  		} else if n == 1 {
   151  			w.WriteHeader(http.StatusTooManyRequests)
   152  		}
   153  		n++
   154  	}
   155  	t.Run("Success", func(t *testing.T) {
   156  		err := client.UpdateShardStatus(ctx, fs.host, "C1", "S1", "NewStatus")
   157  		assert.Nil(t, err)
   158  	})
   159  }
   160  
   161  func TestRemoteIndexShardStatus(t *testing.T) {
   162  	t.Parallel()
   163  	var (
   164  		ctx    = context.Background()
   165  		path   = "/indices/C1/shards/S1/status"
   166  		fs     = newFakeRemoteIndexServer(t, http.MethodGet, path)
   167  		Status = "READONLY"
   168  	)
   169  	ts := fs.server(t)
   170  	defer ts.Close()
   171  	client := newRemoteIndex(ts.Client())
   172  	t.Run("ConnectionError", func(t *testing.T) {
   173  		_, err := client.GetShardStatus(ctx, "", "C1", "S1")
   174  		assert.NotNil(t, err)
   175  		assert.Contains(t, err.Error(), "connect")
   176  	})
   177  	n := 0
   178  	fs.doAfter = func(w http.ResponseWriter, r *http.Request) {
   179  		if n == 0 {
   180  			w.WriteHeader(http.StatusInternalServerError)
   181  		} else if n == 1 {
   182  			w.WriteHeader(http.StatusTooManyRequests)
   183  		} else if n == 2 {
   184  			w.Header().Set("content-type", "any")
   185  		} else if n == 3 {
   186  			clusterapi.IndicesPayloads.GetShardStatusResults.SetContentTypeHeader(w)
   187  		} else {
   188  			clusterapi.IndicesPayloads.GetShardStatusResults.SetContentTypeHeader(w)
   189  			bytes, _ := clusterapi.IndicesPayloads.GetShardStatusResults.Marshal(Status)
   190  			w.Write(bytes)
   191  		}
   192  		n++
   193  	}
   194  
   195  	t.Run("ContentType", func(t *testing.T) {
   196  		_, err := client.GetShardStatus(ctx, fs.host, "C1", "S1")
   197  		assert.NotNil(t, err)
   198  	})
   199  	t.Run("Status", func(t *testing.T) {
   200  		_, err := client.GetShardStatus(ctx, fs.host, "C1", "S1")
   201  		assert.NotNil(t, err)
   202  	})
   203  	t.Run("Success", func(t *testing.T) {
   204  		st, err := client.GetShardStatus(ctx, fs.host, "C1", "S1")
   205  		assert.Nil(t, err)
   206  		assert.Equal(t, "READONLY", st)
   207  	})
   208  }
   209  
   210  func TestRemoteIndexPutFile(t *testing.T) {
   211  	t.Parallel()
   212  	var (
   213  		ctx  = context.Background()
   214  		path = "/indices/C1/shards/S1/files/file1"
   215  		fs   = newFakeRemoteIndexServer(t, http.MethodPost, path)
   216  	)
   217  	ts := fs.server(t)
   218  	defer ts.Close()
   219  	client := newRemoteIndex(ts.Client())
   220  
   221  	rsc := struct {
   222  		*strings.Reader
   223  		io.Closer
   224  	}{
   225  		strings.NewReader("hello, world"),
   226  		io.NopCloser(nil),
   227  	}
   228  	t.Run("ConnectionError", func(t *testing.T) {
   229  		err := client.PutFile(ctx, "", "C1", "S1", "file1", rsc)
   230  		assert.NotNil(t, err)
   231  		assert.Contains(t, err.Error(), "connect")
   232  	})
   233  	n := 0
   234  	fs.doAfter = func(w http.ResponseWriter, r *http.Request) {
   235  		if n == 0 {
   236  			w.WriteHeader(http.StatusInternalServerError)
   237  		} else if n == 1 {
   238  			w.WriteHeader(http.StatusTooManyRequests)
   239  		} else {
   240  			w.WriteHeader(http.StatusNoContent)
   241  		}
   242  		n++
   243  	}
   244  
   245  	t.Run("Success", func(t *testing.T) {
   246  		err := client.PutFile(ctx, fs.host, "C1", "S1", "file1", rsc)
   247  		assert.Nil(t, err)
   248  	})
   249  }
   250  
   251  func newRemoteIndex(httpClient *http.Client) *RemoteIndex {
   252  	ri := NewRemoteIndex(httpClient)
   253  	ri.minBackOff = time.Millisecond * 1
   254  	ri.maxBackOff = time.Millisecond * 10
   255  	ri.timeoutUnit = time.Millisecond * 20
   256  	return ri
   257  }
   258  
   259  type fakeRemoteIndexServer struct {
   260  	method   string
   261  	path     string
   262  	host     string
   263  	doBefore func(w http.ResponseWriter, r *http.Request) error
   264  	doAfter  func(w http.ResponseWriter, r *http.Request)
   265  }
   266  
   267  func newFakeRemoteIndexServer(t *testing.T, method, path string) *fakeRemoteIndexServer {
   268  	f := &fakeRemoteIndexServer{
   269  		method: method,
   270  		path:   path,
   271  	}
   272  	f.doBefore = func(w http.ResponseWriter, r *http.Request) error {
   273  		if r.Method != f.method {
   274  			w.WriteHeader(http.StatusBadRequest)
   275  			return fmt.Errorf("method want %s got %s", method, r.Method)
   276  		}
   277  		if f.path != r.URL.Path {
   278  			w.WriteHeader(http.StatusBadRequest)
   279  			return fmt.Errorf("path want %s got %s", path, r.URL.Path)
   280  		}
   281  		return nil
   282  	}
   283  	return f
   284  }
   285  
   286  func (f *fakeRemoteIndexServer) server(t *testing.T) *httptest.Server {
   287  	if f.doBefore == nil {
   288  		f.doBefore = func(w http.ResponseWriter, r *http.Request) error {
   289  			if r.Method != f.method {
   290  				w.WriteHeader(http.StatusBadRequest)
   291  				return fmt.Errorf("method want %s got %s", f.method, r.Method)
   292  			}
   293  			if f.path != r.URL.Path {
   294  				w.WriteHeader(http.StatusBadRequest)
   295  				return fmt.Errorf("path want %s got %s", f.path, r.URL.Path)
   296  			}
   297  			return nil
   298  		}
   299  	}
   300  	handler := func(w http.ResponseWriter, r *http.Request) {
   301  		if err := f.doBefore(w, r); err != nil {
   302  			t.Error(err)
   303  			return
   304  		}
   305  		if f.doAfter != nil {
   306  			f.doAfter(w, r)
   307  		}
   308  	}
   309  	serv := httptest.NewServer(http.HandlerFunc(handler))
   310  	f.host = serv.URL[7:]
   311  	return serv
   312  }