github.com/v2fly/tools@v0.100.0/internal/jsonrpc2_v2/jsonrpc2_test.go (about)

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package jsonrpc2_test
     6  
     7  import (
     8  	"context"
     9  	"encoding/json"
    10  	"fmt"
    11  	"path"
    12  	"reflect"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/v2fly/tools/internal/event/export/eventtest"
    17  	jsonrpc2 "github.com/v2fly/tools/internal/jsonrpc2_v2"
    18  	"github.com/v2fly/tools/internal/stack/stacktest"
    19  	errors "golang.org/x/xerrors"
    20  )
    21  
    22  var callTests = []invoker{
    23  	call{"no_args", nil, true},
    24  	call{"one_string", "fish", "got:fish"},
    25  	call{"one_number", 10, "got:10"},
    26  	call{"join", []string{"a", "b", "c"}, "a/b/c"},
    27  	sequence{"notify", []invoker{
    28  		notify{"set", 3},
    29  		notify{"add", 5},
    30  		call{"get", nil, 8},
    31  	}},
    32  	sequence{"preempt", []invoker{
    33  		async{"a", "wait", "a"},
    34  		notify{"unblock", "a"},
    35  		collect{"a", true, false},
    36  	}},
    37  	sequence{"basic cancel", []invoker{
    38  		async{"b", "wait", "b"},
    39  		cancel{"b"},
    40  		collect{"b", nil, true},
    41  	}},
    42  	sequence{"queue", []invoker{
    43  		async{"a", "wait", "a"},
    44  		notify{"set", 1},
    45  		notify{"add", 2},
    46  		notify{"add", 3},
    47  		notify{"add", 4},
    48  		call{"peek", nil, 0}, // accumulator will not have any adds yet
    49  		notify{"unblock", "a"},
    50  		collect{"a", true, false},
    51  		call{"get", nil, 10}, // accumulator now has all the adds
    52  	}},
    53  	sequence{"fork", []invoker{
    54  		async{"a", "fork", "a"},
    55  		notify{"set", 1},
    56  		notify{"add", 2},
    57  		notify{"add", 3},
    58  		notify{"add", 4},
    59  		call{"get", nil, 10}, // fork will not have blocked the adds
    60  		notify{"unblock", "a"},
    61  		collect{"a", true, false},
    62  	}},
    63  }
    64  
    65  type binder struct {
    66  	framer  jsonrpc2.Framer
    67  	runTest func(*handler)
    68  }
    69  
    70  type handler struct {
    71  	conn        *jsonrpc2.Connection
    72  	accumulator int
    73  	waitersBox  chan map[string]chan struct{}
    74  	calls       map[string]*jsonrpc2.AsyncCall
    75  }
    76  
    77  type invoker interface {
    78  	Name() string
    79  	Invoke(t *testing.T, ctx context.Context, h *handler)
    80  }
    81  
    82  type notify struct {
    83  	method string
    84  	params interface{}
    85  }
    86  
    87  type call struct {
    88  	method string
    89  	params interface{}
    90  	expect interface{}
    91  }
    92  
    93  type async struct {
    94  	name   string
    95  	method string
    96  	params interface{}
    97  }
    98  
    99  type collect struct {
   100  	name   string
   101  	expect interface{}
   102  	fails  bool
   103  }
   104  
   105  type cancel struct {
   106  	name string
   107  }
   108  
   109  type sequence struct {
   110  	name  string
   111  	tests []invoker
   112  }
   113  
   114  type echo call
   115  
   116  type cancelParams struct{ ID int64 }
   117  
   118  func TestConnectionRaw(t *testing.T) {
   119  	testConnection(t, jsonrpc2.RawFramer())
   120  }
   121  
   122  func TestConnectionHeader(t *testing.T) {
   123  	testConnection(t, jsonrpc2.HeaderFramer())
   124  }
   125  
   126  func testConnection(t *testing.T, framer jsonrpc2.Framer) {
   127  	stacktest.NoLeak(t)
   128  	ctx := eventtest.NewContext(context.Background(), t)
   129  	listener, err := jsonrpc2.NetPipe(ctx)
   130  	if err != nil {
   131  		t.Fatal(err)
   132  	}
   133  	server, err := jsonrpc2.Serve(ctx, listener, binder{framer, nil})
   134  	if err != nil {
   135  		t.Fatal(err)
   136  	}
   137  	defer func() {
   138  		listener.Close()
   139  		server.Wait()
   140  	}()
   141  
   142  	for _, test := range callTests {
   143  		t.Run(test.Name(), func(t *testing.T) {
   144  			client, err := jsonrpc2.Dial(ctx,
   145  				listener.Dialer(), binder{framer, func(h *handler) {
   146  					defer h.conn.Close()
   147  					ctx := eventtest.NewContext(ctx, t)
   148  					test.Invoke(t, ctx, h)
   149  					if call, ok := test.(*call); ok {
   150  						// also run all simple call tests in echo mode
   151  						(*echo)(call).Invoke(t, ctx, h)
   152  					}
   153  				}})
   154  			if err != nil {
   155  				t.Fatal(err)
   156  			}
   157  			client.Wait()
   158  		})
   159  	}
   160  }
   161  
   162  func (test notify) Name() string { return test.method }
   163  func (test notify) Invoke(t *testing.T, ctx context.Context, h *handler) {
   164  	if err := h.conn.Notify(ctx, test.method, test.params); err != nil {
   165  		t.Fatalf("%v:Notify failed: %v", test.method, err)
   166  	}
   167  }
   168  
   169  func (test call) Name() string { return test.method }
   170  func (test call) Invoke(t *testing.T, ctx context.Context, h *handler) {
   171  	results := newResults(test.expect)
   172  	if err := h.conn.Call(ctx, test.method, test.params).Await(ctx, results); err != nil {
   173  		t.Fatalf("%v:Call failed: %v", test.method, err)
   174  	}
   175  	verifyResults(t, test.method, results, test.expect)
   176  }
   177  
   178  func (test echo) Invoke(t *testing.T, ctx context.Context, h *handler) {
   179  	results := newResults(test.expect)
   180  	if err := h.conn.Call(ctx, "echo", []interface{}{test.method, test.params}).Await(ctx, results); err != nil {
   181  		t.Fatalf("%v:Echo failed: %v", test.method, err)
   182  	}
   183  	verifyResults(t, test.method, results, test.expect)
   184  }
   185  
   186  func (test async) Name() string { return test.name }
   187  func (test async) Invoke(t *testing.T, ctx context.Context, h *handler) {
   188  	h.calls[test.name] = h.conn.Call(ctx, test.method, test.params)
   189  }
   190  
   191  func (test collect) Name() string { return test.name }
   192  func (test collect) Invoke(t *testing.T, ctx context.Context, h *handler) {
   193  	o := h.calls[test.name]
   194  	results := newResults(test.expect)
   195  	err := o.Await(ctx, results)
   196  	switch {
   197  	case test.fails && err == nil:
   198  		t.Fatalf("%v:Collect was supposed to fail", test.name)
   199  	case !test.fails && err != nil:
   200  		t.Fatalf("%v:Collect failed: %v", test.name, err)
   201  	}
   202  	verifyResults(t, test.name, results, test.expect)
   203  }
   204  
   205  func (test cancel) Name() string { return test.name }
   206  func (test cancel) Invoke(t *testing.T, ctx context.Context, h *handler) {
   207  	o := h.calls[test.name]
   208  	if err := h.conn.Notify(ctx, "cancel", &cancelParams{o.ID().Raw().(int64)}); err != nil {
   209  		t.Fatalf("%v:Collect failed: %v", test.name, err)
   210  	}
   211  }
   212  
   213  func (test sequence) Name() string { return test.name }
   214  func (test sequence) Invoke(t *testing.T, ctx context.Context, h *handler) {
   215  	for _, child := range test.tests {
   216  		child.Invoke(t, ctx, h)
   217  	}
   218  }
   219  
   220  // newResults makes a new empty copy of the expected type to put the results into
   221  func newResults(expect interface{}) interface{} {
   222  	switch e := expect.(type) {
   223  	case []interface{}:
   224  		var r []interface{}
   225  		for _, v := range e {
   226  			r = append(r, reflect.New(reflect.TypeOf(v)).Interface())
   227  		}
   228  		return r
   229  	case nil:
   230  		return nil
   231  	default:
   232  		return reflect.New(reflect.TypeOf(expect)).Interface()
   233  	}
   234  }
   235  
   236  // verifyResults compares the results to the expected values
   237  func verifyResults(t *testing.T, method string, results interface{}, expect interface{}) {
   238  	if expect == nil {
   239  		if results != nil {
   240  			t.Errorf("%v:Got results %+v where none expeted", method, expect)
   241  		}
   242  		return
   243  	}
   244  	val := reflect.Indirect(reflect.ValueOf(results)).Interface()
   245  	if !reflect.DeepEqual(val, expect) {
   246  		t.Errorf("%v:Results are incorrect, got %+v expect %+v", method, val, expect)
   247  	}
   248  }
   249  
   250  func (b binder) Bind(ctx context.Context, conn *jsonrpc2.Connection) (jsonrpc2.ConnectionOptions, error) {
   251  	h := &handler{
   252  		conn:       conn,
   253  		waitersBox: make(chan map[string]chan struct{}, 1),
   254  		calls:      make(map[string]*jsonrpc2.AsyncCall),
   255  	}
   256  	h.waitersBox <- make(map[string]chan struct{})
   257  	if b.runTest != nil {
   258  		go b.runTest(h)
   259  	}
   260  	return jsonrpc2.ConnectionOptions{
   261  		Framer:    b.framer,
   262  		Preempter: h,
   263  		Handler:   h,
   264  	}, nil
   265  }
   266  
   267  func (h *handler) waiter(name string) chan struct{} {
   268  	waiters := <-h.waitersBox
   269  	defer func() { h.waitersBox <- waiters }()
   270  	waiter, found := waiters[name]
   271  	if !found {
   272  		waiter = make(chan struct{})
   273  		waiters[name] = waiter
   274  	}
   275  	return waiter
   276  }
   277  
   278  func (h *handler) Preempt(ctx context.Context, req *jsonrpc2.Request) (interface{}, error) {
   279  	switch req.Method {
   280  	case "unblock":
   281  		var name string
   282  		if err := json.Unmarshal(req.Params, &name); err != nil {
   283  			return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   284  		}
   285  		close(h.waiter(name))
   286  		return nil, nil
   287  	case "peek":
   288  		if len(req.Params) > 0 {
   289  			return nil, errors.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams)
   290  		}
   291  		return h.accumulator, nil
   292  	case "cancel":
   293  		var params cancelParams
   294  		if err := json.Unmarshal(req.Params, &params); err != nil {
   295  			return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   296  		}
   297  		h.conn.Cancel(jsonrpc2.Int64ID(params.ID))
   298  		return nil, nil
   299  	default:
   300  		return nil, jsonrpc2.ErrNotHandled
   301  	}
   302  }
   303  
   304  func (h *handler) Handle(ctx context.Context, req *jsonrpc2.Request) (interface{}, error) {
   305  	switch req.Method {
   306  	case "no_args":
   307  		if len(req.Params) > 0 {
   308  			return nil, errors.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams)
   309  		}
   310  		return true, nil
   311  	case "one_string":
   312  		var v string
   313  		if err := json.Unmarshal(req.Params, &v); err != nil {
   314  			return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   315  		}
   316  		return "got:" + v, nil
   317  	case "one_number":
   318  		var v int
   319  		if err := json.Unmarshal(req.Params, &v); err != nil {
   320  			return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   321  		}
   322  		return fmt.Sprintf("got:%d", v), nil
   323  	case "set":
   324  		var v int
   325  		if err := json.Unmarshal(req.Params, &v); err != nil {
   326  			return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   327  		}
   328  		h.accumulator = v
   329  		return nil, nil
   330  	case "add":
   331  		var v int
   332  		if err := json.Unmarshal(req.Params, &v); err != nil {
   333  			return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   334  		}
   335  		h.accumulator += v
   336  		return nil, nil
   337  	case "get":
   338  		if len(req.Params) > 0 {
   339  			return nil, errors.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams)
   340  		}
   341  		return h.accumulator, nil
   342  	case "join":
   343  		var v []string
   344  		if err := json.Unmarshal(req.Params, &v); err != nil {
   345  			return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   346  		}
   347  		return path.Join(v...), nil
   348  	case "echo":
   349  		var v []interface{}
   350  		if err := json.Unmarshal(req.Params, &v); err != nil {
   351  			return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   352  		}
   353  		var result interface{}
   354  		err := h.conn.Call(ctx, v[0].(string), v[1]).Await(ctx, &result)
   355  		return result, err
   356  	case "wait":
   357  		var name string
   358  		if err := json.Unmarshal(req.Params, &name); err != nil {
   359  			return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   360  		}
   361  		select {
   362  		case <-h.waiter(name):
   363  			return true, nil
   364  		case <-ctx.Done():
   365  			return nil, ctx.Err()
   366  		case <-time.After(time.Second):
   367  			return nil, errors.Errorf("wait for %q timed out", name)
   368  		}
   369  	case "fork":
   370  		var name string
   371  		if err := json.Unmarshal(req.Params, &name); err != nil {
   372  			return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   373  		}
   374  		waitFor := h.waiter(name)
   375  		go func() {
   376  			select {
   377  			case <-waitFor:
   378  				h.conn.Respond(req.ID, true, nil)
   379  			case <-ctx.Done():
   380  				h.conn.Respond(req.ID, nil, ctx.Err())
   381  			case <-time.After(time.Second):
   382  				h.conn.Respond(req.ID, nil, errors.Errorf("wait for %q timed out", name))
   383  			}
   384  		}()
   385  		return nil, jsonrpc2.ErrAsyncResponse
   386  	default:
   387  		return nil, jsonrpc2.ErrNotHandled
   388  	}
   389  }