tractor.dev/toolkit-go@v0.0.0-20241010005851-214d91207d07/duplex/rpc/rpc_test.go (about)

     1  package rpc
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	"strings"
     9  	"testing"
    10  	"time"
    11  
    12  	"tractor.dev/toolkit-go/duplex/codec"
    13  	"tractor.dev/toolkit-go/duplex/mux"
    14  )
    15  
    16  func fatal(t *testing.T, err error) {
    17  	t.Helper()
    18  	if err != nil {
    19  		t.Fatal(err)
    20  	}
    21  }
    22  
    23  func newTestPair(handler Handler) (*Client, *Server) {
    24  	ar, bw := io.Pipe()
    25  	br, aw := io.Pipe()
    26  	sessA, _ := mux.DialIO(aw, ar)
    27  	sessB, _ := mux.DialIO(bw, br)
    28  
    29  	srv := &Server{
    30  		Codec:   codec.JSONCodec{},
    31  		Handler: handler,
    32  	}
    33  	go srv.Respond(sessA, nil)
    34  
    35  	return NewClient(sessB, codec.JSONCodec{}), srv
    36  }
    37  
    38  func TestServerNoCodec(t *testing.T) {
    39  	defer func() {
    40  		if r := recover(); r == nil {
    41  			t.Errorf("did not panic from unset codec")
    42  		}
    43  	}()
    44  
    45  	ar, _ := io.Pipe()
    46  	_, aw := io.Pipe()
    47  	sessA, _ := mux.DialIO(aw, ar)
    48  
    49  	srv := &Server{
    50  		Handler: NotFoundHandler(),
    51  	}
    52  	go sessA.Close()
    53  	srv.Respond(sessA, nil)
    54  }
    55  
    56  func TestRespondMux(t *testing.T) {
    57  	ctx := context.Background()
    58  
    59  	t.Run("selector mux", func(t *testing.T) {
    60  		mux := NewRespondMux()
    61  		mux.Handle("foo", HandlerFunc(func(r Responder, c *Call) {
    62  			r.Return("foo")
    63  		}))
    64  		mux.Handle("bar", HandlerFunc(func(r Responder, c *Call) {
    65  			r.Return("bar")
    66  		}))
    67  
    68  		client, _ := newTestPair(mux)
    69  		defer client.Close()
    70  
    71  		var out string
    72  		_, err := client.Call(ctx, "foo", nil, &out)
    73  		fatal(t, err)
    74  		if out != "foo" {
    75  			t.Fatal("unexpected return:", out)
    76  		}
    77  
    78  		_, err = client.Call(ctx, "bar", nil, &out)
    79  		fatal(t, err)
    80  		if out != "bar" {
    81  			t.Fatal("unexpected return:", out)
    82  		}
    83  	})
    84  
    85  	t.Run("selector not found error", func(t *testing.T) {
    86  		mux := NewRespondMux()
    87  		mux.Handle("foo", HandlerFunc(func(r Responder, c *Call) {
    88  			r.Return("foo")
    89  		}))
    90  
    91  		client, _ := newTestPair(mux)
    92  		defer client.Close()
    93  
    94  		var out string
    95  		_, err := client.Call(ctx, "baz", nil, &out)
    96  		if err == nil {
    97  			t.Fatal("expected error")
    98  		}
    99  		if err != nil {
   100  			rErr, ok := err.(RemoteError)
   101  			if !ok {
   102  				t.Fatal("unexpected error:", err)
   103  			}
   104  			if rErr.Error() != "remote: not found: /baz" {
   105  				t.Fatal("unexpected error:", rErr)
   106  			}
   107  		}
   108  	})
   109  
   110  	t.Run("default handler mux", func(t *testing.T) {
   111  		mux := NewRespondMux()
   112  		mux.Handle("foo", HandlerFunc(func(r Responder, c *Call) {
   113  			r.Return("foo")
   114  		}))
   115  		mux.Handle("", HandlerFunc(func(r Responder, c *Call) {
   116  			r.Return(fmt.Errorf("default"))
   117  		}))
   118  
   119  		client, _ := newTestPair(mux)
   120  		defer client.Close()
   121  
   122  		var out string
   123  		_, err := client.Call(ctx, "baz", nil, &out)
   124  		if err == nil {
   125  			t.Fatal("expected error")
   126  		}
   127  		if err != nil {
   128  			rErr, ok := err.(RemoteError)
   129  			if !ok {
   130  				t.Fatal("unexpected error:", err)
   131  			}
   132  			if rErr.Error() != "remote: default" {
   133  				t.Fatal("unexpected error:", rErr)
   134  			}
   135  		}
   136  
   137  		_, err = client.Call(ctx, "foo", nil, &out)
   138  		if err != nil {
   139  			t.Fatal("unexpected error:", err)
   140  		}
   141  		if out != "foo" {
   142  			t.Fatal("unexpected return:", out)
   143  		}
   144  	})
   145  
   146  	t.Run("sub muxing", func(t *testing.T) {
   147  		mux := NewRespondMux()
   148  		submux := NewRespondMux()
   149  		mux.Handle("foo.bar", submux)
   150  		mux.Handle("", HandlerFunc(func(r Responder, c *Call) {
   151  			r.Return(fmt.Errorf("default"))
   152  		}))
   153  		submux.Handle("baz", HandlerFunc(func(r Responder, c *Call) {
   154  			r.Return("foobarbaz")
   155  		}))
   156  
   157  		client, _ := newTestPair(mux)
   158  		defer client.Close()
   159  
   160  		var out string
   161  		_, err := client.Call(ctx, "foo.bar.baz", nil, &out)
   162  		fatal(t, err)
   163  		if out != "foobarbaz" {
   164  			t.Fatal("unexpected return:", out)
   165  		}
   166  	})
   167  
   168  	t.Run("selector normalizing", func(t *testing.T) {
   169  		mux := NewRespondMux()
   170  		mux.Handle("foo.bar", HandlerFunc(func(r Responder, c *Call) {
   171  			r.Return("foobar")
   172  		}))
   173  
   174  		client, _ := newTestPair(mux)
   175  		defer client.Close()
   176  
   177  		var out string
   178  		_, err := client.Call(ctx, "/foo/bar", nil, &out)
   179  		fatal(t, err)
   180  		if out != "foobar" {
   181  			t.Fatal("unexpected return:", out)
   182  		}
   183  	})
   184  
   185  	t.Run("selector catchall", func(t *testing.T) {
   186  		mux := NewRespondMux()
   187  		mux.Handle("foo.bar.", HandlerFunc(func(r Responder, c *Call) {
   188  			r.Return("foobar")
   189  		}))
   190  
   191  		client, _ := newTestPair(mux)
   192  		defer client.Close()
   193  
   194  		var out string
   195  		_, err := client.Call(ctx, "foo.bar.baz", nil, &out)
   196  		fatal(t, err)
   197  		if out != "foobar" {
   198  			t.Fatal("unexpected return:", out)
   199  		}
   200  	})
   201  
   202  	t.Run("remove handler", func(t *testing.T) {
   203  		mux := NewRespondMux()
   204  		mux.Handle("foo", HandlerFunc(func(r Responder, c *Call) {
   205  			r.Return("foo")
   206  		}))
   207  
   208  		client, _ := newTestPair(mux)
   209  		defer client.Close()
   210  
   211  		_, err := client.Call(ctx, "foo", nil, nil)
   212  		fatal(t, err)
   213  
   214  		mux.Remove("foo")
   215  
   216  		_, err = client.Call(ctx, "foo", nil, nil)
   217  		if err == nil {
   218  			t.Fatal("expected error")
   219  		}
   220  	})
   221  
   222  	t.Run("bad handler: nil", func(t *testing.T) {
   223  		defer func() {
   224  			if r := recover(); r == nil {
   225  				t.Errorf("did not panic from nil handler")
   226  			}
   227  		}()
   228  		mux := NewRespondMux()
   229  		mux.Handle("foo.bar", nil)
   230  	})
   231  
   232  	t.Run("bad handle: exists", func(t *testing.T) {
   233  		defer func() {
   234  			if r := recover(); r == nil {
   235  				t.Errorf("did not panic from existing handle")
   236  			}
   237  		}()
   238  		mux := NewRespondMux()
   239  		mux.Handle("foo", NotFoundHandler())
   240  		mux.Handle("foo", NotFoundHandler())
   241  	})
   242  }
   243  
   244  func TestRPC(t *testing.T) {
   245  	ctx := context.Background()
   246  
   247  	t.Run("unary rpc", func(t *testing.T) {
   248  		client, _ := newTestPair(HandlerFunc(func(r Responder, c *Call) {
   249  			var in string
   250  			fatal(t, c.Receive(&in))
   251  			r.Return(in)
   252  		}))
   253  		defer client.Close()
   254  
   255  		var out string
   256  		resp, err := client.Call(ctx, "", "Hello world", &out)
   257  		fatal(t, err)
   258  		if resp.Continue() {
   259  			t.Fatal("unexpected continue")
   260  		}
   261  		if out != "Hello world" {
   262  			t.Fatalf("unexpected return: %#v", out)
   263  		}
   264  	})
   265  
   266  	t.Run("unary rpc remote error", func(t *testing.T) {
   267  		client, _ := newTestPair(HandlerFunc(func(r Responder, c *Call) {
   268  			var in interface{}
   269  			fatal(t, c.Receive(&in))
   270  			r.Return(fmt.Errorf("internal server error"))
   271  		}))
   272  		defer client.Close()
   273  
   274  		var out string
   275  		_, err := client.Call(ctx, "", "Hello world", &out)
   276  		if err == nil {
   277  			t.Fatal("expected error")
   278  		}
   279  		if err != nil {
   280  			rErr, ok := err.(RemoteError)
   281  			if !ok {
   282  				t.Fatal("unexpected error:", err)
   283  			}
   284  			if rErr.Error() != "remote: internal server error" {
   285  				t.Fatal("unexpected error:", rErr)
   286  			}
   287  		}
   288  	})
   289  
   290  	t.Run("multi-return rpc", func(t *testing.T) {
   291  		client, _ := newTestPair(HandlerFunc(func(r Responder, c *Call) {
   292  			var in string
   293  			fatal(t, c.Receive(&in))
   294  			r.Return(in, strings.ToUpper(in))
   295  		}))
   296  		defer client.Close()
   297  
   298  		var out, out2 string
   299  		resp, err := client.Call(ctx, "", "Hello world", &out, &out2)
   300  		fatal(t, err)
   301  		if resp.Continue() {
   302  			t.Fatal("unexpected continue")
   303  		}
   304  		if out != "Hello world" {
   305  			t.Errorf("unexpected return 1: %#v", out)
   306  		}
   307  		if out2 != "HELLO WORLD" {
   308  			t.Errorf("unexpected return 2: %#v", out)
   309  		}
   310  	})
   311  
   312  	t.Run("server streaming rpc", func(t *testing.T) {
   313  		client, _ := newTestPair(HandlerFunc(func(r Responder, c *Call) {
   314  			var in string
   315  			fatal(t, c.Receive(&in))
   316  			_, err := r.Continue(nil)
   317  			fatal(t, err)
   318  			fatal(t, r.Send(in))
   319  			fatal(t, r.Send(in))
   320  			fatal(t, r.Send(in))
   321  		}))
   322  		defer client.Close()
   323  
   324  		resp, err := client.Call(ctx, "", "Hello world", nil)
   325  		fatal(t, err)
   326  		if !resp.Continue() {
   327  			t.Fatal("expected continue")
   328  		}
   329  		for i := 0; i < 3; i++ {
   330  			var rcv string
   331  			fatal(t, resp.Receive(&rcv))
   332  			if rcv != "Hello world" {
   333  				t.Fatalf("unexpected receive [%d]: %#v", i, rcv)
   334  			}
   335  		}
   336  
   337  	})
   338  
   339  	t.Run("client streaming rpc", func(t *testing.T) {
   340  		client, _ := newTestPair(HandlerFunc(func(r Responder, c *Call) {
   341  			for i := 0; i < 3; i++ {
   342  				var rcv string
   343  				fatal(t, c.Receive(&rcv))
   344  				if rcv != "Hello world" {
   345  					t.Fatalf("unexpected server receive [%d]: %#v", i, rcv)
   346  				}
   347  			}
   348  		}))
   349  		defer client.Close()
   350  
   351  		sender := make(chan interface{})
   352  		go func() {
   353  			for i := 0; i < 3; i++ {
   354  				sender <- "Hello world"
   355  			}
   356  			close(sender)
   357  		}()
   358  		_, err := client.Call(ctx, "", sender, nil)
   359  		fatal(t, err)
   360  
   361  	})
   362  
   363  	t.Run("bidirectional streaming rpc", func(t *testing.T) {
   364  		client, _ := newTestPair(HandlerFunc(func(r Responder, c *Call) {
   365  			var rcv string
   366  			for i := 0; i < 3; i++ {
   367  				fatal(t, c.Receive(&rcv))
   368  				if rcv != "Hello world" {
   369  					t.Fatalf("unexpected server receive [%d]: %#v", i, rcv)
   370  				}
   371  			}
   372  			_, err := r.Continue(nil)
   373  			fatal(t, err)
   374  			fatal(t, r.Send(rcv))
   375  			fatal(t, r.Send(rcv))
   376  			fatal(t, r.Send(rcv))
   377  		}))
   378  		defer client.Close()
   379  
   380  		sender := make(chan interface{})
   381  		go func() {
   382  			for i := 0; i < 3; i++ {
   383  				sender <- "Hello world"
   384  			}
   385  			close(sender)
   386  		}()
   387  		resp, err := client.Call(ctx, "", sender, nil)
   388  		fatal(t, err)
   389  		if !resp.Continue() {
   390  			t.Fatal("expected continue")
   391  		}
   392  		for i := 0; i < 3; i++ {
   393  			var rcv string
   394  			fatal(t, resp.Receive(&rcv))
   395  			if rcv != "Hello world" {
   396  				t.Fatalf("unexpected client receive [%d]: %#v", i, rcv)
   397  			}
   398  		}
   399  	})
   400  
   401  	t.Run("bidirectional channel byte stream", func(t *testing.T) {
   402  		client, _ := newTestPair(HandlerFunc(func(r Responder, c *Call) {
   403  			fatal(t, c.Receive(nil))
   404  			ch, err := r.Continue(nil)
   405  			fatal(t, err)
   406  			io.Copy(ch, ch)
   407  			ch.Close()
   408  		}))
   409  		defer client.Close()
   410  
   411  		resp, err := client.Call(ctx, "", nil, nil)
   412  		fatal(t, err)
   413  		if !resp.Continue() {
   414  			t.Fatal("expected continue")
   415  		}
   416  		_, err = io.WriteString(resp.Channel, "Hello world")
   417  		fatal(t, err)
   418  		fatal(t, resp.Channel.CloseWrite())
   419  		b, err := ioutil.ReadAll(resp.Channel)
   420  		fatal(t, err)
   421  		if string(b) != "Hello world" {
   422  			t.Fatalf("unexpected data: %#v", b)
   423  		}
   424  	})
   425  
   426  	t.Run("bidirectional channel codec stream", func(t *testing.T) {
   427  		client, _ := newTestPair(HandlerFunc(func(r Responder, c *Call) {
   428  			fatal(t, c.Receive(nil))
   429  			_, err := r.Continue(nil)
   430  			fatal(t, err)
   431  
   432  			var rcv string
   433  			for i := 0; i < 3; i++ {
   434  				fatal(t, c.Receive(&rcv))
   435  				if rcv != "Hello world" {
   436  					t.Fatalf("unexpected server receive [%d]: %#v", i, rcv)
   437  				}
   438  			}
   439  			fatal(t, r.Send(rcv))
   440  			fatal(t, r.Send(rcv))
   441  			fatal(t, r.Send(rcv))
   442  		}))
   443  		defer client.Close()
   444  
   445  		resp, err := client.Call(ctx, "", nil, nil)
   446  		fatal(t, err)
   447  		if !resp.Continue() {
   448  			t.Fatal("expected continue")
   449  		}
   450  		fatal(t, resp.Send("Hello world"))
   451  		fatal(t, resp.Send("Hello world"))
   452  		fatal(t, resp.Send("Hello world"))
   453  		for i := 0; i < 3; i++ {
   454  			var rcv string
   455  			fatal(t, resp.Receive(&rcv))
   456  			if rcv != "Hello world" {
   457  				t.Fatalf("unexpected client receive [%d]: %#v", i, rcv)
   458  			}
   459  		}
   460  	})
   461  
   462  	t.Run("call timeout", func(t *testing.T) {
   463  		client, _ := newTestPair(HandlerFunc(func(r Responder, c *Call) {
   464  			time.Sleep(200 * time.Millisecond)
   465  			fatal(t, c.Receive(nil))
   466  			_, err := r.Continue(nil)
   467  			fatal(t, err)
   468  
   469  			var rcv string
   470  			for i := 0; i < 3; i++ {
   471  				fatal(t, c.Receive(&rcv))
   472  				if rcv != "Hello world" {
   473  					t.Fatalf("unexpected server receive [%d]: %#v", i, rcv)
   474  				}
   475  			}
   476  			fatal(t, r.Send(rcv))
   477  			fatal(t, r.Send(rcv))
   478  			fatal(t, r.Send(rcv))
   479  		}))
   480  		defer client.Close()
   481  
   482  		ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
   483  		defer cancel()
   484  
   485  		_, err := client.Call(ctx, "", []any{"foo"}, nil)
   486  		expectedError := "context deadline exceeded"
   487  		if fmt.Sprintf("%v", err) != expectedError {
   488  			t.Fatalf("expected error: %v\ngot: %v", expectedError, err)
   489  		}
   490  	})
   491  
   492  }