github.com/jd-ly/tools@v0.5.7/internal/lsp/lsprpc/lsprpc.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package lsprpc implements a jsonrpc2.StreamServer that may be used to
     6  // serve the LSP on a jsonrpc2 channel.
     7  package lsprpc
     8  
     9  import (
    10  	"context"
    11  	"encoding/json"
    12  	"fmt"
    13  	"log"
    14  	"net"
    15  	"os"
    16  	"strconv"
    17  	"sync/atomic"
    18  	"time"
    19  
    20  	"github.com/jd-ly/tools/internal/event"
    21  	"github.com/jd-ly/tools/internal/gocommand"
    22  	"github.com/jd-ly/tools/internal/jsonrpc2"
    23  	"github.com/jd-ly/tools/internal/lsp"
    24  	"github.com/jd-ly/tools/internal/lsp/cache"
    25  	"github.com/jd-ly/tools/internal/lsp/debug"
    26  	"github.com/jd-ly/tools/internal/lsp/debug/tag"
    27  	"github.com/jd-ly/tools/internal/lsp/protocol"
    28  	errors "golang.org/x/xerrors"
    29  )
    30  
    31  // AutoNetwork is the pseudo network type used to signal that gopls should use
    32  // automatic discovery to resolve a remote address.
    33  const AutoNetwork = "auto"
    34  
    35  // Unique identifiers for client/server.
    36  var serverIndex int64
    37  
    38  // The StreamServer type is a jsonrpc2.StreamServer that handles incoming
    39  // streams as a new LSP session, using a shared cache.
    40  type StreamServer struct {
    41  	cache *cache.Cache
    42  	// daemon controls whether or not to log new connections.
    43  	daemon bool
    44  
    45  	// serverForTest may be set to a test fake for testing.
    46  	serverForTest protocol.Server
    47  }
    48  
    49  // NewStreamServer creates a StreamServer using the shared cache. If
    50  // withTelemetry is true, each session is instrumented with telemetry that
    51  // records RPC statistics.
    52  func NewStreamServer(cache *cache.Cache, daemon bool) *StreamServer {
    53  	return &StreamServer{cache: cache, daemon: daemon}
    54  }
    55  
    56  // ServeStream implements the jsonrpc2.StreamServer interface, by handling
    57  // incoming streams using a new lsp server.
    58  func (s *StreamServer) ServeStream(ctx context.Context, conn jsonrpc2.Conn) error {
    59  	client := protocol.ClientDispatcher(conn)
    60  	session := s.cache.NewSession(ctx)
    61  	server := s.serverForTest
    62  	if server == nil {
    63  		server = lsp.NewServer(session, client)
    64  	}
    65  	// Clients may or may not send a shutdown message. Make sure the server is
    66  	// shut down.
    67  	// TODO(rFindley): this shutdown should perhaps be on a disconnected context.
    68  	defer func() {
    69  		if err := server.Shutdown(ctx); err != nil {
    70  			event.Error(ctx, "error shutting down", err)
    71  		}
    72  	}()
    73  	executable, err := os.Executable()
    74  	if err != nil {
    75  		log.Printf("error getting gopls path: %v", err)
    76  		executable = ""
    77  	}
    78  	ctx = protocol.WithClient(ctx, client)
    79  	conn.Go(ctx,
    80  		protocol.Handlers(
    81  			handshaker(session, executable, s.daemon,
    82  				protocol.ServerHandler(server,
    83  					jsonrpc2.MethodNotFound))))
    84  	if s.daemon {
    85  		log.Printf("Session %s: connected", session.ID())
    86  		defer log.Printf("Session %s: exited", session.ID())
    87  	}
    88  	<-conn.Done()
    89  	return conn.Err()
    90  }
    91  
    92  // A Forwarder is a jsonrpc2.StreamServer that handles an LSP stream by
    93  // forwarding it to a remote. This is used when the gopls process started by
    94  // the editor is in the `-remote` mode, which means it finds and connects to a
    95  // separate gopls daemon. In these cases, we still want the forwarder gopls to
    96  // be instrumented with telemetry, and want to be able to in some cases hijack
    97  // the jsonrpc2 connection with the daemon.
    98  type Forwarder struct {
    99  	network, addr string
   100  
   101  	// goplsPath is the path to the current executing gopls binary.
   102  	goplsPath string
   103  
   104  	// configuration for the auto-started gopls remote.
   105  	remoteConfig remoteConfig
   106  }
   107  
   108  type remoteConfig struct {
   109  	debug         string
   110  	listenTimeout time.Duration
   111  	logfile       string
   112  }
   113  
   114  // A RemoteOption configures the behavior of the auto-started remote.
   115  type RemoteOption interface {
   116  	set(*remoteConfig)
   117  }
   118  
   119  // RemoteDebugAddress configures the address used by the auto-started Gopls daemon
   120  // for serving debug information.
   121  type RemoteDebugAddress string
   122  
   123  func (d RemoteDebugAddress) set(cfg *remoteConfig) {
   124  	cfg.debug = string(d)
   125  }
   126  
   127  // RemoteListenTimeout configures the amount of time the auto-started gopls
   128  // daemon will wait with no client connections before shutting down.
   129  type RemoteListenTimeout time.Duration
   130  
   131  func (d RemoteListenTimeout) set(cfg *remoteConfig) {
   132  	cfg.listenTimeout = time.Duration(d)
   133  }
   134  
   135  // RemoteLogfile configures the logfile location for the auto-started gopls
   136  // daemon.
   137  type RemoteLogfile string
   138  
   139  func (l RemoteLogfile) set(cfg *remoteConfig) {
   140  	cfg.logfile = string(l)
   141  }
   142  
   143  func defaultRemoteConfig() remoteConfig {
   144  	return remoteConfig{
   145  		listenTimeout: 1 * time.Minute,
   146  	}
   147  }
   148  
   149  // NewForwarder creates a new Forwarder, ready to forward connections to the
   150  // remote server specified by network and addr.
   151  func NewForwarder(network, addr string, opts ...RemoteOption) *Forwarder {
   152  	gp, err := os.Executable()
   153  	if err != nil {
   154  		log.Printf("error getting gopls path for forwarder: %v", err)
   155  		gp = ""
   156  	}
   157  
   158  	rcfg := defaultRemoteConfig()
   159  	for _, opt := range opts {
   160  		opt.set(&rcfg)
   161  	}
   162  
   163  	fwd := &Forwarder{
   164  		network:      network,
   165  		addr:         addr,
   166  		goplsPath:    gp,
   167  		remoteConfig: rcfg,
   168  	}
   169  	return fwd
   170  }
   171  
   172  // QueryServerState queries the server state of the current server.
   173  func QueryServerState(ctx context.Context, network, address string) (*ServerState, error) {
   174  	if network == AutoNetwork {
   175  		gp, err := os.Executable()
   176  		if err != nil {
   177  			return nil, errors.Errorf("getting gopls path: %w", err)
   178  		}
   179  		network, address = autoNetworkAddress(gp, address)
   180  	}
   181  	netConn, err := net.DialTimeout(network, address, 5*time.Second)
   182  	if err != nil {
   183  		return nil, errors.Errorf("dialing remote: %w", err)
   184  	}
   185  	serverConn := jsonrpc2.NewConn(jsonrpc2.NewHeaderStream(netConn))
   186  	serverConn.Go(ctx, jsonrpc2.MethodNotFound)
   187  	var state ServerState
   188  	if err := protocol.Call(ctx, serverConn, sessionsMethod, nil, &state); err != nil {
   189  		return nil, errors.Errorf("querying server state: %w", err)
   190  	}
   191  	return &state, nil
   192  }
   193  
   194  // ServeStream dials the forwarder remote and binds the remote to serve the LSP
   195  // on the incoming stream.
   196  func (f *Forwarder) ServeStream(ctx context.Context, clientConn jsonrpc2.Conn) error {
   197  	client := protocol.ClientDispatcher(clientConn)
   198  
   199  	netConn, err := f.connectToRemote(ctx)
   200  	if err != nil {
   201  		return errors.Errorf("forwarder: connecting to remote: %w", err)
   202  	}
   203  	serverConn := jsonrpc2.NewConn(jsonrpc2.NewHeaderStream(netConn))
   204  	server := protocol.ServerDispatcher(serverConn)
   205  
   206  	// Forward between connections.
   207  	serverConn.Go(ctx,
   208  		protocol.Handlers(
   209  			protocol.ClientHandler(client,
   210  				jsonrpc2.MethodNotFound)))
   211  	// Don't run the clientConn yet, so that we can complete the handshake before
   212  	// processing any client messages.
   213  
   214  	// Do a handshake with the server instance to exchange debug information.
   215  	index := atomic.AddInt64(&serverIndex, 1)
   216  	serverID := strconv.FormatInt(index, 10)
   217  	var (
   218  		hreq = handshakeRequest{
   219  			ServerID:  serverID,
   220  			GoplsPath: f.goplsPath,
   221  		}
   222  		hresp handshakeResponse
   223  	)
   224  	if di := debug.GetInstance(ctx); di != nil {
   225  		hreq.Logfile = di.Logfile
   226  		hreq.DebugAddr = di.ListenedDebugAddress
   227  	}
   228  	if err := protocol.Call(ctx, serverConn, handshakeMethod, hreq, &hresp); err != nil {
   229  		// TODO(rfindley): at some point in the future we should return an error
   230  		// here.  Handshakes have become functional in nature.
   231  		event.Error(ctx, "forwarder: gopls handshake failed", err)
   232  	}
   233  	if hresp.GoplsPath != f.goplsPath {
   234  		event.Error(ctx, "", fmt.Errorf("forwarder: gopls path mismatch: forwarder is %q, remote is %q", f.goplsPath, hresp.GoplsPath))
   235  	}
   236  	event.Log(ctx, "New server",
   237  		tag.NewServer.Of(serverID),
   238  		tag.Logfile.Of(hresp.Logfile),
   239  		tag.DebugAddress.Of(hresp.DebugAddr),
   240  		tag.GoplsPath.Of(hresp.GoplsPath),
   241  		tag.ClientID.Of(hresp.SessionID),
   242  	)
   243  	clientConn.Go(ctx,
   244  		protocol.Handlers(
   245  			forwarderHandler(
   246  				protocol.ServerHandler(server,
   247  					jsonrpc2.MethodNotFound))))
   248  
   249  	select {
   250  	case <-serverConn.Done():
   251  		clientConn.Close()
   252  	case <-clientConn.Done():
   253  		serverConn.Close()
   254  	}
   255  
   256  	err = nil
   257  	if serverConn.Err() != nil {
   258  		err = errors.Errorf("remote disconnected: %v", err)
   259  	} else if clientConn.Err() != nil {
   260  		err = errors.Errorf("client disconnected: %v", err)
   261  	}
   262  	event.Log(ctx, fmt.Sprintf("forwarder: exited with error: %v", err))
   263  	return err
   264  }
   265  
   266  func (f *Forwarder) connectToRemote(ctx context.Context) (net.Conn, error) {
   267  	return connectToRemote(ctx, f.network, f.addr, f.goplsPath, f.remoteConfig)
   268  }
   269  
   270  func ConnectToRemote(ctx context.Context, network, addr string, opts ...RemoteOption) (net.Conn, error) {
   271  	rcfg := defaultRemoteConfig()
   272  	for _, opt := range opts {
   273  		opt.set(&rcfg)
   274  	}
   275  	// This is not strictly necessary, as it won't be used if not connecting to
   276  	// the 'auto' remote.
   277  	goplsPath, err := os.Executable()
   278  	if err != nil {
   279  		return nil, fmt.Errorf("unable to resolve gopls path: %v", err)
   280  	}
   281  	return connectToRemote(ctx, network, addr, goplsPath, rcfg)
   282  }
   283  
   284  func connectToRemote(ctx context.Context, inNetwork, inAddr, goplsPath string, rcfg remoteConfig) (net.Conn, error) {
   285  	var (
   286  		netConn          net.Conn
   287  		err              error
   288  		network, address = inNetwork, inAddr
   289  	)
   290  	if inNetwork == AutoNetwork {
   291  		// f.network is overloaded to support a concept of 'automatic' addresses,
   292  		// which signals that the gopls remote address should be automatically
   293  		// derived.
   294  		// So we need to resolve a real network and address here.
   295  		network, address = autoNetworkAddress(goplsPath, inAddr)
   296  	}
   297  	// Attempt to verify that we own the remote. This is imperfect, but if we can
   298  	// determine that the remote is owned by a different user, we should fail.
   299  	ok, err := verifyRemoteOwnership(network, address)
   300  	if err != nil {
   301  		// If the ownership check itself failed, we fail open but log an error to
   302  		// the user.
   303  		event.Error(ctx, "unable to check daemon socket owner, failing open", err)
   304  	} else if !ok {
   305  		// We successfully checked that the socket is not owned by us, we fail
   306  		// closed.
   307  		return nil, fmt.Errorf("socket %q is owned by a different user", address)
   308  	}
   309  	const dialTimeout = 1 * time.Second
   310  	// Try dialing our remote once, in case it is already running.
   311  	netConn, err = net.DialTimeout(network, address, dialTimeout)
   312  	if err == nil {
   313  		return netConn, nil
   314  	}
   315  	// If our remote is on the 'auto' network, start it if it doesn't exist.
   316  	if inNetwork == AutoNetwork {
   317  		if goplsPath == "" {
   318  			return nil, fmt.Errorf("cannot auto-start remote: gopls path is unknown")
   319  		}
   320  		if network == "unix" {
   321  			// Sometimes the socketfile isn't properly cleaned up when gopls shuts
   322  			// down. Since we have already tried and failed to dial this address, it
   323  			// should *usually* be safe to remove the socket before binding to the
   324  			// address.
   325  			// TODO(rfindley): there is probably a race here if multiple gopls
   326  			// instances are simultaneously starting up.
   327  			if _, err := os.Stat(address); err == nil {
   328  				if err := os.Remove(address); err != nil {
   329  					return nil, errors.Errorf("removing remote socket file: %w", err)
   330  				}
   331  			}
   332  		}
   333  		args := []string{"serve",
   334  			"-listen", fmt.Sprintf(`%s;%s`, network, address),
   335  			"-listen.timeout", rcfg.listenTimeout.String(),
   336  		}
   337  		if rcfg.logfile != "" {
   338  			args = append(args, "-logfile", rcfg.logfile)
   339  		}
   340  		if rcfg.debug != "" {
   341  			args = append(args, "-debug", rcfg.debug)
   342  		}
   343  		if err := startRemote(goplsPath, args...); err != nil {
   344  			return nil, errors.Errorf("startRemote(%q, %v): %w", goplsPath, args, err)
   345  		}
   346  	}
   347  
   348  	const retries = 5
   349  	// It can take some time for the newly started server to bind to our address,
   350  	// so we retry for a bit.
   351  	for retry := 0; retry < retries; retry++ {
   352  		startDial := time.Now()
   353  		netConn, err = net.DialTimeout(network, address, dialTimeout)
   354  		if err == nil {
   355  			return netConn, nil
   356  		}
   357  		event.Log(ctx, fmt.Sprintf("failed attempt #%d to connect to remote: %v\n", retry+2, err))
   358  		// In case our failure was a fast-failure, ensure we wait at least
   359  		// f.dialTimeout before trying again.
   360  		if retry != retries-1 {
   361  			time.Sleep(dialTimeout - time.Since(startDial))
   362  		}
   363  	}
   364  	return nil, errors.Errorf("dialing remote: %w", err)
   365  }
   366  
   367  // forwarderHandler intercepts 'exit' messages to prevent the shared gopls
   368  // instance from exiting. In the future it may also intercept 'shutdown' to
   369  // provide more graceful shutdown of the client connection.
   370  func forwarderHandler(handler jsonrpc2.Handler) jsonrpc2.Handler {
   371  	return func(ctx context.Context, reply jsonrpc2.Replier, r jsonrpc2.Request) error {
   372  		// The gopls workspace environment defaults to the process environment in
   373  		// which gopls daemon was started. To avoid discrepancies in Go environment
   374  		// between the editor and daemon, inject any unset variables in `go env`
   375  		// into the options sent by initialize.
   376  		//
   377  		// See also golang.org/issue/37830.
   378  		if r.Method() == "initialize" {
   379  			if newr, err := addGoEnvToInitializeRequest(ctx, r); err == nil {
   380  				r = newr
   381  			} else {
   382  				log.Printf("unable to add local env to initialize request: %v", err)
   383  			}
   384  		}
   385  		return handler(ctx, reply, r)
   386  	}
   387  }
   388  
   389  // addGoEnvToInitializeRequest builds a new initialize request in which we set
   390  // any environment variables output by `go env` and not already present in the
   391  // request.
   392  //
   393  // It returns an error if r is not an initialize requst, or is otherwise
   394  // malformed.
   395  func addGoEnvToInitializeRequest(ctx context.Context, r jsonrpc2.Request) (jsonrpc2.Request, error) {
   396  	var params protocol.ParamInitialize
   397  	if err := json.Unmarshal(r.Params(), &params); err != nil {
   398  		return nil, err
   399  	}
   400  	var opts map[string]interface{}
   401  	switch v := params.InitializationOptions.(type) {
   402  	case nil:
   403  		opts = make(map[string]interface{})
   404  	case map[string]interface{}:
   405  		opts = v
   406  	default:
   407  		return nil, fmt.Errorf("unexpected type for InitializationOptions: %T", v)
   408  	}
   409  	envOpt, ok := opts["env"]
   410  	if !ok {
   411  		envOpt = make(map[string]interface{})
   412  	}
   413  	env, ok := envOpt.(map[string]interface{})
   414  	if !ok {
   415  		return nil, fmt.Errorf(`env option is %T, expected a map`, envOpt)
   416  	}
   417  	goenv, err := getGoEnv(ctx, env)
   418  	if err != nil {
   419  		return nil, err
   420  	}
   421  	for govar, value := range goenv {
   422  		env[govar] = value
   423  	}
   424  	opts["env"] = env
   425  	params.InitializationOptions = opts
   426  	call, ok := r.(*jsonrpc2.Call)
   427  	if !ok {
   428  		return nil, fmt.Errorf("%T is not a *jsonrpc2.Call", r)
   429  	}
   430  	return jsonrpc2.NewCall(call.ID(), "initialize", params)
   431  }
   432  
   433  func getGoEnv(ctx context.Context, env map[string]interface{}) (map[string]string, error) {
   434  	var runEnv []string
   435  	for k, v := range env {
   436  		runEnv = append(runEnv, fmt.Sprintf("%s=%s", k, v))
   437  	}
   438  	runner := gocommand.Runner{}
   439  	output, err := runner.Run(ctx, gocommand.Invocation{
   440  		Verb: "env",
   441  		Args: []string{"-json"},
   442  		Env:  runEnv,
   443  	})
   444  	if err != nil {
   445  		return nil, err
   446  	}
   447  	envmap := make(map[string]string)
   448  	if err := json.Unmarshal(output.Bytes(), &envmap); err != nil {
   449  		return nil, err
   450  	}
   451  	return envmap, nil
   452  }
   453  
   454  // A handshakeRequest identifies a client to the LSP server.
   455  type handshakeRequest struct {
   456  	// ServerID is the ID of the server on the client. This should usually be 0.
   457  	ServerID string `json:"serverID"`
   458  	// Logfile is the location of the clients log file.
   459  	Logfile string `json:"logfile"`
   460  	// DebugAddr is the client debug address.
   461  	DebugAddr string `json:"debugAddr"`
   462  	// GoplsPath is the path to the Gopls binary running the current client
   463  	// process.
   464  	GoplsPath string `json:"goplsPath"`
   465  }
   466  
   467  // A handshakeResponse is returned by the LSP server to tell the LSP client
   468  // information about its session.
   469  type handshakeResponse struct {
   470  	// SessionID is the server session associated with the client.
   471  	SessionID string `json:"sessionID"`
   472  	// Logfile is the location of the server logs.
   473  	Logfile string `json:"logfile"`
   474  	// DebugAddr is the server debug address.
   475  	DebugAddr string `json:"debugAddr"`
   476  	// GoplsPath is the path to the Gopls binary running the current server
   477  	// process.
   478  	GoplsPath string `json:"goplsPath"`
   479  }
   480  
   481  // ClientSession identifies a current client LSP session on the server. Note
   482  // that it looks similar to handshakeResposne, but in fact 'Logfile' and
   483  // 'DebugAddr' now refer to the client.
   484  type ClientSession struct {
   485  	SessionID string `json:"sessionID"`
   486  	Logfile   string `json:"logfile"`
   487  	DebugAddr string `json:"debugAddr"`
   488  }
   489  
   490  // ServerState holds information about the gopls daemon process, including its
   491  // debug information and debug information of all of its current connected
   492  // clients.
   493  type ServerState struct {
   494  	Logfile         string          `json:"logfile"`
   495  	DebugAddr       string          `json:"debugAddr"`
   496  	GoplsPath       string          `json:"goplsPath"`
   497  	CurrentClientID string          `json:"currentClientID"`
   498  	Clients         []ClientSession `json:"clients"`
   499  }
   500  
   501  const (
   502  	handshakeMethod = "gopls/handshake"
   503  	sessionsMethod  = "gopls/sessions"
   504  )
   505  
   506  func handshaker(session *cache.Session, goplsPath string, logHandshakes bool, handler jsonrpc2.Handler) jsonrpc2.Handler {
   507  	return func(ctx context.Context, reply jsonrpc2.Replier, r jsonrpc2.Request) error {
   508  		switch r.Method() {
   509  		case handshakeMethod:
   510  			// We log.Printf in this handler, rather than event.Log when we want logs
   511  			// to go to the daemon log rather than being reflected back to the
   512  			// client.
   513  			var req handshakeRequest
   514  			if err := json.Unmarshal(r.Params(), &req); err != nil {
   515  				if logHandshakes {
   516  					log.Printf("Error processing handshake for session %s: %v", session.ID(), err)
   517  				}
   518  				sendError(ctx, reply, err)
   519  				return nil
   520  			}
   521  			if logHandshakes {
   522  				log.Printf("Session %s: got handshake. Logfile: %q, Debug addr: %q", session.ID(), req.Logfile, req.DebugAddr)
   523  			}
   524  			event.Log(ctx, "Handshake session update",
   525  				cache.KeyUpdateSession.Of(session),
   526  				tag.DebugAddress.Of(req.DebugAddr),
   527  				tag.Logfile.Of(req.Logfile),
   528  				tag.ServerID.Of(req.ServerID),
   529  				tag.GoplsPath.Of(req.GoplsPath),
   530  			)
   531  			resp := handshakeResponse{
   532  				SessionID: session.ID(),
   533  				GoplsPath: goplsPath,
   534  			}
   535  			if di := debug.GetInstance(ctx); di != nil {
   536  				resp.Logfile = di.Logfile
   537  				resp.DebugAddr = di.ListenedDebugAddress
   538  			}
   539  
   540  			return reply(ctx, resp, nil)
   541  		case sessionsMethod:
   542  			resp := ServerState{
   543  				GoplsPath:       goplsPath,
   544  				CurrentClientID: session.ID(),
   545  			}
   546  			if di := debug.GetInstance(ctx); di != nil {
   547  				resp.Logfile = di.Logfile
   548  				resp.DebugAddr = di.ListenedDebugAddress
   549  				for _, c := range di.State.Clients() {
   550  					resp.Clients = append(resp.Clients, ClientSession{
   551  						SessionID: c.Session.ID(),
   552  						Logfile:   c.Logfile,
   553  						DebugAddr: c.DebugAddress,
   554  					})
   555  				}
   556  			}
   557  			return reply(ctx, resp, nil)
   558  		}
   559  		return handler(ctx, reply, r)
   560  	}
   561  }
   562  
   563  func sendError(ctx context.Context, reply jsonrpc2.Replier, err error) {
   564  	err = errors.Errorf("%v: %w", err, jsonrpc2.ErrParse)
   565  	if err := reply(ctx, nil, err); err != nil {
   566  		event.Error(ctx, "", err)
   567  	}
   568  }