github.com/anycable/anycable-go@v1.5.1/rpc/http_test.go (about)

     1  package rpc
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"sync/atomic"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/anycable/anycable-go/common"
    15  	"github.com/anycable/anycable-go/protocol"
    16  	pb "github.com/anycable/anycable-go/protos"
    17  	"github.com/anycable/anycable-go/utils"
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/stretchr/testify/require"
    20  	"google.golang.org/grpc/codes"
    21  	"google.golang.org/grpc/metadata"
    22  	"google.golang.org/grpc/status"
    23  )
    24  
    25  func TestHTTPServiceRPC(t *testing.T) {
    26  	var onRequest func(r *http.Request, w http.ResponseWriter)
    27  
    28  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    29  		if r.Method != "POST" {
    30  			w.WriteHeader(http.StatusMethodNotAllowed)
    31  			return
    32  		}
    33  
    34  		if onRequest == nil {
    35  			w.WriteHeader(http.StatusNotFound)
    36  			return
    37  		}
    38  
    39  		onRequest(r, w)
    40  	}))
    41  
    42  	defer ts.Close()
    43  
    44  	conf := NewConfig()
    45  	conf.Host = ts.URL
    46  
    47  	service, _ := NewHTTPService(&conf)
    48  
    49  	t.Run("Connect", func(t *testing.T) {
    50  		onRequest = func(r *http.Request, w http.ResponseWriter) {
    51  			require.Equal(t, "/connect", r.URL.Path)
    52  
    53  			body, err := io.ReadAll(r.Body)
    54  			require.NoError(t, err)
    55  
    56  			var req pb.ConnectionRequest
    57  			err = json.Unmarshal(body, &req)
    58  			require.NoError(t, err)
    59  
    60  			require.Equal(t, "ws://anycable.io/cable", req.Env.Url)
    61  			require.Equal(t, "foo=bar", req.Env.Headers["cookie"])
    62  
    63  			identifiers := fmt.Sprintf("%s-%s", r.Header.Get("x-anycable-meta-year"), r.Header.Get("x-anycable-meta-album"))
    64  
    65  			res := pb.ConnectionResponse{
    66  				Transmissions: []string{"welcome"},
    67  				Identifiers:   identifiers,
    68  				Status:        pb.Status_SUCCESS,
    69  			}
    70  
    71  			w.Write(utils.ToJSON(res)) // nolint: errcheck
    72  		}
    73  
    74  		md := metadata.Pairs("album", "Kamni", "year", "2008")
    75  		ctx := metadata.NewIncomingContext(context.Background(), md)
    76  		res, err := service.Connect(ctx, protocol.NewConnectMessage(
    77  			common.NewSessionEnv("ws://anycable.io/cable", &map[string]string{"cookie": "foo=bar"}),
    78  		))
    79  
    80  		require.NoError(t, err)
    81  
    82  		assert.Equal(t, pb.Status_SUCCESS, res.Status)
    83  		assert.Equal(t, []string{"welcome"}, res.Transmissions)
    84  		assert.Equal(t, "2008-Kamni", res.Identifiers)
    85  	})
    86  
    87  	t.Run("Disconnect", func(t *testing.T) {
    88  		onRequest = func(r *http.Request, w http.ResponseWriter) {
    89  			require.Equal(t, "/disconnect", r.URL.Path)
    90  
    91  			body, err := io.ReadAll(r.Body)
    92  			require.NoError(t, err)
    93  
    94  			var req pb.DisconnectRequest
    95  			err = json.Unmarshal(body, &req)
    96  			require.NoError(t, err)
    97  
    98  			require.Equal(t, "ws://anycable.io/cable", req.Env.Url)
    99  			require.Equal(t, "foo=bar", req.Env.Headers["cookie"])
   100  			require.Equal(t, "test-session", req.Identifiers)
   101  
   102  			res := pb.DisconnectResponse{
   103  				Status:   pb.Status_ERROR,
   104  				ErrorMsg: r.Header.Get("x-anycable-meta-error"),
   105  			}
   106  
   107  			w.Write(utils.ToJSON(res)) // nolint: errcheck
   108  		}
   109  
   110  		md := metadata.Pairs("error", "test error")
   111  		ctx := metadata.NewIncomingContext(context.Background(), md)
   112  		res, err := service.Disconnect(ctx, protocol.NewDisconnectMessage(
   113  			common.NewSessionEnv("ws://anycable.io/cable", &map[string]string{"cookie": "foo=bar"}),
   114  			"test-session",
   115  			[]string{},
   116  		))
   117  
   118  		require.NoError(t, err)
   119  
   120  		assert.Equal(t, pb.Status_ERROR, res.Status)
   121  		assert.Equal(t, "test error", res.ErrorMsg)
   122  	})
   123  
   124  	t.Run("Command", func(t *testing.T) {
   125  		onRequest = func(r *http.Request, w http.ResponseWriter) {
   126  			require.Equal(t, "/command", r.URL.Path)
   127  
   128  			body, err := io.ReadAll(r.Body)
   129  			require.NoError(t, err)
   130  
   131  			var req pb.CommandMessage
   132  			err = json.Unmarshal(body, &req)
   133  			require.NoError(t, err)
   134  
   135  			require.Equal(t, "chat_1", req.Identifier)
   136  			require.Equal(t, "subscribe", req.Command)
   137  
   138  			stream := r.Header.Get("x-anycable-meta-track")
   139  
   140  			res := pb.CommandResponse{
   141  				Transmissions: []string{"confirmed"},
   142  				Streams:       []string{stream},
   143  				Status:        pb.Status_SUCCESS,
   144  			}
   145  
   146  			w.Write(utils.ToJSON(res)) // nolint: errcheck
   147  		}
   148  
   149  		md := metadata.Pairs("track", "easy-way-out")
   150  		ctx := metadata.NewIncomingContext(context.Background(), md)
   151  		res, err := service.Command(ctx, protocol.NewCommandMessage(
   152  			common.NewSessionEnv("ws://anycable.io/cable", &map[string]string{"cookie": "foo=bar"}),
   153  			"subscribe",
   154  			"chat_1",
   155  			"test-session",
   156  			"{}",
   157  		))
   158  
   159  		require.NoError(t, err)
   160  
   161  		assert.Equal(t, pb.Status_SUCCESS, res.Status)
   162  		assert.Equal(t, []string{"confirmed"}, res.Transmissions)
   163  		assert.Equal(t, []string{"easy-way-out"}, res.Streams)
   164  	})
   165  }
   166  
   167  func TestHTTPServiceAuthentication(t *testing.T) {
   168  	conf := NewConfig()
   169  	conf.Secret = "secretto"
   170  
   171  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   172  		if r.Header.Get("Authorization") != "Bearer secretto" {
   173  			w.WriteHeader(http.StatusUnauthorized)
   174  			return
   175  		}
   176  
   177  		res := pb.ConnectionResponse{
   178  			Status: pb.Status_SUCCESS,
   179  		}
   180  
   181  		w.Write(utils.ToJSON(res)) // nolint: errcheck
   182  	}))
   183  
   184  	defer ts.Close()
   185  
   186  	conf.Host = ts.URL
   187  
   188  	service, _ := NewHTTPService(&conf)
   189  
   190  	request := protocol.NewConnectMessage(
   191  		common.NewSessionEnv("ws://anycable.io/cable", &map[string]string{"cookie": "foo=bar"}),
   192  	)
   193  
   194  	t.Run("Authentication_SUCCESS", func(t *testing.T) {
   195  		res, err := service.Connect(context.Background(), request)
   196  
   197  		require.NoError(t, err)
   198  
   199  		assert.Equal(t, pb.Status_SUCCESS, res.Status)
   200  	})
   201  
   202  	t.Run("Authentication_FAILURE", func(t *testing.T) {
   203  		newConf := NewConfig()
   204  		newConf.Secret = "not-a-secret"
   205  		newConf.Host = ts.URL
   206  
   207  		service, _ := NewHTTPService(&newConf)
   208  
   209  		_, err := service.Connect(context.Background(), request)
   210  
   211  		require.Error(t, err)
   212  
   213  		grpcErr, ok := status.FromError(err)
   214  
   215  		require.True(t, ok)
   216  
   217  		assert.Equal(t, codes.Unauthenticated, grpcErr.Code())
   218  	})
   219  }
   220  
   221  func TestHTTPServiceRequestTimeout(t *testing.T) {
   222  	completed := int64(0)
   223  
   224  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   225  		res := pb.ConnectionResponse{
   226  			Status: pb.Status_SUCCESS,
   227  		}
   228  
   229  		// Timers are not determenistic (especially on CI with OSX — don't know why)
   230  		// let's make sure the request is slow enough to be cancelled
   231  		for atomic.LoadInt64(&completed) == 0 {
   232  			time.Sleep(50 * time.Millisecond)
   233  		}
   234  
   235  		w.Write(utils.ToJSON(res)) // nolint: errcheck
   236  	}))
   237  
   238  	defer ts.Close()
   239  
   240  	conf := NewConfig()
   241  	conf.Host = ts.URL
   242  	conf.RequestTimeout = 50
   243  
   244  	service, _ := NewHTTPService(&conf)
   245  	request := protocol.NewConnectMessage(
   246  		common.NewSessionEnv("ws://anycable.io/cable", &map[string]string{"cookie": "foo=bar"}),
   247  	)
   248  
   249  	ctx := context.Background()
   250  
   251  	_, err := service.Connect(ctx, request)
   252  	atomic.AddInt64(&completed, 1)
   253  
   254  	require.Error(t, err)
   255  
   256  	grpcErr, ok := status.FromError(err)
   257  
   258  	require.True(t, ok)
   259  
   260  	assert.Equal(t, codes.DeadlineExceeded, grpcErr.Code())
   261  }
   262  
   263  func TestHTTPServiceBadRequests(t *testing.T) {
   264  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   265  		w.WriteHeader(http.StatusUnprocessableEntity)
   266  	}))
   267  
   268  	defer ts.Close()
   269  
   270  	conf := NewConfig()
   271  	conf.Host = ts.URL
   272  	conf.RequestTimeout = 50
   273  
   274  	request := protocol.NewConnectMessage(
   275  		common.NewSessionEnv("ws://anycable.io/cable", &map[string]string{"cookie": "foo=bar"}),
   276  	)
   277  
   278  	t.Run("unknown url", func(t *testing.T) {
   279  		newConf := NewConfig()
   280  		newConf.Host = "http://localhost:1234"
   281  
   282  		service, _ := NewHTTPService(&newConf)
   283  
   284  		ctx := context.Background()
   285  
   286  		_, err := service.Connect(ctx, request)
   287  
   288  		require.Error(t, err)
   289  
   290  		grpcErr, ok := status.FromError(err)
   291  
   292  		require.True(t, ok)
   293  
   294  		assert.Equal(t, codes.Unavailable, grpcErr.Code())
   295  	})
   296  
   297  	t.Run("bad request", func(t *testing.T) {
   298  		service, _ := NewHTTPService(&conf)
   299  
   300  		ctx := context.Background()
   301  
   302  		_, err := service.Connect(ctx, request)
   303  
   304  		require.Error(t, err)
   305  
   306  		grpcErr, ok := status.FromError(err)
   307  
   308  		require.True(t, ok)
   309  
   310  		assert.Equal(t, codes.InvalidArgument, grpcErr.Code())
   311  	})
   312  }
   313  
   314  func TestHTTPClientHelper_READY(t *testing.T) {
   315  	conf := NewConfig()
   316  	conf.Host = "http://localhost:1234"
   317  
   318  	service, _ := NewHTTPService(&conf)
   319  	h := NewHTTPClientHelper(service)
   320  
   321  	assert.NoError(t, h.Ready())
   322  
   323  	// by default, we open a breaker if there are >20% of errors
   324  	request := protocol.NewConnectMessage(
   325  		common.NewSessionEnv("ws://anycable.io/cable", &map[string]string{"cookie": "foo=bar"}),
   326  	)
   327  
   328  	for i := 0; i < 20; i++ {
   329  		_, err := service.Connect(context.Background(), request)
   330  		require.Error(t, err)
   331  	}
   332  
   333  	// Shouldn't be ready
   334  	assert.Error(t, h.Ready())
   335  }