github.com/april1989/origin-go-tools@v0.0.32/internal/lsp/lsprpc/lsprpc.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package lsprpc implements a jsonrpc2.StreamServer that may be used to
     6  // serve the LSP on a jsonrpc2 channel.
     7  package lsprpc
     8  
     9  import (
    10  	"context"
    11  	"encoding/json"
    12  	"fmt"
    13  	"log"
    14  	"net"
    15  	"os"
    16  	"strconv"
    17  	"sync/atomic"
    18  	"time"
    19  
    20  	"github.com/april1989/origin-go-tools/internal/event"
    21  	"github.com/april1989/origin-go-tools/internal/gocommand"
    22  	"github.com/april1989/origin-go-tools/internal/jsonrpc2"
    23  	"github.com/april1989/origin-go-tools/internal/lsp"
    24  	"github.com/april1989/origin-go-tools/internal/lsp/cache"
    25  	"github.com/april1989/origin-go-tools/internal/lsp/debug"
    26  	"github.com/april1989/origin-go-tools/internal/lsp/debug/tag"
    27  	"github.com/april1989/origin-go-tools/internal/lsp/protocol"
    28  	errors "golang.org/x/xerrors"
    29  )
    30  
    31  // AutoNetwork is the pseudo network type used to signal that gopls should use
    32  // automatic discovery to resolve a remote address.
    33  const AutoNetwork = "auto"
    34  
    35  // Unique identifiers for client/server.
    36  var serverIndex int64
    37  
    38  // The StreamServer type is a jsonrpc2.StreamServer that handles incoming
    39  // streams as a new LSP session, using a shared cache.
    40  type StreamServer struct {
    41  	cache *cache.Cache
    42  	// logConnections controls whether or not to log new connections.
    43  	logConnections bool
    44  
    45  	// serverForTest may be set to a test fake for golibexec_testing.
    46  	serverForTest protocol.Server
    47  }
    48  
    49  // NewStreamServer creates a StreamServer using the shared cache. If
    50  // withTelemetry is true, each session is instrumented with telemetry that
    51  // records RPC statistics.
    52  func NewStreamServer(cache *cache.Cache, logConnections bool) *StreamServer {
    53  	return &StreamServer{cache: cache, logConnections: logConnections}
    54  }
    55  
    56  // ServeStream implements the jsonrpc2.StreamServer interface, by handling
    57  // incoming streams using a new lsp server.
    58  func (s *StreamServer) ServeStream(ctx context.Context, conn jsonrpc2.Conn) error {
    59  	client := protocol.ClientDispatcher(conn)
    60  	session := s.cache.NewSession(ctx)
    61  	server := s.serverForTest
    62  	if server == nil {
    63  		server = lsp.NewServer(session, client)
    64  	}
    65  	// Clients may or may not send a shutdown message. Make sure the server is
    66  	// shut down.
    67  	// TODO(rFindley): this shutdown should perhaps be on a disconnected context.
    68  	defer func() {
    69  		if err := server.Shutdown(ctx); err != nil {
    70  			event.Error(ctx, "error shutting down", err)
    71  		}
    72  	}()
    73  	executable, err := os.Executable()
    74  	if err != nil {
    75  		log.Printf("error getting gopls path: %v", err)
    76  		executable = ""
    77  	}
    78  	ctx = protocol.WithClient(ctx, client)
    79  	conn.Go(ctx,
    80  		protocol.Handlers(
    81  			handshaker(session, executable, s.logConnections,
    82  				protocol.ServerHandler(server,
    83  					jsonrpc2.MethodNotFound))))
    84  	if s.logConnections {
    85  		log.Printf("Session %s: connected", session.ID())
    86  		defer log.Printf("Session %s: exited", session.ID())
    87  	}
    88  	<-conn.Done()
    89  	return conn.Err()
    90  }
    91  
    92  // A Forwarder is a jsonrpc2.StreamServer that handles an LSP stream by
    93  // forwarding it to a remote. This is used when the gopls process started by
    94  // the editor is in the `-remote` mode, which means it finds and connects to a
    95  // separate gopls daemon. In these cases, we still want the forwarder gopls to
    96  // be instrumented with telemetry, and want to be able to in some cases hijack
    97  // the jsonrpc2 connection with the daemon.
    98  type Forwarder struct {
    99  	network, addr string
   100  
   101  	// goplsPath is the path to the current executing gopls binary.
   102  	goplsPath string
   103  
   104  	// configuration for the auto-started gopls remote.
   105  	remoteConfig remoteConfig
   106  }
   107  
   108  type remoteConfig struct {
   109  	debug         string
   110  	listenTimeout time.Duration
   111  	logfile       string
   112  }
   113  
   114  // A RemoteOption configures the behavior of the auto-started remote.
   115  type RemoteOption interface {
   116  	set(*remoteConfig)
   117  }
   118  
   119  // RemoteDebugAddress configures the address used by the auto-started Gopls daemon
   120  // for serving debug information.
   121  type RemoteDebugAddress string
   122  
   123  func (d RemoteDebugAddress) set(cfg *remoteConfig) {
   124  	cfg.debug = string(d)
   125  }
   126  
   127  // RemoteListenTimeout configures the amount of time the auto-started gopls
   128  // daemon will wait with no client connections before shutting down.
   129  type RemoteListenTimeout time.Duration
   130  
   131  func (d RemoteListenTimeout) set(cfg *remoteConfig) {
   132  	cfg.listenTimeout = time.Duration(d)
   133  }
   134  
   135  // RemoteLogfile configures the logfile location for the auto-started gopls
   136  // daemon.
   137  type RemoteLogfile string
   138  
   139  func (l RemoteLogfile) set(cfg *remoteConfig) {
   140  	cfg.logfile = string(l)
   141  }
   142  
   143  func defaultRemoteConfig() remoteConfig {
   144  	return remoteConfig{
   145  		listenTimeout: 1 * time.Minute,
   146  	}
   147  }
   148  
   149  // NewForwarder creates a new Forwarder, ready to forward connections to the
   150  // remote server specified by network and addr.
   151  func NewForwarder(network, addr string, opts ...RemoteOption) *Forwarder {
   152  	gp, err := os.Executable()
   153  	if err != nil {
   154  		log.Printf("error getting gopls path for forwarder: %v", err)
   155  		gp = ""
   156  	}
   157  
   158  	rcfg := defaultRemoteConfig()
   159  	for _, opt := range opts {
   160  		opt.set(&rcfg)
   161  	}
   162  
   163  	fwd := &Forwarder{
   164  		network:      network,
   165  		addr:         addr,
   166  		goplsPath:    gp,
   167  		remoteConfig: rcfg,
   168  	}
   169  	return fwd
   170  }
   171  
   172  // QueryServerState queries the server state of the current server.
   173  func QueryServerState(ctx context.Context, network, address string) (*ServerState, error) {
   174  	if network == AutoNetwork {
   175  		gp, err := os.Executable()
   176  		if err != nil {
   177  			return nil, errors.Errorf("getting gopls path: %w", err)
   178  		}
   179  		network, address = autoNetworkAddress(gp, address)
   180  	}
   181  	netConn, err := net.DialTimeout(network, address, 5*time.Second)
   182  	if err != nil {
   183  		return nil, errors.Errorf("dialing remote: %w", err)
   184  	}
   185  	serverConn := jsonrpc2.NewConn(jsonrpc2.NewHeaderStream(netConn))
   186  	serverConn.Go(ctx, jsonrpc2.MethodNotFound)
   187  	var state ServerState
   188  	if err := protocol.Call(ctx, serverConn, sessionsMethod, nil, &state); err != nil {
   189  		return nil, errors.Errorf("querying server state: %w", err)
   190  	}
   191  	return &state, nil
   192  }
   193  
   194  // ServeStream dials the forwarder remote and binds the remote to serve the LSP
   195  // on the incoming stream.
   196  func (f *Forwarder) ServeStream(ctx context.Context, clientConn jsonrpc2.Conn) error {
   197  	client := protocol.ClientDispatcher(clientConn)
   198  
   199  	netConn, err := f.connectToRemote(ctx)
   200  	if err != nil {
   201  		return errors.Errorf("forwarder: connecting to remote: %w", err)
   202  	}
   203  	serverConn := jsonrpc2.NewConn(jsonrpc2.NewHeaderStream(netConn))
   204  	server := protocol.ServerDispatcher(serverConn)
   205  
   206  	// Forward between connections.
   207  	serverConn.Go(ctx,
   208  		protocol.Handlers(
   209  			protocol.ClientHandler(client,
   210  				jsonrpc2.MethodNotFound)))
   211  	// Don't run the clientConn yet, so that we can complete the handshake before
   212  	// processing any client messages.
   213  
   214  	// Do a handshake with the server instance to exchange debug information.
   215  	index := atomic.AddInt64(&serverIndex, 1)
   216  	serverID := strconv.FormatInt(index, 10)
   217  	var (
   218  		hreq = handshakeRequest{
   219  			ServerID:  serverID,
   220  			GoplsPath: f.goplsPath,
   221  		}
   222  		hresp handshakeResponse
   223  	)
   224  	if di := debug.GetInstance(ctx); di != nil {
   225  		hreq.Logfile = di.Logfile
   226  		hreq.DebugAddr = di.ListenedDebugAddress
   227  	}
   228  	if err := protocol.Call(ctx, serverConn, handshakeMethod, hreq, &hresp); err != nil {
   229  		event.Error(ctx, "forwarder: gopls handshake failed", err)
   230  	}
   231  	if hresp.GoplsPath != f.goplsPath {
   232  		event.Error(ctx, "", fmt.Errorf("forwarder: gopls path mismatch: forwarder is %q, remote is %q", f.goplsPath, hresp.GoplsPath))
   233  	}
   234  	event.Log(ctx, "New server",
   235  		tag.NewServer.Of(serverID),
   236  		tag.Logfile.Of(hresp.Logfile),
   237  		tag.DebugAddress.Of(hresp.DebugAddr),
   238  		tag.GoplsPath.Of(hresp.GoplsPath),
   239  		tag.ClientID.Of(hresp.SessionID),
   240  	)
   241  	clientConn.Go(ctx,
   242  		protocol.Handlers(
   243  			forwarderHandler(
   244  				protocol.ServerHandler(server,
   245  					jsonrpc2.MethodNotFound))))
   246  
   247  	select {
   248  	case <-serverConn.Done():
   249  		clientConn.Close()
   250  	case <-clientConn.Done():
   251  		serverConn.Close()
   252  	}
   253  
   254  	err = nil
   255  	if serverConn.Err() != nil {
   256  		err = errors.Errorf("remote disconnected: %v", err)
   257  	} else if clientConn.Err() != nil {
   258  		err = errors.Errorf("client disconnected: %v", err)
   259  	}
   260  	event.Log(ctx, fmt.Sprintf("forwarder: exited with error: %v", err))
   261  	return err
   262  }
   263  
   264  func (f *Forwarder) connectToRemote(ctx context.Context) (net.Conn, error) {
   265  	return connectToRemote(ctx, f.network, f.addr, f.goplsPath, f.remoteConfig)
   266  }
   267  
   268  func ConnectToRemote(ctx context.Context, network, addr string, opts ...RemoteOption) (net.Conn, error) {
   269  	rcfg := defaultRemoteConfig()
   270  	for _, opt := range opts {
   271  		opt.set(&rcfg)
   272  	}
   273  	// This is not strictly necessary, as it won't be used if not connecting to
   274  	// the 'auto' remote.
   275  	goplsPath, err := os.Executable()
   276  	if err != nil {
   277  		return nil, fmt.Errorf("unable to resolve gopls path: %v", err)
   278  	}
   279  	return connectToRemote(ctx, network, addr, goplsPath, rcfg)
   280  }
   281  
   282  func connectToRemote(ctx context.Context, inNetwork, inAddr, goplsPath string, rcfg remoteConfig) (net.Conn, error) {
   283  	var (
   284  		netConn          net.Conn
   285  		err              error
   286  		network, address = inNetwork, inAddr
   287  	)
   288  	if inNetwork == AutoNetwork {
   289  		// f.network is overloaded to support a concept of 'automatic' addresses,
   290  		// which signals that the gopls remote address should be automatically
   291  		// derived.
   292  		// So we need to resolve a real network and address here.
   293  		network, address = autoNetworkAddress(goplsPath, inAddr)
   294  	}
   295  	// Attempt to verify that we own the remote. This is imperfect, but if we can
   296  	// determine that the remote is owned by a different user, we should fail.
   297  	ok, err := verifyRemoteOwnership(network, address)
   298  	if err != nil {
   299  		// If the ownership check itself failed, we fail open but log an error to
   300  		// the user.
   301  		event.Error(ctx, "unable to check daemon socket owner, failing open", err)
   302  	} else if !ok {
   303  		// We succesfully checked that the socket is not owned by us, we fail
   304  		// closed.
   305  		return nil, fmt.Errorf("socket %q is owned by a different user", address)
   306  	}
   307  	const dialTimeout = 1 * time.Second
   308  	// Try dialing our remote once, in case it is already running.
   309  	netConn, err = net.DialTimeout(network, address, dialTimeout)
   310  	if err == nil {
   311  		return netConn, nil
   312  	}
   313  	// If our remote is on the 'auto' network, start it if it doesn't exist.
   314  	if inNetwork == AutoNetwork {
   315  		if goplsPath == "" {
   316  			return nil, fmt.Errorf("cannot auto-start remote: gopls path is unknown")
   317  		}
   318  		if network == "unix" {
   319  			// Sometimes the socketfile isn't properly cleaned up when gopls shuts
   320  			// down. Since we have already tried and failed to dial this address, it
   321  			// should *usually* be safe to remove the socket before binding to the
   322  			// address.
   323  			// TODO(rfindley): there is probably a race here if multiple gopls
   324  			// instances are simultaneously starting up.
   325  			if _, err := os.Stat(address); err == nil {
   326  				if err := os.Remove(address); err != nil {
   327  					return nil, errors.Errorf("removing remote socket file: %w", err)
   328  				}
   329  			}
   330  		}
   331  		args := []string{"serve",
   332  			"-listen", fmt.Sprintf(`%s;%s`, network, address),
   333  			"-listen.timeout", rcfg.listenTimeout.String(),
   334  		}
   335  		if rcfg.logfile != "" {
   336  			args = append(args, "-logfile", rcfg.logfile)
   337  		}
   338  		if rcfg.debug != "" {
   339  			args = append(args, "-debug", rcfg.debug)
   340  		}
   341  		if err := startRemote(goplsPath, args...); err != nil {
   342  			return nil, errors.Errorf("startRemote(%q, %v): %w", goplsPath, args, err)
   343  		}
   344  	}
   345  
   346  	const retries = 5
   347  	// It can take some time for the newly started server to bind to our address,
   348  	// so we retry for a bit.
   349  	for retry := 0; retry < retries; retry++ {
   350  		startDial := time.Now()
   351  		netConn, err = net.DialTimeout(network, address, dialTimeout)
   352  		if err == nil {
   353  			return netConn, nil
   354  		}
   355  		event.Log(ctx, fmt.Sprintf("failed attempt #%d to connect to remote: %v\n", retry+2, err))
   356  		// In case our failure was a fast-failure, ensure we wait at least
   357  		// f.dialTimeout before trying again.
   358  		if retry != retries-1 {
   359  			time.Sleep(dialTimeout - time.Since(startDial))
   360  		}
   361  	}
   362  	return nil, errors.Errorf("dialing remote: %w", err)
   363  }
   364  
   365  // forwarderHandler intercepts 'exit' messages to prevent the shared gopls
   366  // instance from exiting. In the future it may also intercept 'shutdown' to
   367  // provide more graceful shutdown of the client connection.
   368  func forwarderHandler(handler jsonrpc2.Handler) jsonrpc2.Handler {
   369  	return func(ctx context.Context, reply jsonrpc2.Replier, r jsonrpc2.Request) error {
   370  		// The gopls workspace environment defaults to the process environment in
   371  		// which gopls daemon was started. To avoid discrepancies in Go environment
   372  		// between the editor and daemon, inject any unset variables in `go env`
   373  		// into the options sent by initialize.
   374  		//
   375  		// See also golang.org/issue/37830.
   376  		if r.Method() == "initialize" {
   377  			if newr, err := addGoEnvToInitializeRequest(ctx, r); err == nil {
   378  				r = newr
   379  			} else {
   380  				log.Printf("unable to add local env to initialize request: %v", err)
   381  			}
   382  		}
   383  		return handler(ctx, reply, r)
   384  	}
   385  }
   386  
   387  // addGoEnvToInitializeRequest builds a new initialize request in which we set
   388  // any environment variables output by `go env` and not already present in the
   389  // request.
   390  //
   391  // It returns an error if r is not an initialize requst, or is otherwise
   392  // malformed.
   393  func addGoEnvToInitializeRequest(ctx context.Context, r jsonrpc2.Request) (jsonrpc2.Request, error) {
   394  	var params protocol.ParamInitialize
   395  	if err := json.Unmarshal(r.Params(), &params); err != nil {
   396  		return nil, err
   397  	}
   398  	var opts map[string]interface{}
   399  	switch v := params.InitializationOptions.(type) {
   400  	case nil:
   401  		opts = make(map[string]interface{})
   402  	case map[string]interface{}:
   403  		opts = v
   404  	default:
   405  		return nil, fmt.Errorf("unexpected type for InitializationOptions: %T", v)
   406  	}
   407  	envOpt, ok := opts["env"]
   408  	if !ok {
   409  		envOpt = make(map[string]interface{})
   410  	}
   411  	env, ok := envOpt.(map[string]interface{})
   412  	if !ok {
   413  		return nil, fmt.Errorf(`env option is %T, expected a map`, envOpt)
   414  	}
   415  	goenv, err := getGoEnv(ctx, env)
   416  	if err != nil {
   417  		return nil, err
   418  	}
   419  	for govar, value := range goenv {
   420  		env[govar] = value
   421  	}
   422  	opts["env"] = env
   423  	params.InitializationOptions = opts
   424  	call, ok := r.(*jsonrpc2.Call)
   425  	if !ok {
   426  		return nil, fmt.Errorf("%T is not a *jsonrpc2.Call", r)
   427  	}
   428  	return jsonrpc2.NewCall(call.ID(), "initialize", params)
   429  }
   430  
   431  func getGoEnv(ctx context.Context, env map[string]interface{}) (map[string]string, error) {
   432  	var runEnv []string
   433  	for k, v := range env {
   434  		runEnv = append(runEnv, fmt.Sprintf("%s=%s", k, v))
   435  	}
   436  	runner := gocommand.Runner{}
   437  	output, err := runner.Run(ctx, gocommand.Invocation{
   438  		Verb: "env",
   439  		Args: []string{"-json"},
   440  		Env:  runEnv,
   441  	})
   442  	if err != nil {
   443  		return nil, err
   444  	}
   445  	envmap := make(map[string]string)
   446  	if err := json.Unmarshal(output.Bytes(), &envmap); err != nil {
   447  		return nil, err
   448  	}
   449  	return envmap, nil
   450  }
   451  
   452  // A handshakeRequest identifies a client to the LSP server.
   453  type handshakeRequest struct {
   454  	// ServerID is the ID of the server on the client. This should usually be 0.
   455  	ServerID string `json:"serverID"`
   456  	// Logfile is the location of the clients log file.
   457  	Logfile string `json:"logfile"`
   458  	// DebugAddr is the client debug address.
   459  	DebugAddr string `json:"debugAddr"`
   460  	// GoplsPath is the path to the Gopls binary running the current client
   461  	// process.
   462  	GoplsPath string `json:"goplsPath"`
   463  }
   464  
   465  // A handshakeResponse is returned by the LSP server to tell the LSP client
   466  // information about its session.
   467  type handshakeResponse struct {
   468  	// SessionID is the server session associated with the client.
   469  	SessionID string `json:"sessionID"`
   470  	// Logfile is the location of the server logs.
   471  	Logfile string `json:"logfile"`
   472  	// DebugAddr is the server debug address.
   473  	DebugAddr string `json:"debugAddr"`
   474  	// GoplsPath is the path to the Gopls binary running the current server
   475  	// process.
   476  	GoplsPath string `json:"goplsPath"`
   477  }
   478  
   479  // ClientSession identifies a current client LSP session on the server. Note
   480  // that it looks similar to handshakeResposne, but in fact 'Logfile' and
   481  // 'DebugAddr' now refer to the client.
   482  type ClientSession struct {
   483  	SessionID string `json:"sessionID"`
   484  	Logfile   string `json:"logfile"`
   485  	DebugAddr string `json:"debugAddr"`
   486  }
   487  
   488  // ServerState holds information about the gopls daemon process, including its
   489  // debug information and debug information of all of its current connected
   490  // clients.
   491  type ServerState struct {
   492  	Logfile         string          `json:"logfile"`
   493  	DebugAddr       string          `json:"debugAddr"`
   494  	GoplsPath       string          `json:"goplsPath"`
   495  	CurrentClientID string          `json:"currentClientID"`
   496  	Clients         []ClientSession `json:"clients"`
   497  }
   498  
   499  const (
   500  	handshakeMethod = "gopls/handshake"
   501  	sessionsMethod  = "gopls/sessions"
   502  )
   503  
   504  func handshaker(session *cache.Session, goplsPath string, logHandshakes bool, handler jsonrpc2.Handler) jsonrpc2.Handler {
   505  	return func(ctx context.Context, reply jsonrpc2.Replier, r jsonrpc2.Request) error {
   506  		switch r.Method() {
   507  		case handshakeMethod:
   508  			// We log.Printf in this handler, rather than event.Log when we want logs
   509  			// to go to the daemon log rather than being reflected back to the
   510  			// client.
   511  			var req handshakeRequest
   512  			if err := json.Unmarshal(r.Params(), &req); err != nil {
   513  				if logHandshakes {
   514  					log.Printf("Error processing handshake for session %s: %v", session.ID(), err)
   515  				}
   516  				sendError(ctx, reply, err)
   517  				return nil
   518  			}
   519  			if logHandshakes {
   520  				log.Printf("Session %s: got handshake. Logfile: %q, Debug addr: %q", session.ID(), req.Logfile, req.DebugAddr)
   521  			}
   522  			event.Log(ctx, "Handshake session update",
   523  				cache.KeyUpdateSession.Of(session),
   524  				tag.DebugAddress.Of(req.DebugAddr),
   525  				tag.Logfile.Of(req.Logfile),
   526  				tag.ServerID.Of(req.ServerID),
   527  				tag.GoplsPath.Of(req.GoplsPath),
   528  			)
   529  			resp := handshakeResponse{
   530  				SessionID: session.ID(),
   531  				GoplsPath: goplsPath,
   532  			}
   533  			if di := debug.GetInstance(ctx); di != nil {
   534  				resp.Logfile = di.Logfile
   535  				resp.DebugAddr = di.ListenedDebugAddress
   536  			}
   537  
   538  			return reply(ctx, resp, nil)
   539  		case sessionsMethod:
   540  			resp := ServerState{
   541  				GoplsPath:       goplsPath,
   542  				CurrentClientID: session.ID(),
   543  			}
   544  			if di := debug.GetInstance(ctx); di != nil {
   545  				resp.Logfile = di.Logfile
   546  				resp.DebugAddr = di.ListenedDebugAddress
   547  				for _, c := range di.State.Clients() {
   548  					resp.Clients = append(resp.Clients, ClientSession{
   549  						SessionID: c.Session.ID(),
   550  						Logfile:   c.Logfile,
   551  						DebugAddr: c.DebugAddress,
   552  					})
   553  				}
   554  			}
   555  			return reply(ctx, resp, nil)
   556  		}
   557  		return handler(ctx, reply, r)
   558  	}
   559  }
   560  
   561  func sendError(ctx context.Context, reply jsonrpc2.Replier, err error) {
   562  	err = errors.Errorf("%v: %w", err, jsonrpc2.ErrParse)
   563  	if err := reply(ctx, nil, err); err != nil {
   564  		event.Error(ctx, "", err)
   565  	}
   566  }