golang.org/x/tools/gopls@v0.15.3/internal/lsprpc/middleware_test.go (about)

     1  // Copyright 2021 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package lsprpc_test
     6  
     7  import (
     8  	"context"
     9  	"encoding/json"
    10  	"errors"
    11  	"fmt"
    12  	"sync"
    13  	"testing"
    14  	"time"
    15  
    16  	. "golang.org/x/tools/gopls/internal/lsprpc"
    17  	"golang.org/x/tools/internal/event"
    18  	jsonrpc2_v2 "golang.org/x/tools/internal/jsonrpc2_v2"
    19  )
    20  
    21  var noopBinder = BinderFunc(func(context.Context, *jsonrpc2_v2.Connection) jsonrpc2_v2.ConnectionOptions {
    22  	return jsonrpc2_v2.ConnectionOptions{}
    23  })
    24  
    25  func TestHandshakeMiddleware(t *testing.T) {
    26  	sh := &Handshaker{
    27  		metadata: metadata{
    28  			"answer": 42,
    29  		},
    30  	}
    31  	ctx := context.Background()
    32  	env := new(TestEnv)
    33  	defer env.Shutdown(t)
    34  	l, _ := env.serve(ctx, t, sh.Middleware(noopBinder))
    35  	conn := env.dial(ctx, t, l.Dialer(), noopBinder, false)
    36  	ch := &Handshaker{
    37  		metadata: metadata{
    38  			"question": 6 * 9,
    39  		},
    40  	}
    41  
    42  	check := func(connected bool) error {
    43  		clients := sh.Peers()
    44  		servers := ch.Peers()
    45  		want := 0
    46  		if connected {
    47  			want = 1
    48  		}
    49  		if got := len(clients); got != want {
    50  			return fmt.Errorf("got %d clients on the server, want %d", got, want)
    51  		}
    52  		if got := len(servers); got != want {
    53  			return fmt.Errorf("got %d servers on the client, want %d", got, want)
    54  		}
    55  		if !connected {
    56  			return nil
    57  		}
    58  		client := clients[0]
    59  		server := servers[0]
    60  		if _, ok := client.Metadata["question"]; !ok {
    61  			return errors.New("no client metadata")
    62  		}
    63  		if _, ok := server.Metadata["answer"]; !ok {
    64  			return errors.New("no server metadata")
    65  		}
    66  		if client.LocalID != server.RemoteID {
    67  			return fmt.Errorf("client.LocalID == %d, server.PeerID == %d", client.LocalID, server.RemoteID)
    68  		}
    69  		if client.RemoteID != server.LocalID {
    70  			return fmt.Errorf("client.PeerID == %d, server.LocalID == %d", client.RemoteID, server.LocalID)
    71  		}
    72  		return nil
    73  	}
    74  
    75  	if err := check(false); err != nil {
    76  		t.Fatalf("before handshake: %v", err)
    77  	}
    78  	ch.ClientHandshake(ctx, conn)
    79  	if err := check(true); err != nil {
    80  		t.Fatalf("after handshake: %v", err)
    81  	}
    82  	conn.Close()
    83  	// Wait for up to ~2s for connections to get cleaned up.
    84  	delay := 25 * time.Millisecond
    85  	for retries := 3; retries >= 0; retries-- {
    86  		time.Sleep(delay)
    87  		err := check(false)
    88  		if err == nil {
    89  			return
    90  		}
    91  		if retries == 0 {
    92  			t.Fatalf("after closing connection: %v", err)
    93  		}
    94  		delay *= 4
    95  	}
    96  }
    97  
    98  // Handshaker handles both server and client handshaking over jsonrpc2 v2.
    99  // To instrument server-side handshaking, use Handshaker.Middleware.
   100  // To instrument client-side handshaking, call
   101  // Handshaker.ClientHandshake for any new client-side connections.
   102  type Handshaker struct {
   103  	// metadata will be shared with peers via handshaking.
   104  	metadata metadata
   105  
   106  	mu     sync.Mutex
   107  	prevID int64
   108  	peers  map[int64]PeerInfo
   109  }
   110  
   111  // metadata holds arbitrary data transferred between jsonrpc2 peers.
   112  type metadata map[string]any
   113  
   114  // PeerInfo holds information about a peering between jsonrpc2 servers.
   115  type PeerInfo struct {
   116  	// RemoteID is the identity of the current server on its peer.
   117  	RemoteID int64
   118  
   119  	// LocalID is the identity of the peer on the server.
   120  	LocalID int64
   121  
   122  	// IsClient reports whether the peer is a client. If false, the peer is a
   123  	// server.
   124  	IsClient bool
   125  
   126  	// Metadata holds arbitrary information provided by the peer.
   127  	Metadata metadata
   128  }
   129  
   130  // Peers returns the peer info this handshaker knows about by way of either the
   131  // server-side handshake middleware, or client-side handshakes.
   132  func (h *Handshaker) Peers() []PeerInfo {
   133  	h.mu.Lock()
   134  	defer h.mu.Unlock()
   135  
   136  	var c []PeerInfo
   137  	for _, v := range h.peers {
   138  		c = append(c, v)
   139  	}
   140  	return c
   141  }
   142  
   143  // Middleware is a jsonrpc2 middleware function to augment connection binding
   144  // to handle the handshake method, and record disconnections.
   145  func (h *Handshaker) Middleware(inner jsonrpc2_v2.Binder) jsonrpc2_v2.Binder {
   146  	return BinderFunc(func(ctx context.Context, conn *jsonrpc2_v2.Connection) jsonrpc2_v2.ConnectionOptions {
   147  		opts := inner.Bind(ctx, conn)
   148  
   149  		localID := h.nextID()
   150  		info := &PeerInfo{
   151  			RemoteID: localID,
   152  			Metadata: h.metadata,
   153  		}
   154  
   155  		// Wrap the delegated handler to accept the handshake.
   156  		delegate := opts.Handler
   157  		opts.Handler = jsonrpc2_v2.HandlerFunc(func(ctx context.Context, req *jsonrpc2_v2.Request) (interface{}, error) {
   158  			if req.Method == HandshakeMethod {
   159  				var peerInfo PeerInfo
   160  				if err := json.Unmarshal(req.Params, &peerInfo); err != nil {
   161  					return nil, fmt.Errorf("%w: unmarshaling client info: %v", jsonrpc2_v2.ErrInvalidParams, err)
   162  				}
   163  				peerInfo.LocalID = localID
   164  				peerInfo.IsClient = true
   165  				h.recordPeer(peerInfo)
   166  				return info, nil
   167  			}
   168  			return delegate.Handle(ctx, req)
   169  		})
   170  
   171  		// Record the dropped client.
   172  		go h.cleanupAtDisconnect(conn, localID)
   173  
   174  		return opts
   175  	})
   176  }
   177  
   178  // ClientHandshake performs a client-side handshake with the server at the
   179  // other end of conn, recording the server's peer info and watching for conn's
   180  // disconnection.
   181  func (h *Handshaker) ClientHandshake(ctx context.Context, conn *jsonrpc2_v2.Connection) {
   182  	localID := h.nextID()
   183  	info := &PeerInfo{
   184  		RemoteID: localID,
   185  		Metadata: h.metadata,
   186  	}
   187  
   188  	call := conn.Call(ctx, HandshakeMethod, info)
   189  	var serverInfo PeerInfo
   190  	if err := call.Await(ctx, &serverInfo); err != nil {
   191  		event.Error(ctx, "performing handshake", err)
   192  		return
   193  	}
   194  	serverInfo.LocalID = localID
   195  	h.recordPeer(serverInfo)
   196  
   197  	go h.cleanupAtDisconnect(conn, localID)
   198  }
   199  
   200  func (h *Handshaker) nextID() int64 {
   201  	h.mu.Lock()
   202  	defer h.mu.Unlock()
   203  
   204  	h.prevID++
   205  	return h.prevID
   206  }
   207  
   208  func (h *Handshaker) cleanupAtDisconnect(conn *jsonrpc2_v2.Connection, peerID int64) {
   209  	conn.Wait()
   210  
   211  	h.mu.Lock()
   212  	defer h.mu.Unlock()
   213  	delete(h.peers, peerID)
   214  }
   215  
   216  func (h *Handshaker) recordPeer(info PeerInfo) {
   217  	h.mu.Lock()
   218  	defer h.mu.Unlock()
   219  	if h.peers == nil {
   220  		h.peers = make(map[int64]PeerInfo)
   221  	}
   222  	h.peers[info.LocalID] = info
   223  }