github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/lsp/lsp.go (about)

     1  // Package lsp implements the Language Server Protocol for SpiceDB schema
     2  // development.
     3  package lsp
     4  
     5  import (
     6  	"context"
     7  	"encoding/json"
     8  	"fmt"
     9  	"net"
    10  
    11  	"github.com/jzelinskie/persistent"
    12  	"github.com/sourcegraph/go-lsp"
    13  	"github.com/sourcegraph/jsonrpc2"
    14  	"golang.org/x/sync/errgroup"
    15  
    16  	log "github.com/authzed/spicedb/internal/logging"
    17  )
    18  
    19  type serverState int
    20  
    21  const (
    22  	serverStateNotInitialized serverState = iota
    23  	serverStateInitialized
    24  	serverStateShuttingDown
    25  )
    26  
    27  // Server is a Language Server Protocol server for SpiceDB schema development.
    28  type Server struct {
    29  	files *persistent.Map[lsp.DocumentURI, trackedFile]
    30  	state serverState
    31  
    32  	requestsDiagnostics bool
    33  }
    34  
    35  // NewServer returns a new Server.
    36  func NewServer() *Server {
    37  	return &Server{
    38  		state: serverStateNotInitialized,
    39  		files: persistent.NewMap[lsp.DocumentURI, trackedFile](func(x, y lsp.DocumentURI) bool {
    40  			return string(x) < string(y)
    41  		}),
    42  	}
    43  }
    44  
    45  func (s *Server) Handle(ctx context.Context, conn *jsonrpc2.Conn, r *jsonrpc2.Request) {
    46  	jsonrpc2.HandlerWithError(s.handle).Handle(ctx, conn, r)
    47  }
    48  
    49  func logJSONPtr(msg *json.RawMessage) string {
    50  	if msg == nil {
    51  		return "nil"
    52  	}
    53  	return string(*msg)
    54  }
    55  
    56  func (s *Server) handle(ctx context.Context, conn *jsonrpc2.Conn, r *jsonrpc2.Request) (result any, err error) {
    57  	log.Ctx(ctx).Debug().
    58  		Stringer("id", r.ID).
    59  		Str("method", r.Method).
    60  		Str("params", logJSONPtr(r.Params)).
    61  		Msg("received LSP request")
    62  
    63  	if s.state == serverStateShuttingDown {
    64  		log.Ctx(ctx).Warn().
    65  			Str("method", r.Method).
    66  			Msg("ignoring request during shutdown")
    67  		return nil, nil
    68  	}
    69  
    70  	// Reference: https://microsoft.github.io/language-server-protocol/specifications/lsp/3.17/specification/#lifeCycleMessages
    71  	if r.Method != "initialize" && s.state != serverStateInitialized {
    72  		return nil, &jsonrpc2.Error{
    73  			Code:    codeUninitialized,
    74  			Message: "server not initialized",
    75  		}
    76  	}
    77  
    78  	switch r.Method {
    79  	case "initialize":
    80  		result, err = s.initialize(ctx, r)
    81  	case "initialized":
    82  		result, err = s.initialized(ctx, r)
    83  	case "shutdown":
    84  		result, err = nil, s.shutdown()
    85  	case "exit":
    86  		result, err = nil, conn.Close()
    87  	case "textDocument/didOpen":
    88  		result, err = s.textDocDidOpen(ctx, r, conn)
    89  	case "textDocument/didClose":
    90  		result, err = s.textDocDidClose(ctx, r)
    91  	case "textDocument/didChange":
    92  		result, err = s.textDocDidChange(ctx, r, conn)
    93  	case "textDocument/didSave":
    94  		result, err = s.textDocDidSave(ctx, r, conn)
    95  	case "textDocument/diagnostic":
    96  		result, err = s.textDocDiagnostic(ctx, r)
    97  	case "textDocument/formatting":
    98  		result, err = s.textDocFormat(ctx, r)
    99  	case "textDocument/hover":
   100  		result, err = s.textDocHover(ctx, r)
   101  	default:
   102  		log.Ctx(ctx).Warn().
   103  			Str("method", r.Method).
   104  			Msg("unsupported LSP method")
   105  		return nil, nil
   106  	}
   107  	log.Ctx(ctx).Info().
   108  		Stringer("id", r.ID).
   109  		Str("method", r.Method).
   110  		Str("params", logJSONPtr(r.Params)).
   111  		Interface("response", result).
   112  		Msg("responded to LSP request")
   113  	return result, err
   114  }
   115  
   116  func (s *Server) listenStdin(ctx context.Context) error {
   117  	log.Ctx(ctx).Info().
   118  		Msg("listening for LSP connections on stdin")
   119  
   120  	var connOpts []jsonrpc2.ConnOpt
   121  	stream := jsonrpc2.NewBufferedStream(stdrwc{}, jsonrpc2.VSCodeObjectCodec{})
   122  	conn := jsonrpc2.NewConn(ctx, stream, jsonrpc2.AsyncHandler(s), connOpts...)
   123  	defer conn.Close()
   124  
   125  	select {
   126  	case <-ctx.Done():
   127  	case <-conn.DisconnectNotify():
   128  	}
   129  	return nil
   130  }
   131  
   132  func (s *Server) listenTCP(ctx context.Context, addr string) error {
   133  	log.Ctx(ctx).Info().
   134  		Str("addr", addr).
   135  		Msg("listening for LSP connections")
   136  
   137  	l, err := net.Listen("tcp", addr)
   138  	if err != nil {
   139  		return err
   140  	}
   141  	defer l.Close()
   142  
   143  	var g errgroup.Group
   144  
   145  serving:
   146  	for {
   147  		select {
   148  		case <-ctx.Done():
   149  			break serving
   150  		default:
   151  			conn, err := l.Accept()
   152  			if err != nil {
   153  				continue
   154  			}
   155  
   156  			g.Go(func() error {
   157  				stream := jsonrpc2.NewBufferedStream(conn, jsonrpc2.VSCodeObjectCodec{})
   158  				jconn := jsonrpc2.NewConn(ctx, stream, s)
   159  				defer jconn.Close()
   160  				<-jconn.DisconnectNotify()
   161  				return nil
   162  			})
   163  		}
   164  	}
   165  
   166  	return g.Wait()
   167  }
   168  
   169  // Run binds to the provided address and concurrently serves Language Server
   170  // Protocol requests.
   171  func (s *Server) Run(ctx context.Context, addr string, stdio bool) error {
   172  	log.Ctx(ctx).Info().
   173  		Str("addr", addr).
   174  		Msg("starting LSP server")
   175  
   176  	if addr == "-" && !stdio {
   177  		return fmt.Errorf("cannot use stdin with stdio disabled")
   178  	}
   179  
   180  	if addr == "-" || stdio {
   181  		return s.listenStdin(ctx)
   182  	}
   183  	return s.listenTCP(ctx, addr)
   184  }