github.com/goplus/gop@v1.2.6/x/jsonrpc2/jsonrpc2test/cases/testcase.go (about)

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package cases
     6  
     7  import (
     8  	"context"
     9  	"encoding/json"
    10  	"fmt"
    11  	"log"
    12  	"path"
    13  	"reflect"
    14  	"sync"
    15  	"testing"
    16  	"time"
    17  
    18  	"github.com/goplus/gop/x/jsonrpc2"
    19  	"github.com/goplus/gop/x/jsonrpc2/internal/stack/stacktest"
    20  )
    21  
    22  var callTests = []invoker{
    23  	call{"no_args", nil, true},
    24  	call{"one_string", "fish", "got:fish"},
    25  	call{"one_number", 10, "got:10"},
    26  	call{"join", []string{"a", "b", "c"}, "a/b/c"},
    27  	sequence{"notify", []invoker{
    28  		notify{"set", 3},
    29  		notify{"add", 5},
    30  		call{"get", nil, 8},
    31  	}},
    32  	sequence{"preempt", []invoker{
    33  		async{"a", "wait", "a"},
    34  		notify{"unblock", "a"},
    35  		collect{"a", true, false},
    36  	}},
    37  	sequence{"basic cancel", []invoker{
    38  		async{"b", "wait", "b"},
    39  		cancel{"b"},
    40  		collect{"b", nil, true},
    41  	}},
    42  	sequence{"queue", []invoker{
    43  		async{"a", "wait", "a"},
    44  		notify{"set", 1},
    45  		notify{"add", 2},
    46  		notify{"add", 3},
    47  		notify{"add", 4},
    48  		// call{"peek", nil, 0}, // accumulator will not have any adds yet
    49  		notify{"unblock", "a"},
    50  		collect{"a", true, false},
    51  		call{"get", nil, 10}, // accumulator now has all the adds
    52  	}},
    53  	sequence{"fork", []invoker{
    54  		async{"a", "fork", "a"},
    55  		notify{"set", 1},
    56  		notify{"add", 2},
    57  		notify{"add", 3},
    58  		notify{"add", 4},
    59  		call{"get", nil, 10}, // fork will not have blocked the adds
    60  		notify{"unblock", "a"},
    61  		collect{"a", true, false},
    62  	}},
    63  	sequence{"concurrent", []invoker{
    64  		async{"a", "fork", "a"},
    65  		notify{"unblock", "a"},
    66  		async{"b", "fork", "b"},
    67  		notify{"unblock", "b"},
    68  		collect{"a", true, false},
    69  		collect{"b", true, false},
    70  	}},
    71  }
    72  
    73  type binder struct {
    74  	framer  jsonrpc2.Framer
    75  	runTest func(*handler)
    76  }
    77  
    78  type handler struct {
    79  	conn        *jsonrpc2.Connection
    80  	accumulator int
    81  
    82  	mutex   sync.Mutex
    83  	waiters map[string]chan struct{}
    84  
    85  	calls map[string]*jsonrpc2.AsyncCall
    86  }
    87  
    88  type invoker interface {
    89  	Name() string
    90  	Invoke(t *testing.T, ctx context.Context, h *handler)
    91  }
    92  
    93  type notify struct {
    94  	method string
    95  	params interface{}
    96  }
    97  
    98  type call struct {
    99  	method string
   100  	params interface{}
   101  	expect interface{}
   102  }
   103  
   104  type async struct {
   105  	name   string
   106  	method string
   107  	params interface{}
   108  }
   109  
   110  type collect struct {
   111  	name   string
   112  	expect interface{}
   113  	fails  bool
   114  }
   115  
   116  type cancel struct {
   117  	name string
   118  }
   119  
   120  type sequence struct {
   121  	name  string
   122  	tests []invoker
   123  }
   124  
   125  type echo call
   126  
   127  type cancelParams struct{ ID int64 }
   128  
   129  func Test(t *testing.T, ctx context.Context, listener jsonrpc2.Listener, framer jsonrpc2.Framer, noLeak bool) {
   130  	if noLeak {
   131  		stacktest.NoLeak(t)
   132  	}
   133  	server := jsonrpc2.NewServer(ctx, listener, binder{framer, nil})
   134  	defer func() {
   135  		listener.Close()
   136  		if noLeak {
   137  			server.Wait()
   138  		}
   139  	}()
   140  	for _, test := range callTests {
   141  		t.Run(test.Name(), func(t *testing.T) {
   142  			client, err := jsonrpc2.Dial(ctx,
   143  				listener.Dialer(), binder{framer, func(h *handler) {
   144  					defer h.conn.Close()
   145  					ctx := context.Background()
   146  					test.Invoke(t, ctx, h)
   147  					if call, ok := test.(*call); ok {
   148  						// also run all simple call tests in echo mode
   149  						(*echo)(call).Invoke(t, ctx, h)
   150  					}
   151  				}}, nil)
   152  			if err != nil {
   153  				t.Fatal(err)
   154  			}
   155  			client.Wait()
   156  		})
   157  	}
   158  }
   159  
   160  func (test notify) Name() string { return test.method }
   161  func (test notify) Invoke(t *testing.T, ctx context.Context, h *handler) {
   162  	if err := h.conn.Notify(ctx, test.method, test.params); err != nil {
   163  		t.Fatalf("%v:Notify failed: %v", test.method, err)
   164  	}
   165  }
   166  
   167  func (test call) Name() string { return test.method }
   168  func (test call) Invoke(t *testing.T, ctx context.Context, h *handler) {
   169  	results := newResults(test.expect)
   170  	if err := h.conn.Call(ctx, test.method, test.params).Await(ctx, results); err != nil {
   171  		t.Fatalf("%v:Call failed: %v", test.method, err)
   172  	}
   173  	verifyResults(t, test.method, results, test.expect)
   174  }
   175  
   176  func (test echo) Invoke(t *testing.T, ctx context.Context, h *handler) {
   177  	results := newResults(test.expect)
   178  	if err := h.conn.Call(ctx, "echo", []interface{}{test.method, test.params}).Await(ctx, results); err != nil {
   179  		t.Fatalf("%v:Echo failed: %v", test.method, err)
   180  	}
   181  	verifyResults(t, test.method, results, test.expect)
   182  }
   183  
   184  func (test async) Name() string { return test.name }
   185  func (test async) Invoke(t *testing.T, ctx context.Context, h *handler) {
   186  	h.calls[test.name] = h.conn.Call(ctx, test.method, test.params)
   187  }
   188  
   189  func (test collect) Name() string { return test.name }
   190  func (test collect) Invoke(t *testing.T, ctx context.Context, h *handler) {
   191  	o := h.calls[test.name]
   192  	results := newResults(test.expect)
   193  	err := o.Await(ctx, results)
   194  	switch {
   195  	case test.fails && err == nil:
   196  		t.Fatalf("%v:Collect was supposed to fail", test.name)
   197  	case !test.fails && err != nil:
   198  		t.Fatalf("%v:Collect failed: %v", test.name, err)
   199  	}
   200  	verifyResults(t, test.name, results, test.expect)
   201  }
   202  
   203  func (test cancel) Name() string { return test.name }
   204  func (test cancel) Invoke(t *testing.T, ctx context.Context, h *handler) {
   205  	o := h.calls[test.name]
   206  	if err := h.conn.Notify(ctx, "cancel", &cancelParams{o.ID().Raw().(int64)}); err != nil {
   207  		t.Fatalf("%v:Collect failed: %v", test.name, err)
   208  	}
   209  }
   210  
   211  func (test sequence) Name() string { return test.name }
   212  func (test sequence) Invoke(t *testing.T, ctx context.Context, h *handler) {
   213  	for _, child := range test.tests {
   214  		child.Invoke(t, ctx, h)
   215  	}
   216  }
   217  
   218  // newResults makes a new empty copy of the expected type to put the results into
   219  func newResults(expect interface{}) interface{} {
   220  	switch e := expect.(type) {
   221  	case []interface{}:
   222  		var r []interface{}
   223  		for _, v := range e {
   224  			r = append(r, reflect.New(reflect.TypeOf(v)).Interface())
   225  		}
   226  		return r
   227  	case nil:
   228  		return nil
   229  	default:
   230  		return reflect.New(reflect.TypeOf(expect)).Interface()
   231  	}
   232  }
   233  
   234  // verifyResults compares the results to the expected values
   235  func verifyResults(t *testing.T, method string, results interface{}, expect interface{}) {
   236  	if expect == nil {
   237  		if results != nil {
   238  			t.Errorf("%v:Got results %+v where none expeted", method, expect)
   239  		}
   240  		return
   241  	}
   242  	val := reflect.Indirect(reflect.ValueOf(results)).Interface()
   243  	if !reflect.DeepEqual(val, expect) {
   244  		t.Errorf("%v:Results are incorrect, got %+v expect %+v", method, val, expect)
   245  	}
   246  }
   247  
   248  func (b binder) Bind(ctx context.Context, conn *jsonrpc2.Connection) jsonrpc2.ConnectionOptions {
   249  	h := &handler{
   250  		conn:    conn,
   251  		waiters: make(map[string]chan struct{}),
   252  		calls:   make(map[string]*jsonrpc2.AsyncCall),
   253  	}
   254  	if b.runTest != nil {
   255  		go b.runTest(h)
   256  	}
   257  	return jsonrpc2.ConnectionOptions{
   258  		Framer:    b.framer,
   259  		Preempter: h,
   260  		Handler:   h,
   261  	}
   262  }
   263  
   264  func (h *handler) waiter(name string) chan struct{} {
   265  	log.Println("waiter:", name)
   266  	h.mutex.Lock()
   267  	defer h.mutex.Unlock()
   268  	waiter := make(chan struct{})
   269  	h.waiters[name] = waiter
   270  	return waiter
   271  }
   272  
   273  func (h *handler) closeWaiter(name string) {
   274  	log.Println("closeWaiter:", name)
   275  	for !h.tryCloseWaiter(name) {
   276  		time.Sleep(time.Millisecond)
   277  	}
   278  }
   279  
   280  func (h *handler) tryCloseWaiter(name string) (ok bool) {
   281  	h.mutex.Lock()
   282  	defer h.mutex.Unlock()
   283  	waiter, ok := h.waiters[name]
   284  	if ok {
   285  		delete(h.waiters, name)
   286  		close(waiter)
   287  	}
   288  	return
   289  }
   290  
   291  func (h *handler) Preempt(ctx context.Context, req *jsonrpc2.Request) (interface{}, error) {
   292  	switch req.Method {
   293  	case "unblock":
   294  		var name string
   295  		if err := json.Unmarshal(req.Params, &name); err != nil {
   296  			return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   297  		}
   298  		h.closeWaiter(name)
   299  		return nil, nil
   300  	case "peek":
   301  		if len(req.Params) > 0 {
   302  			return nil, fmt.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams)
   303  		}
   304  		return h.accumulator, nil
   305  	case "cancel":
   306  		var params cancelParams
   307  		if err := json.Unmarshal(req.Params, &params); err != nil {
   308  			return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   309  		}
   310  		h.conn.Cancel(jsonrpc2.Int64ID(params.ID))
   311  		return nil, nil
   312  	default:
   313  		return nil, jsonrpc2.ErrNotHandled
   314  	}
   315  }
   316  
   317  func (h *handler) Handle(ctx context.Context, req *jsonrpc2.Request) (interface{}, error) {
   318  	switch req.Method {
   319  	case "no_args":
   320  		if len(req.Params) > 0 {
   321  			return nil, fmt.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams)
   322  		}
   323  		return true, nil
   324  	case "one_string":
   325  		var v string
   326  		if err := json.Unmarshal(req.Params, &v); err != nil {
   327  			return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   328  		}
   329  		return "got:" + v, nil
   330  	case "one_number":
   331  		var v int
   332  		if err := json.Unmarshal(req.Params, &v); err != nil {
   333  			return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   334  		}
   335  		return fmt.Sprintf("got:%d", v), nil
   336  	case "set":
   337  		var v int
   338  		if err := json.Unmarshal(req.Params, &v); err != nil {
   339  			return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   340  		}
   341  		h.accumulator = v
   342  		return nil, nil
   343  	case "add":
   344  		var v int
   345  		if err := json.Unmarshal(req.Params, &v); err != nil {
   346  			return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   347  		}
   348  		h.accumulator += v
   349  		return nil, nil
   350  	case "get":
   351  		if len(req.Params) > 0 {
   352  			return nil, fmt.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams)
   353  		}
   354  		return h.accumulator, nil
   355  	case "join":
   356  		var v []string
   357  		if err := json.Unmarshal(req.Params, &v); err != nil {
   358  			return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   359  		}
   360  		return path.Join(v...), nil
   361  	case "echo":
   362  		var v []interface{}
   363  		if err := json.Unmarshal(req.Params, &v); err != nil {
   364  			return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   365  		}
   366  		var result interface{}
   367  		err := h.conn.Call(ctx, v[0].(string), v[1]).Await(ctx, &result)
   368  		return result, err
   369  	case "wait":
   370  		var name string
   371  		if err := json.Unmarshal(req.Params, &name); err != nil {
   372  			return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   373  		}
   374  		select {
   375  		case <-h.waiter(name):
   376  			return true, nil
   377  		case <-ctx.Done():
   378  			return nil, ctx.Err()
   379  		}
   380  	case "fork":
   381  		var name string
   382  		if err := json.Unmarshal(req.Params, &name); err != nil {
   383  			return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   384  		}
   385  		waitFor := h.waiter(name)
   386  		go func() {
   387  			select {
   388  			case <-waitFor:
   389  				h.conn.Respond(req.ID, true, nil)
   390  			case <-ctx.Done():
   391  				h.conn.Respond(req.ID, nil, ctx.Err())
   392  			}
   393  		}()
   394  		return nil, jsonrpc2.ErrAsyncResponse
   395  	default:
   396  		return nil, jsonrpc2.ErrNotHandled
   397  	}
   398  }