tractor.dev/toolkit-go@v0.0.0-20241010005851-214d91207d07/duplex/fn/handler_test.go (about)

     1  package fn
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"reflect"
     8  	"strings"
     9  	"testing"
    10  
    11  	"tractor.dev/toolkit-go/duplex/codec"
    12  	"tractor.dev/toolkit-go/duplex/rpc"
    13  	"tractor.dev/toolkit-go/duplex/rpc/rpctest"
    14  )
    15  
    16  func TestHandlerFromBadData(t *testing.T) {
    17  	defer func() {
    18  		if r := recover(); r == nil {
    19  			t.Errorf("did not panic from bad argument data")
    20  		}
    21  	}()
    22  	HandlerFrom(2)
    23  }
    24  
    25  type subfake struct {
    26  	A string
    27  }
    28  
    29  type fake struct {
    30  	A subfake
    31  	B int
    32  }
    33  
    34  type id int
    35  
    36  func TestHandlerFromFunc(t *testing.T) {
    37  	t.Run("int sum", func(t *testing.T) {
    38  		client, _ := rpctest.NewPair(HandlerFrom(func(a, b int) int {
    39  			return a + b
    40  		}), codec.JSONCodec{})
    41  		defer client.Close()
    42  
    43  		var sum int
    44  		if _, err := client.Call(context.Background(), "", []interface{}{2, 3}, &sum); err != nil {
    45  			t.Fatal(err)
    46  		}
    47  		if sum != 5 {
    48  			t.Fatalf("unexpected sum: %v", sum)
    49  		}
    50  	})
    51  
    52  	t.Run("defined type arg and return", func(t *testing.T) {
    53  		client, _ := rpctest.NewPair(HandlerFrom(func(a id) id {
    54  			return a
    55  		}), codec.JSONCodec{})
    56  		defer client.Close()
    57  
    58  		var ret id
    59  		if _, err := client.Call(context.Background(), "", Args{id(64)}, &ret); err != nil {
    60  			t.Fatal(err)
    61  		}
    62  		if ret != 64 {
    63  			t.Fatalf("unexpected return value: %v", ret)
    64  		}
    65  	})
    66  
    67  	t.Run("struct arguments", func(t *testing.T) {
    68  		client, _ := rpctest.NewPair(HandlerFrom(func(a fake, b subfake) {
    69  			if a.A.A != "Hello" {
    70  				t.Fatalf("unexpected field value in struct: %v", a)
    71  			}
    72  			if b.A != "world" {
    73  				t.Fatalf("unexpected field value in struct: %v", b)
    74  			}
    75  		}), codec.JSONCodec{})
    76  		defer client.Close()
    77  
    78  		if _, err := client.Call(context.Background(), "", Args{fake{A: subfake{A: "Hello"}}, subfake{A: "world"}}, nil); err != nil {
    79  			t.Fatal(err)
    80  		}
    81  	})
    82  
    83  	t.Run("nil error", func(t *testing.T) {
    84  		client, _ := rpctest.NewPair(HandlerFrom(func(a, b int) error {
    85  			return nil
    86  		}), codec.JSONCodec{})
    87  		defer client.Close()
    88  
    89  		if _, err := client.Call(context.Background(), "", []interface{}{2, 3}, nil); err != nil {
    90  			t.Fatal(err)
    91  		}
    92  	})
    93  
    94  	t.Run("not enough args", func(t *testing.T) {
    95  		client, _ := rpctest.NewPair(HandlerFrom(func(a, b int) int {
    96  			return a + b
    97  		}), codec.JSONCodec{})
    98  		defer client.Close()
    99  
   100  		var sum int
   101  		_, err := client.Call(context.Background(), "", []interface{}{2}, &sum)
   102  		if err == nil || !strings.Contains(err.Error(), "expected 2 params") {
   103  			t.Fatalf("unexpected error: %v", err)
   104  		}
   105  	})
   106  
   107  	t.Run("too many args", func(t *testing.T) {
   108  		client, _ := rpctest.NewPair(HandlerFrom(func(a, b int) int {
   109  			return a + b
   110  		}), codec.JSONCodec{})
   111  		defer client.Close()
   112  
   113  		var sum int
   114  		_, err := client.Call(context.Background(), "", []interface{}{2, 3, 5}, &sum)
   115  		if err == nil || !strings.Contains(err.Error(), "expected 2 params") {
   116  			t.Fatalf("unexpected error: %v", err)
   117  		}
   118  	})
   119  
   120  	t.Run("with call", func(t *testing.T) {
   121  		client, _ := rpctest.NewPair(HandlerFrom(func(a, b int, call *rpc.Call) int {
   122  			if call.Selector() != "/sum" {
   123  				t.Fatalf("unexpected selector: %v", call.Selector())
   124  			}
   125  			return a + b
   126  		}), codec.JSONCodec{})
   127  		defer client.Close()
   128  
   129  		var sum int
   130  		if _, err := client.Call(context.Background(), "sum", []interface{}{2, 3}, &sum); err != nil {
   131  			t.Fatal(err)
   132  		}
   133  		if sum != 5 {
   134  			t.Fatalf("unexpected sum: %v", sum)
   135  		}
   136  	})
   137  
   138  	t.Run("return error", func(t *testing.T) {
   139  		client, _ := rpctest.NewPair(HandlerFrom(func(a, b int) error {
   140  			return errors.New("test")
   141  		}), codec.JSONCodec{})
   142  		defer client.Close()
   143  
   144  		var sum int
   145  		_, err := client.Call(context.Background(), "", []interface{}{2, 3}, &sum)
   146  		if err == nil || !strings.Contains(err.Error(), "test") {
   147  			t.Fatalf("unexpected error: %v", err)
   148  		}
   149  	})
   150  
   151  	t.Run("return error with value", func(t *testing.T) {
   152  		client, _ := rpctest.NewPair(HandlerFrom(func(a, b int) (int, error) {
   153  			return a + b, errors.New("test")
   154  		}), codec.JSONCodec{})
   155  		defer client.Close()
   156  
   157  		var sum int
   158  		_, err := client.Call(context.Background(), "", []interface{}{2, 3}, &sum)
   159  		if err == nil || !strings.Contains(err.Error(), "test") {
   160  			t.Fatalf("unexpected error: %v", err)
   161  		}
   162  	})
   163  
   164  	t.Run("no return", func(t *testing.T) {
   165  		client, _ := rpctest.NewPair(HandlerFrom(func(a, b int) {
   166  			return
   167  		}), codec.JSONCodec{})
   168  		defer client.Close()
   169  
   170  		var sum int
   171  		_, err := client.Call(context.Background(), "", []interface{}{2, 3}, &sum)
   172  		if err != nil {
   173  			t.Fatalf("unexpected error: %v", err)
   174  		}
   175  	})
   176  
   177  	t.Run("channel return value stream", func(t *testing.T) {
   178  		client, _ := rpctest.NewPair(HandlerFrom(func() (chan string, string) {
   179  			ch := make(chan string)
   180  			go func() {
   181  				ch <- "one"
   182  				ch <- "two"
   183  				close(ch)
   184  			}()
   185  			return ch, "ok"
   186  		}), codec.JSONCodec{})
   187  		defer client.Close()
   188  
   189  		ch := make(chan string)
   190  		ctx := context.Background()
   191  		var ret string
   192  		resp, err := client.Call(ctx, "", nil, &ret)
   193  		if err != nil {
   194  			t.Fatalf("unexpected error: %v", err)
   195  		}
   196  		if ret != "ok" {
   197  			t.Fatalf("unexpected ret: %v", ret)
   198  		}
   199  		go rpc.ReceiveNotify(ctx, resp, ch)
   200  		var vals []string
   201  		for s := range ch {
   202  			vals = append(vals, s)
   203  		}
   204  		if !reflect.DeepEqual(vals, []string{"one", "two"}) {
   205  			t.Fatalf("unexpected streamed values: %v", vals)
   206  		}
   207  
   208  	})
   209  
   210  	t.Run("channel arg stream", func(t *testing.T) {
   211  		client, _ := rpctest.NewPair(HandlerFrom(func(ch chan string) string {
   212  			go func() {
   213  				ch <- "one"
   214  				ch <- "two"
   215  				close(ch)
   216  			}()
   217  			return "ok"
   218  		}), codec.JSONCodec{})
   219  		defer client.Close()
   220  
   221  		ch := make(chan string)
   222  		ctx := context.Background()
   223  		var ret string
   224  		resp, err := client.Call(ctx, "", nil, &ret)
   225  		if err != nil {
   226  			t.Fatalf("unexpected error: %v", err)
   227  		}
   228  		if ret != "ok" {
   229  			t.Fatalf("unexpected ret: %v", ret)
   230  		}
   231  		go rpc.ReceiveNotify(ctx, resp, ch)
   232  		var vals []string
   233  		for s := range ch {
   234  			vals = append(vals, s)
   235  		}
   236  		if !reflect.DeepEqual(vals, []string{"one", "two"}) {
   237  			t.Fatalf("unexpected streamed values: %v", vals)
   238  		}
   239  
   240  	})
   241  
   242  }
   243  
   244  type mockMethods struct{}
   245  
   246  func (m *mockMethods) Foo() string {
   247  	return "Foo"
   248  }
   249  
   250  func (m *mockMethods) Bar() {}
   251  
   252  func TestHandlerFromMethods(t *testing.T) {
   253  	handler := HandlerFrom(&mockMethods{})
   254  	mux, ok := handler.(*rpc.RespondMux)
   255  	if !ok {
   256  		t.Fatal("expected handler to be rpc.RespondMux")
   257  	}
   258  	h, _ := mux.Match("Foo")
   259  	if h == nil {
   260  		t.Fatal("expected Foo handler")
   261  	}
   262  	h, _ = mux.Match("Bar")
   263  	if h == nil {
   264  		t.Fatal("expected Bar handler")
   265  	}
   266  
   267  	client, _ := rpctest.NewPair(mux, codec.JSONCodec{})
   268  	defer client.Close()
   269  
   270  	var ret string
   271  	if _, err := client.Call(context.Background(), "Foo", nil, &ret); err != nil {
   272  		t.Fatal(err)
   273  	}
   274  	if ret != "Foo" {
   275  		t.Fatalf("unexpected ret: %v", ret)
   276  	}
   277  }
   278  
   279  func TestHandlerFromMethodsInterface(t *testing.T) {
   280  	handler := HandlerFrom[interface {
   281  		Foo() string
   282  	}](&mockMethods{})
   283  	mux, ok := handler.(*rpc.RespondMux)
   284  	if !ok {
   285  		t.Fatal("expected handler to be rpc.RespondMux")
   286  	}
   287  	h, _ := mux.Match("Foo")
   288  	if h == nil {
   289  		t.Fatal("expected Foo handler")
   290  	}
   291  	h, _ = mux.Match("Bar")
   292  	if h != nil {
   293  		t.Fatal("expected no handler for Bar method not on interface")
   294  	}
   295  
   296  	client, _ := rpctest.NewPair(mux, codec.JSONCodec{})
   297  	defer client.Close()
   298  
   299  	var ret string
   300  	if _, err := client.Call(context.Background(), "Foo", nil, &ret); err != nil {
   301  		t.Fatal(err)
   302  	}
   303  	if ret != "Foo" {
   304  		t.Fatalf("unexpected ret: %v", ret)
   305  	}
   306  }
   307  
   308  func TestHandlerFromMethodsInterfaceDifferentMethod(t *testing.T) {
   309  	// Also check a different method to ensure that the reflection code is
   310  	// matching the correct method based on the interface and not just getting
   311  	// Method(0) which matches up in the first test.
   312  	handler := HandlerFrom[interface {
   313  		Bar()
   314  	}](&mockMethods{})
   315  	mux, ok := handler.(*rpc.RespondMux)
   316  	if !ok {
   317  		t.Fatal("expected handler to be rpc.RespondMux")
   318  	}
   319  	h, _ := mux.Match("Bar")
   320  	if h == nil {
   321  		t.Fatal("expected Bar handler")
   322  	}
   323  	h, _ = mux.Match("Foo")
   324  	if h != nil {
   325  		t.Fatal("expected no handler for Foo method not on interface")
   326  	}
   327  
   328  	client, _ := rpctest.NewPair(mux, codec.JSONCodec{})
   329  	defer client.Close()
   330  
   331  	var ret string
   332  	if _, err := client.Call(context.Background(), "Bar", nil, &ret); err != nil {
   333  		t.Fatal(err)
   334  	}
   335  	if ret != "" {
   336  		t.Fatalf("unexpected ret: %v", ret)
   337  	}
   338  }
   339  
   340  type handlerFuncMethod struct{}
   341  
   342  func (*handlerFuncMethod) Bar(r rpc.Responder, c *rpc.Call) {
   343  	var args []any
   344  	if err := c.Receive(&args); err != nil {
   345  		r.Return(fmt.Errorf("fn: args: %s", err.Error()))
   346  		return
   347  	}
   348  
   349  	r.Return("returned from Responder")
   350  }
   351  
   352  func TestMethodHandlerFunc(t *testing.T) {
   353  	handler := HandlerFrom(&handlerFuncMethod{})
   354  	client, _ := rpctest.NewPair(handler, codec.JSONCodec{})
   355  	defer client.Close()
   356  
   357  	var ret string
   358  	if _, err := client.Call(context.Background(), "Bar", nil, &ret); err != nil {
   359  		t.Fatal(err)
   360  	}
   361  	if ret != "returned from Responder" {
   362  		t.Fatalf("unexpected ret: %v", ret)
   363  	}
   364  }