golang.org/x/tools@v0.21.0/internal/jsonrpc2/jsonrpc2_test.go (about)

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package jsonrpc2_test
     6  
     7  import (
     8  	"context"
     9  	"encoding/json"
    10  	"flag"
    11  	"fmt"
    12  	"net"
    13  	"path"
    14  	"reflect"
    15  	"testing"
    16  
    17  	"golang.org/x/tools/internal/event/export/eventtest"
    18  	"golang.org/x/tools/internal/jsonrpc2"
    19  	"golang.org/x/tools/internal/stack/stacktest"
    20  )
    21  
    22  var logRPC = flag.Bool("logrpc", false, "Enable jsonrpc2 communication logging")
    23  
    24  type callTest struct {
    25  	method string
    26  	params interface{}
    27  	expect interface{}
    28  }
    29  
    30  var callTests = []callTest{
    31  	{"no_args", nil, true},
    32  	{"one_string", "fish", "got:fish"},
    33  	{"one_number", 10, "got:10"},
    34  	{"join", []string{"a", "b", "c"}, "a/b/c"},
    35  	//TODO: expand the test cases
    36  }
    37  
    38  func (test *callTest) newResults() interface{} {
    39  	switch e := test.expect.(type) {
    40  	case []interface{}:
    41  		var r []interface{}
    42  		for _, v := range e {
    43  			r = append(r, reflect.New(reflect.TypeOf(v)).Interface())
    44  		}
    45  		return r
    46  	case nil:
    47  		return nil
    48  	default:
    49  		return reflect.New(reflect.TypeOf(test.expect)).Interface()
    50  	}
    51  }
    52  
    53  func (test *callTest) verifyResults(t *testing.T, results interface{}) {
    54  	if results == nil {
    55  		return
    56  	}
    57  	val := reflect.Indirect(reflect.ValueOf(results)).Interface()
    58  	if !reflect.DeepEqual(val, test.expect) {
    59  		t.Errorf("%v:Results are incorrect, got %+v expect %+v", test.method, val, test.expect)
    60  	}
    61  }
    62  
    63  func TestCall(t *testing.T) {
    64  	stacktest.NoLeak(t)
    65  	ctx := eventtest.NewContext(context.Background(), t)
    66  	for _, headers := range []bool{false, true} {
    67  		name := "Plain"
    68  		if headers {
    69  			name = "Headers"
    70  		}
    71  		t.Run(name, func(t *testing.T) {
    72  			ctx := eventtest.NewContext(ctx, t)
    73  			a, b, done := prepare(ctx, t, headers)
    74  			defer done()
    75  			for _, test := range callTests {
    76  				t.Run(test.method, func(t *testing.T) {
    77  					ctx := eventtest.NewContext(ctx, t)
    78  					results := test.newResults()
    79  					if _, err := a.Call(ctx, test.method, test.params, results); err != nil {
    80  						t.Fatalf("%v:Call failed: %v", test.method, err)
    81  					}
    82  					test.verifyResults(t, results)
    83  					if _, err := b.Call(ctx, test.method, test.params, results); err != nil {
    84  						t.Fatalf("%v:Call failed: %v", test.method, err)
    85  					}
    86  					test.verifyResults(t, results)
    87  				})
    88  			}
    89  		})
    90  	}
    91  }
    92  
    93  func prepare(ctx context.Context, t *testing.T, withHeaders bool) (jsonrpc2.Conn, jsonrpc2.Conn, func()) {
    94  	// make a wait group that can be used to wait for the system to shut down
    95  	aPipe, bPipe := net.Pipe()
    96  	a := run(ctx, withHeaders, aPipe)
    97  	b := run(ctx, withHeaders, bPipe)
    98  	return a, b, func() {
    99  		a.Close()
   100  		b.Close()
   101  		<-a.Done()
   102  		<-b.Done()
   103  	}
   104  }
   105  
   106  func run(ctx context.Context, withHeaders bool, nc net.Conn) jsonrpc2.Conn {
   107  	var stream jsonrpc2.Stream
   108  	if withHeaders {
   109  		stream = jsonrpc2.NewHeaderStream(nc)
   110  	} else {
   111  		stream = jsonrpc2.NewRawStream(nc)
   112  	}
   113  	conn := jsonrpc2.NewConn(stream)
   114  	conn.Go(ctx, testHandler(*logRPC))
   115  	return conn
   116  }
   117  
   118  func testHandler(log bool) jsonrpc2.Handler {
   119  	return func(ctx context.Context, reply jsonrpc2.Replier, req jsonrpc2.Request) error {
   120  		switch req.Method() {
   121  		case "no_args":
   122  			if len(req.Params()) > 0 {
   123  				return reply(ctx, nil, fmt.Errorf("%w: expected no params", jsonrpc2.ErrInvalidParams))
   124  			}
   125  			return reply(ctx, true, nil)
   126  		case "one_string":
   127  			var v string
   128  			if err := json.Unmarshal(req.Params(), &v); err != nil {
   129  				return reply(ctx, nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err))
   130  			}
   131  			return reply(ctx, "got:"+v, nil)
   132  		case "one_number":
   133  			var v int
   134  			if err := json.Unmarshal(req.Params(), &v); err != nil {
   135  				return reply(ctx, nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err))
   136  			}
   137  			return reply(ctx, fmt.Sprintf("got:%d", v), nil)
   138  		case "join":
   139  			var v []string
   140  			if err := json.Unmarshal(req.Params(), &v); err != nil {
   141  				return reply(ctx, nil, fmt.Errorf("%w: %s", jsonrpc2.ErrParse, err))
   142  			}
   143  			return reply(ctx, path.Join(v...), nil)
   144  		default:
   145  			return jsonrpc2.MethodNotFound(ctx, reply, req)
   146  		}
   147  	}
   148  }