golang.org/x/tools@v0.21.1-0.20240520172518-788d39e776b1/internal/jsonrpc2_v2/jsonrpc2_test.go (about)

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package jsonrpc2_test
     6  
     7  import (
     8  	"context"
     9  	"encoding/json"
    10  	"fmt"
    11  	"path"
    12  	"reflect"
    13  	"testing"
    14  
    15  	"golang.org/x/tools/internal/event/export/eventtest"
    16  	jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2"
    17  	"golang.org/x/tools/internal/stack/stacktest"
    18  )
    19  
    20  var callTests = []invoker{
    21  	call{"no_args", nil, true},
    22  	call{"one_string", "fish", "got:fish"},
    23  	call{"one_number", 10, "got:10"},
    24  	call{"join", []string{"a", "b", "c"}, "a/b/c"},
    25  	sequence{"notify", []invoker{
    26  		notify{"set", 3},
    27  		notify{"add", 5},
    28  		call{"get", nil, 8},
    29  	}},
    30  	sequence{"preempt", []invoker{
    31  		async{"a", "wait", "a"},
    32  		notify{"unblock", "a"},
    33  		collect{"a", true, false},
    34  	}},
    35  	sequence{"basic cancel", []invoker{
    36  		async{"b", "wait", "b"},
    37  		cancel{"b"},
    38  		collect{"b", nil, true},
    39  	}},
    40  	sequence{"queue", []invoker{
    41  		async{"a", "wait", "a"},
    42  		notify{"set", 1},
    43  		notify{"add", 2},
    44  		notify{"add", 3},
    45  		notify{"add", 4},
    46  		call{"peek", nil, 0}, // accumulator will not have any adds yet
    47  		notify{"unblock", "a"},
    48  		collect{"a", true, false},
    49  		call{"get", nil, 10}, // accumulator now has all the adds
    50  	}},
    51  	sequence{"fork", []invoker{
    52  		async{"a", "fork", "a"},
    53  		notify{"set", 1},
    54  		notify{"add", 2},
    55  		notify{"add", 3},
    56  		notify{"add", 4},
    57  		call{"get", nil, 10}, // fork will not have blocked the adds
    58  		notify{"unblock", "a"},
    59  		collect{"a", true, false},
    60  	}},
    61  	sequence{"concurrent", []invoker{
    62  		async{"a", "fork", "a"},
    63  		notify{"unblock", "a"},
    64  		async{"b", "fork", "b"},
    65  		notify{"unblock", "b"},
    66  		collect{"a", true, false},
    67  		collect{"b", true, false},
    68  	}},
    69  }
    70  
    71  type binder struct {
    72  	framer  jsonrpc2.Framer
    73  	runTest func(*handler)
    74  }
    75  
    76  type handler struct {
    77  	conn        *jsonrpc2.Connection
    78  	accumulator int
    79  	waiters     chan map[string]chan struct{}
    80  	calls       map[string]*jsonrpc2.AsyncCall
    81  }
    82  
    83  type invoker interface {
    84  	Name() string
    85  	Invoke(t *testing.T, ctx context.Context, h *handler)
    86  }
    87  
    88  type notify struct {
    89  	method string
    90  	params interface{}
    91  }
    92  
    93  type call struct {
    94  	method string
    95  	params interface{}
    96  	expect interface{}
    97  }
    98  
    99  type async struct {
   100  	name   string
   101  	method string
   102  	params interface{}
   103  }
   104  
   105  type collect struct {
   106  	name   string
   107  	expect interface{}
   108  	fails  bool
   109  }
   110  
   111  type cancel struct {
   112  	name string
   113  }
   114  
   115  type sequence struct {
   116  	name  string
   117  	tests []invoker
   118  }
   119  
   120  type echo call
   121  
   122  type cancelParams struct{ ID int64 }
   123  
   124  func TestConnectionRaw(t *testing.T) {
   125  	testConnection(t, jsonrpc2.RawFramer())
   126  }
   127  
   128  func TestConnectionHeader(t *testing.T) {
   129  	testConnection(t, jsonrpc2.HeaderFramer())
   130  }
   131  
   132  func testConnection(t *testing.T, framer jsonrpc2.Framer) {
   133  	stacktest.NoLeak(t)
   134  	ctx := eventtest.NewContext(context.Background(), t)
   135  	listener, err := jsonrpc2.NetPipeListener(ctx)
   136  	if err != nil {
   137  		t.Fatal(err)
   138  	}
   139  	server := jsonrpc2.NewServer(ctx, listener, binder{framer, nil})
   140  	defer func() {
   141  		listener.Close()
   142  		server.Wait()
   143  	}()
   144  
   145  	for _, test := range callTests {
   146  		t.Run(test.Name(), func(t *testing.T) {
   147  			client, err := jsonrpc2.Dial(ctx,
   148  				listener.Dialer(), binder{framer, func(h *handler) {
   149  					defer h.conn.Close()
   150  					ctx := eventtest.NewContext(ctx, t)
   151  					test.Invoke(t, ctx, h)
   152  					if call, ok := test.(*call); ok {
   153  						// also run all simple call tests in echo mode
   154  						(*echo)(call).Invoke(t, ctx, h)
   155  					}
   156  				}})
   157  			if err != nil {
   158  				t.Fatal(err)
   159  			}
   160  			client.Wait()
   161  		})
   162  	}
   163  }
   164  
   165  func (test notify) Name() string { return test.method }
   166  func (test notify) Invoke(t *testing.T, ctx context.Context, h *handler) {
   167  	if err := h.conn.Notify(ctx, test.method, test.params); err != nil {
   168  		t.Fatalf("%v:Notify failed: %v", test.method, err)
   169  	}
   170  }
   171  
   172  func (test call) Name() string { return test.method }
   173  func (test call) Invoke(t *testing.T, ctx context.Context, h *handler) {
   174  	results := newResults(test.expect)
   175  	if err := h.conn.Call(ctx, test.method, test.params).Await(ctx, results); err != nil {
   176  		t.Fatalf("%v:Call failed: %v", test.method, err)
   177  	}
   178  	verifyResults(t, test.method, results, test.expect)
   179  }
   180  
   181  func (test echo) Invoke(t *testing.T, ctx context.Context, h *handler) {
   182  	results := newResults(test.expect)
   183  	if err := h.conn.Call(ctx, "echo", []interface{}{test.method, test.params}).Await(ctx, results); err != nil {
   184  		t.Fatalf("%v:Echo failed: %v", test.method, err)
   185  	}
   186  	verifyResults(t, test.method, results, test.expect)
   187  }
   188  
   189  func (test async) Name() string { return test.name }
   190  func (test async) Invoke(t *testing.T, ctx context.Context, h *handler) {
   191  	h.calls[test.name] = h.conn.Call(ctx, test.method, test.params)
   192  }
   193  
   194  func (test collect) Name() string { return test.name }
   195  func (test collect) Invoke(t *testing.T, ctx context.Context, h *handler) {
   196  	o := h.calls[test.name]
   197  	results := newResults(test.expect)
   198  	err := o.Await(ctx, results)
   199  	switch {
   200  	case test.fails && err == nil:
   201  		t.Fatalf("%v:Collect was supposed to fail", test.name)
   202  	case !test.fails && err != nil:
   203  		t.Fatalf("%v:Collect failed: %v", test.name, err)
   204  	}
   205  	verifyResults(t, test.name, results, test.expect)
   206  }
   207  
   208  func (test cancel) Name() string { return test.name }
   209  func (test cancel) Invoke(t *testing.T, ctx context.Context, h *handler) {
   210  	o := h.calls[test.name]
   211  	if err := h.conn.Notify(ctx, "cancel", &cancelParams{o.ID().Raw().(int64)}); err != nil {
   212  		t.Fatalf("%v:Collect failed: %v", test.name, err)
   213  	}
   214  }
   215  
   216  func (test sequence) Name() string { return test.name }
   217  func (test sequence) Invoke(t *testing.T, ctx context.Context, h *handler) {
   218  	for _, child := range test.tests {
   219  		child.Invoke(t, ctx, h)
   220  	}
   221  }
   222  
   223  // newResults makes a new empty copy of the expected type to put the results into
   224  func newResults(expect interface{}) interface{} {
   225  	switch e := expect.(type) {
   226  	case []interface{}:
   227  		var r []interface{}
   228  		for _, v := range e {
   229  			r = append(r, reflect.New(reflect.TypeOf(v)).Interface())
   230  		}
   231  		return r
   232  	case nil:
   233  		return nil
   234  	default:
   235  		return reflect.New(reflect.TypeOf(expect)).Interface()
   236  	}
   237  }
   238  
   239  // verifyResults compares the results to the expected values
   240  func verifyResults(t *testing.T, method string, results interface{}, expect interface{}) {
   241  	if expect == nil {
   242  		if results != nil {
   243  			t.Errorf("%v:Got results %+v where none expected", method, expect)
   244  		}
   245  		return
   246  	}
   247  	val := reflect.Indirect(reflect.ValueOf(results)).Interface()
   248  	if !reflect.DeepEqual(val, expect) {
   249  		t.Errorf("%v:Results are incorrect, got %+v expect %+v", method, val, expect)
   250  	}
   251  }
   252  
   253  func (b binder) Bind(ctx context.Context, conn *jsonrpc2.Connection) jsonrpc2.ConnectionOptions {
   254  	h := &handler{
   255  		conn:    conn,
   256  		waiters: make(chan map[string]chan struct{}, 1),
   257  		calls:   make(map[string]*jsonrpc2.AsyncCall),
   258  	}
   259  	h.waiters <- make(map[string]chan struct{})
   260  	if b.runTest != nil {
   261  		go b.runTest(h)
   262  	}
   263  	return jsonrpc2.ConnectionOptions{
   264  		Framer:    b.framer,
   265  		Preempter: h,
   266  		Handler:   h,
   267  	}
   268  }
   269  
   270  func (h *handler) waiter(name string) chan struct{} {
   271  	waiters := <-h.waiters
   272  	defer func() { h.waiters <- waiters }()
   273  	waiter, found := waiters[name]
   274  	if !found {
   275  		waiter = make(chan struct{})
   276  		waiters[name] = waiter
   277  	}
   278  	return waiter
   279  }
   280  
   281  func (h *handler) Preempt(ctx context.Context, req *jsonrpc2.Request) (interface{}, error) {
   282  	switch req.Method {
   283  	case "unblock":
   284  		var name string
   285  		if err := json.Unmarshal(req.Params, &name); err != nil {
   286  			return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   287  		}
   288  		close(h.waiter(name))
   289  		return nil, nil
   290  	case "peek":
   291  		if len(req.Params) > 0 {
   292  			return nil, fmt.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams)
   293  		}
   294  		return h.accumulator, nil
   295  	case "cancel":
   296  		var params cancelParams
   297  		if err := json.Unmarshal(req.Params, &params); err != nil {
   298  			return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   299  		}
   300  		h.conn.Cancel(jsonrpc2.Int64ID(params.ID))
   301  		return nil, nil
   302  	default:
   303  		return nil, jsonrpc2.ErrNotHandled
   304  	}
   305  }
   306  
   307  func (h *handler) Handle(ctx context.Context, req *jsonrpc2.Request) (interface{}, error) {
   308  	switch req.Method {
   309  	case "no_args":
   310  		if len(req.Params) > 0 {
   311  			return nil, fmt.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams)
   312  		}
   313  		return true, nil
   314  	case "one_string":
   315  		var v string
   316  		if err := json.Unmarshal(req.Params, &v); err != nil {
   317  			return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   318  		}
   319  		return "got:" + v, nil
   320  	case "one_number":
   321  		var v int
   322  		if err := json.Unmarshal(req.Params, &v); err != nil {
   323  			return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   324  		}
   325  		return fmt.Sprintf("got:%d", v), nil
   326  	case "set":
   327  		var v int
   328  		if err := json.Unmarshal(req.Params, &v); err != nil {
   329  			return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   330  		}
   331  		h.accumulator = v
   332  		return nil, nil
   333  	case "add":
   334  		var v int
   335  		if err := json.Unmarshal(req.Params, &v); err != nil {
   336  			return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   337  		}
   338  		h.accumulator += v
   339  		return nil, nil
   340  	case "get":
   341  		if len(req.Params) > 0 {
   342  			return nil, fmt.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams)
   343  		}
   344  		return h.accumulator, nil
   345  	case "join":
   346  		var v []string
   347  		if err := json.Unmarshal(req.Params, &v); err != nil {
   348  			return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   349  		}
   350  		return path.Join(v...), nil
   351  	case "echo":
   352  		var v []interface{}
   353  		if err := json.Unmarshal(req.Params, &v); err != nil {
   354  			return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   355  		}
   356  		var result interface{}
   357  		err := h.conn.Call(ctx, v[0].(string), v[1]).Await(ctx, &result)
   358  		return result, err
   359  	case "wait":
   360  		var name string
   361  		if err := json.Unmarshal(req.Params, &name); err != nil {
   362  			return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   363  		}
   364  		select {
   365  		case <-h.waiter(name):
   366  			return true, nil
   367  		case <-ctx.Done():
   368  			return nil, ctx.Err()
   369  		}
   370  	case "fork":
   371  		var name string
   372  		if err := json.Unmarshal(req.Params, &name); err != nil {
   373  			return nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err)
   374  		}
   375  		waitFor := h.waiter(name)
   376  		go func() {
   377  			select {
   378  			case <-waitFor:
   379  				h.conn.Respond(req.ID, true, nil)
   380  			case <-ctx.Done():
   381  				h.conn.Respond(req.ID, nil, ctx.Err())
   382  			}
   383  		}()
   384  		return nil, jsonrpc2.ErrAsyncResponse
   385  	default:
   386  		return nil, jsonrpc2.ErrNotHandled
   387  	}
   388  }