golang.org/x/tools/gopls@v0.15.3/internal/lsprpc/middleware_test.go (about) 1 // Copyright 2021 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package lsprpc_test 6 7 import ( 8 "context" 9 "encoding/json" 10 "errors" 11 "fmt" 12 "sync" 13 "testing" 14 "time" 15 16 . "golang.org/x/tools/gopls/internal/lsprpc" 17 "golang.org/x/tools/internal/event" 18 jsonrpc2_v2 "golang.org/x/tools/internal/jsonrpc2_v2" 19 ) 20 21 var noopBinder = BinderFunc(func(context.Context, *jsonrpc2_v2.Connection) jsonrpc2_v2.ConnectionOptions { 22 return jsonrpc2_v2.ConnectionOptions{} 23 }) 24 25 func TestHandshakeMiddleware(t *testing.T) { 26 sh := &Handshaker{ 27 metadata: metadata{ 28 "answer": 42, 29 }, 30 } 31 ctx := context.Background() 32 env := new(TestEnv) 33 defer env.Shutdown(t) 34 l, _ := env.serve(ctx, t, sh.Middleware(noopBinder)) 35 conn := env.dial(ctx, t, l.Dialer(), noopBinder, false) 36 ch := &Handshaker{ 37 metadata: metadata{ 38 "question": 6 * 9, 39 }, 40 } 41 42 check := func(connected bool) error { 43 clients := sh.Peers() 44 servers := ch.Peers() 45 want := 0 46 if connected { 47 want = 1 48 } 49 if got := len(clients); got != want { 50 return fmt.Errorf("got %d clients on the server, want %d", got, want) 51 } 52 if got := len(servers); got != want { 53 return fmt.Errorf("got %d servers on the client, want %d", got, want) 54 } 55 if !connected { 56 return nil 57 } 58 client := clients[0] 59 server := servers[0] 60 if _, ok := client.Metadata["question"]; !ok { 61 return errors.New("no client metadata") 62 } 63 if _, ok := server.Metadata["answer"]; !ok { 64 return errors.New("no server metadata") 65 } 66 if client.LocalID != server.RemoteID { 67 return fmt.Errorf("client.LocalID == %d, server.PeerID == %d", client.LocalID, server.RemoteID) 68 } 69 if client.RemoteID != server.LocalID { 70 return fmt.Errorf("client.PeerID == %d, server.LocalID == %d", client.RemoteID, server.LocalID) 71 } 72 return nil 73 } 74 75 if err := check(false); err != nil { 76 t.Fatalf("before handshake: %v", err) 77 } 78 ch.ClientHandshake(ctx, conn) 79 if err := check(true); err != nil { 80 t.Fatalf("after handshake: %v", err) 81 } 82 conn.Close() 83 // Wait for up to ~2s for connections to get cleaned up. 84 delay := 25 * time.Millisecond 85 for retries := 3; retries >= 0; retries-- { 86 time.Sleep(delay) 87 err := check(false) 88 if err == nil { 89 return 90 } 91 if retries == 0 { 92 t.Fatalf("after closing connection: %v", err) 93 } 94 delay *= 4 95 } 96 } 97 98 // Handshaker handles both server and client handshaking over jsonrpc2 v2. 99 // To instrument server-side handshaking, use Handshaker.Middleware. 100 // To instrument client-side handshaking, call 101 // Handshaker.ClientHandshake for any new client-side connections. 102 type Handshaker struct { 103 // metadata will be shared with peers via handshaking. 104 metadata metadata 105 106 mu sync.Mutex 107 prevID int64 108 peers map[int64]PeerInfo 109 } 110 111 // metadata holds arbitrary data transferred between jsonrpc2 peers. 112 type metadata map[string]any 113 114 // PeerInfo holds information about a peering between jsonrpc2 servers. 115 type PeerInfo struct { 116 // RemoteID is the identity of the current server on its peer. 117 RemoteID int64 118 119 // LocalID is the identity of the peer on the server. 120 LocalID int64 121 122 // IsClient reports whether the peer is a client. If false, the peer is a 123 // server. 124 IsClient bool 125 126 // Metadata holds arbitrary information provided by the peer. 127 Metadata metadata 128 } 129 130 // Peers returns the peer info this handshaker knows about by way of either the 131 // server-side handshake middleware, or client-side handshakes. 132 func (h *Handshaker) Peers() []PeerInfo { 133 h.mu.Lock() 134 defer h.mu.Unlock() 135 136 var c []PeerInfo 137 for _, v := range h.peers { 138 c = append(c, v) 139 } 140 return c 141 } 142 143 // Middleware is a jsonrpc2 middleware function to augment connection binding 144 // to handle the handshake method, and record disconnections. 145 func (h *Handshaker) Middleware(inner jsonrpc2_v2.Binder) jsonrpc2_v2.Binder { 146 return BinderFunc(func(ctx context.Context, conn *jsonrpc2_v2.Connection) jsonrpc2_v2.ConnectionOptions { 147 opts := inner.Bind(ctx, conn) 148 149 localID := h.nextID() 150 info := &PeerInfo{ 151 RemoteID: localID, 152 Metadata: h.metadata, 153 } 154 155 // Wrap the delegated handler to accept the handshake. 156 delegate := opts.Handler 157 opts.Handler = jsonrpc2_v2.HandlerFunc(func(ctx context.Context, req *jsonrpc2_v2.Request) (interface{}, error) { 158 if req.Method == HandshakeMethod { 159 var peerInfo PeerInfo 160 if err := json.Unmarshal(req.Params, &peerInfo); err != nil { 161 return nil, fmt.Errorf("%w: unmarshaling client info: %v", jsonrpc2_v2.ErrInvalidParams, err) 162 } 163 peerInfo.LocalID = localID 164 peerInfo.IsClient = true 165 h.recordPeer(peerInfo) 166 return info, nil 167 } 168 return delegate.Handle(ctx, req) 169 }) 170 171 // Record the dropped client. 172 go h.cleanupAtDisconnect(conn, localID) 173 174 return opts 175 }) 176 } 177 178 // ClientHandshake performs a client-side handshake with the server at the 179 // other end of conn, recording the server's peer info and watching for conn's 180 // disconnection. 181 func (h *Handshaker) ClientHandshake(ctx context.Context, conn *jsonrpc2_v2.Connection) { 182 localID := h.nextID() 183 info := &PeerInfo{ 184 RemoteID: localID, 185 Metadata: h.metadata, 186 } 187 188 call := conn.Call(ctx, HandshakeMethod, info) 189 var serverInfo PeerInfo 190 if err := call.Await(ctx, &serverInfo); err != nil { 191 event.Error(ctx, "performing handshake", err) 192 return 193 } 194 serverInfo.LocalID = localID 195 h.recordPeer(serverInfo) 196 197 go h.cleanupAtDisconnect(conn, localID) 198 } 199 200 func (h *Handshaker) nextID() int64 { 201 h.mu.Lock() 202 defer h.mu.Unlock() 203 204 h.prevID++ 205 return h.prevID 206 } 207 208 func (h *Handshaker) cleanupAtDisconnect(conn *jsonrpc2_v2.Connection, peerID int64) { 209 conn.Wait() 210 211 h.mu.Lock() 212 defer h.mu.Unlock() 213 delete(h.peers, peerID) 214 } 215 216 func (h *Handshaker) recordPeer(info PeerInfo) { 217 h.mu.Lock() 218 defer h.mu.Unlock() 219 if h.peers == nil { 220 h.peers = make(map[int64]PeerInfo) 221 } 222 h.peers[info.LocalID] = info 223 }