github.com/powerman/golang-tools@v0.1.11-0.20220410185822-5ad214d8d803/internal/jsonrpc2_v2/jsonrpc2_test.go (about)

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package jsonrpc2_test
     6  
     7  import (
     8  	"context"
     9  	"encoding/json"
    10  	"fmt"
    11  	"path"
    12  	"reflect"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/powerman/golang-tools/internal/event/export/eventtest"
    17  	jsonrpc2 "github.com/powerman/golang-tools/internal/jsonrpc2_v2"
    18  	"github.com/powerman/golang-tools/internal/stack/stacktest"
    19  	errors "golang.org/x/xerrors"
    20  )
    21  
    22  var callTests = []invoker{
    23  	call{"no_args", nil, true},
    24  	call{"one_string", "fish", "got:fish"},
    25  	call{"one_number", 10, "got:10"},
    26  	call{"join", []string{"a", "b", "c"}, "a/b/c"},
    27  	sequence{"notify", []invoker{
    28  		notify{"set", 3},
    29  		notify{"add", 5},
    30  		call{"get", nil, 8},
    31  	}},
    32  	sequence{"preempt", []invoker{
    33  		async{"a", "wait", "a"},
    34  		notify{"unblock", "a"},
    35  		collect{"a", true, false},
    36  	}},
    37  	sequence{"basic cancel", []invoker{
    38  		async{"b", "wait", "b"},
    39  		cancel{"b"},
    40  		collect{"b", nil, true},
    41  	}},
    42  	sequence{"queue", []invoker{
    43  		async{"a", "wait", "a"},
    44  		notify{"set", 1},
    45  		notify{"add", 2},
    46  		notify{"add", 3},
    47  		notify{"add", 4},
    48  		call{"peek", nil, 0}, // accumulator will not have any adds yet
    49  		notify{"unblock", "a"},
    50  		collect{"a", true, false},
    51  		call{"get", nil, 10}, // accumulator now has all the adds
    52  	}},
    53  	sequence{"fork", []invoker{
    54  		async{"a", "fork", "a"},
    55  		notify{"set", 1},
    56  		notify{"add", 2},
    57  		notify{"add", 3},
    58  		notify{"add", 4},
    59  		call{"get", nil, 10}, // fork will not have blocked the adds
    60  		notify{"unblock", "a"},
    61  		collect{"a", true, false},
    62  	}},
    63  	sequence{"concurrent", []invoker{
    64  		async{"a", "fork", "a"},
    65  		notify{"unblock", "a"},
    66  		async{"b", "fork", "b"},
    67  		notify{"unblock", "b"},
    68  		collect{"a", true, false},
    69  		collect{"b", true, false},
    70  	}},
    71  }
    72  
    73  type binder struct {
    74  	framer  jsonrpc2.Framer
    75  	runTest func(*handler)
    76  }
    77  
    78  type handler struct {
    79  	conn        *jsonrpc2.Connection
    80  	accumulator int
    81  	waitersBox  chan map[string]chan struct{}
    82  	calls       map[string]*jsonrpc2.AsyncCall
    83  }
    84  
    85  type invoker interface {
    86  	Name() string
    87  	Invoke(t *testing.T, ctx context.Context, h *handler)
    88  }
    89  
    90  type notify struct {
    91  	method string
    92  	params interface{}
    93  }
    94  
    95  type call struct {
    96  	method string
    97  	params interface{}
    98  	expect interface{}
    99  }
   100  
   101  type async struct {
   102  	name   string
   103  	method string
   104  	params interface{}
   105  }
   106  
   107  type collect struct {
   108  	name   string
   109  	expect interface{}
   110  	fails  bool
   111  }
   112  
   113  type cancel struct {
   114  	name string
   115  }
   116  
   117  type sequence struct {
   118  	name  string
   119  	tests []invoker
   120  }
   121  
   122  type echo call
   123  
   124  type cancelParams struct{ ID int64 }
   125  
   126  func TestConnectionRaw(t *testing.T) {
   127  	testConnection(t, jsonrpc2.RawFramer())
   128  }
   129  
   130  func TestConnectionHeader(t *testing.T) {
   131  	testConnection(t, jsonrpc2.HeaderFramer())
   132  }
   133  
   134  func testConnection(t *testing.T, framer jsonrpc2.Framer) {
   135  	stacktest.NoLeak(t)
   136  	ctx := eventtest.NewContext(context.Background(), t)
   137  	listener, err := jsonrpc2.NetPipeListener(ctx)
   138  	if err != nil {
   139  		t.Fatal(err)
   140  	}
   141  	server, err := jsonrpc2.Serve(ctx, listener, binder{framer, nil})
   142  	if err != nil {
   143  		t.Fatal(err)
   144  	}
   145  	defer func() {
   146  		listener.Close()
   147  		server.Wait()
   148  	}()
   149  
   150  	for _, test := range callTests {
   151  		t.Run(test.Name(), func(t *testing.T) {
   152  			client, err := jsonrpc2.Dial(ctx,
   153  				listener.Dialer(), binder{framer, func(h *handler) {
   154  					defer h.conn.Close()
   155  					ctx := eventtest.NewContext(ctx, t)
   156  					test.Invoke(t, ctx, h)
   157  					if call, ok := test.(*call); ok {
   158  						// also run all simple call tests in echo mode
   159  						(*echo)(call).Invoke(t, ctx, h)
   160  					}
   161  				}})
   162  			if err != nil {
   163  				t.Fatal(err)
   164  			}
   165  			client.Wait()
   166  		})
   167  	}
   168  }
   169  
   170  func (test notify) Name() string { return test.method }
   171  func (test notify) Invoke(t *testing.T, ctx context.Context, h *handler) {
   172  	if err := h.conn.Notify(ctx, test.method, test.params); err != nil {
   173  		t.Fatalf("%v:Notify failed: %v", test.method, err)
   174  	}
   175  }
   176  
   177  func (test call) Name() string { return test.method }
   178  func (test call) Invoke(t *testing.T, ctx context.Context, h *handler) {
   179  	results := newResults(test.expect)
   180  	if err := h.conn.Call(ctx, test.method, test.params).Await(ctx, results); err != nil {
   181  		t.Fatalf("%v:Call failed: %v", test.method, err)
   182  	}
   183  	verifyResults(t, test.method, results, test.expect)
   184  }
   185  
   186  func (test echo) Invoke(t *testing.T, ctx context.Context, h *handler) {
   187  	results := newResults(test.expect)
   188  	if err := h.conn.Call(ctx, "echo", []interface{}{test.method, test.params}).Await(ctx, results); err != nil {
   189  		t.Fatalf("%v:Echo failed: %v", test.method, err)
   190  	}
   191  	verifyResults(t, test.method, results, test.expect)
   192  }
   193  
   194  func (test async) Name() string { return test.name }
   195  func (test async) Invoke(t *testing.T, ctx context.Context, h *handler) {
   196  	h.calls[test.name] = h.conn.Call(ctx, test.method, test.params)
   197  }
   198  
   199  func (test collect) Name() string { return test.name }
   200  func (test collect) Invoke(t *testing.T, ctx context.Context, h *handler) {
   201  	o := h.calls[test.name]
   202  	results := newResults(test.expect)
   203  	err := o.Await(ctx, results)
   204  	switch {
   205  	case test.fails && err == nil:
   206  		t.Fatalf("%v:Collect was supposed to fail", test.name)
   207  	case !test.fails && err != nil:
   208  		t.Fatalf("%v:Collect failed: %v", test.name, err)
   209  	}
   210  	verifyResults(t, test.name, results, test.expect)
   211  }
   212  
   213  func (test cancel) Name() string { return test.name }
   214  func (test cancel) Invoke(t *testing.T, ctx context.Context, h *handler) {
   215  	o := h.calls[test.name]
   216  	if err := h.conn.Notify(ctx, "cancel", &cancelParams{o.ID().Raw().(int64)}); err != nil {
   217  		t.Fatalf("%v:Collect failed: %v", test.name, err)
   218  	}
   219  }
   220  
   221  func (test sequence) Name() string { return test.name }
   222  func (test sequence) Invoke(t *testing.T, ctx context.Context, h *handler) {
   223  	for _, child := range test.tests {
   224  		child.Invoke(t, ctx, h)
   225  	}
   226  }
   227  
   228  // newResults makes a new empty copy of the expected type to put the results into
   229  func newResults(expect interface{}) interface{} {
   230  	switch e := expect.(type) {
   231  	case []interface{}:
   232  		var r []interface{}
   233  		for _, v := range e {
   234  			r = append(r, reflect.New(reflect.TypeOf(v)).Interface())
   235  		}
   236  		return r
   237  	case nil:
   238  		return nil
   239  	default:
   240  		return reflect.New(reflect.TypeOf(expect)).Interface()
   241  	}
   242  }
   243  
   244  // verifyResults compares the results to the expected values
   245  func verifyResults(t *testing.T, method string, results interface{}, expect interface{}) {
   246  	if expect == nil {
   247  		if results != nil {
   248  			t.Errorf("%v:Got results %+v where none expeted", method, expect)
   249  		}
   250  		return
   251  	}
   252  	val := reflect.Indirect(reflect.ValueOf(results)).Interface()
   253  	if !reflect.DeepEqual(val, expect) {
   254  		t.Errorf("%v:Results are incorrect, got %+v expect %+v", method, val, expect)
   255  	}
   256  }
   257  
   258  func (b binder) Bind(ctx context.Context, conn *jsonrpc2.Connection) (jsonrpc2.ConnectionOptions, error) {
   259  	h := &handler{
   260  		conn:       conn,
   261  		waitersBox: make(chan map[string]chan struct{}, 1),
   262  		calls:      make(map[string]*jsonrpc2.AsyncCall),
   263  	}
   264  	h.waitersBox <- make(map[string]chan struct{})
   265  	if b.runTest != nil {
   266  		go b.runTest(h)
   267  	}
   268  	return jsonrpc2.ConnectionOptions{
   269  		Framer:    b.framer,
   270  		Preempter: h,
   271  		Handler:   h,
   272  	}, nil
   273  }
   274  
   275  func (h *handler) waiter(name string) chan struct{} {
   276  	waiters := <-h.waitersBox
   277  	defer func() { h.waitersBox <- waiters }()
   278  	waiter, found := waiters[name]
   279  	if !found {
   280  		waiter = make(chan struct{})
   281  		waiters[name] = waiter
   282  	}
   283  	return waiter
   284  }
   285  
   286  func (h *handler) Preempt(ctx context.Context, req *jsonrpc2.Request) (interface{}, error) {
   287  	switch req.Method {
   288  	case "unblock":
   289  		var name string
   290  		if err := json.Unmarshal(req.Params, &name); err != nil {
   291  			return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   292  		}
   293  		close(h.waiter(name))
   294  		return nil, nil
   295  	case "peek":
   296  		if len(req.Params) > 0 {
   297  			return nil, errors.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams)
   298  		}
   299  		return h.accumulator, nil
   300  	case "cancel":
   301  		var params cancelParams
   302  		if err := json.Unmarshal(req.Params, &params); err != nil {
   303  			return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   304  		}
   305  		h.conn.Cancel(jsonrpc2.Int64ID(params.ID))
   306  		return nil, nil
   307  	default:
   308  		return nil, jsonrpc2.ErrNotHandled
   309  	}
   310  }
   311  
   312  func (h *handler) Handle(ctx context.Context, req *jsonrpc2.Request) (interface{}, error) {
   313  	switch req.Method {
   314  	case "no_args":
   315  		if len(req.Params) > 0 {
   316  			return nil, errors.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams)
   317  		}
   318  		return true, nil
   319  	case "one_string":
   320  		var v string
   321  		if err := json.Unmarshal(req.Params, &v); err != nil {
   322  			return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   323  		}
   324  		return "got:" + v, nil
   325  	case "one_number":
   326  		var v int
   327  		if err := json.Unmarshal(req.Params, &v); err != nil {
   328  			return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   329  		}
   330  		return fmt.Sprintf("got:%d", v), nil
   331  	case "set":
   332  		var v int
   333  		if err := json.Unmarshal(req.Params, &v); err != nil {
   334  			return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   335  		}
   336  		h.accumulator = v
   337  		return nil, nil
   338  	case "add":
   339  		var v int
   340  		if err := json.Unmarshal(req.Params, &v); err != nil {
   341  			return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   342  		}
   343  		h.accumulator += v
   344  		return nil, nil
   345  	case "get":
   346  		if len(req.Params) > 0 {
   347  			return nil, errors.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams)
   348  		}
   349  		return h.accumulator, nil
   350  	case "join":
   351  		var v []string
   352  		if err := json.Unmarshal(req.Params, &v); err != nil {
   353  			return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   354  		}
   355  		return path.Join(v...), nil
   356  	case "echo":
   357  		var v []interface{}
   358  		if err := json.Unmarshal(req.Params, &v); err != nil {
   359  			return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   360  		}
   361  		var result interface{}
   362  		err := h.conn.Call(ctx, v[0].(string), v[1]).Await(ctx, &result)
   363  		return result, err
   364  	case "wait":
   365  		var name string
   366  		if err := json.Unmarshal(req.Params, &name); err != nil {
   367  			return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   368  		}
   369  		select {
   370  		case <-h.waiter(name):
   371  			return true, nil
   372  		case <-ctx.Done():
   373  			return nil, ctx.Err()
   374  		case <-time.After(time.Second):
   375  			return nil, errors.Errorf("wait for %q timed out", name)
   376  		}
   377  	case "fork":
   378  		var name string
   379  		if err := json.Unmarshal(req.Params, &name); err != nil {
   380  			return nil, errors.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   381  		}
   382  		waitFor := h.waiter(name)
   383  		go func() {
   384  			select {
   385  			case <-waitFor:
   386  				h.conn.Respond(req.ID, true, nil)
   387  			case <-ctx.Done():
   388  				h.conn.Respond(req.ID, nil, ctx.Err())
   389  			case <-time.After(time.Second):
   390  				h.conn.Respond(req.ID, nil, errors.Errorf("wait for %q timed out", name))
   391  			}
   392  		}()
   393  		return nil, jsonrpc2.ErrAsyncResponse
   394  	default:
   395  		return nil, jsonrpc2.ErrNotHandled
   396  	}
   397  }