github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/lsp/handlers.go (about)

     1  package lsp
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"strings"
     8  
     9  	"github.com/jzelinskie/persistent"
    10  	"github.com/sourcegraph/go-lsp"
    11  	"github.com/sourcegraph/jsonrpc2"
    12  
    13  	log "github.com/authzed/spicedb/internal/logging"
    14  	"github.com/authzed/spicedb/pkg/development"
    15  	developerv1 "github.com/authzed/spicedb/pkg/proto/developer/v1"
    16  	"github.com/authzed/spicedb/pkg/schemadsl/compiler"
    17  	"github.com/authzed/spicedb/pkg/schemadsl/generator"
    18  	"github.com/authzed/spicedb/pkg/schemadsl/input"
    19  )
    20  
    21  func (s *Server) textDocDiagnostic(ctx context.Context, r *jsonrpc2.Request) (FullDocumentDiagnosticReport, error) {
    22  	params, err := unmarshalParams[TextDocumentDiagnosticParams](r)
    23  	if err != nil {
    24  		return FullDocumentDiagnosticReport{}, err
    25  	}
    26  
    27  	log.Info().
    28  		Str("method", "textDocument/diagnostic").
    29  		Str("uri", string(params.TextDocument.URI)).
    30  		Msg("textDocDiagnostic")
    31  
    32  	diagnostics, err := s.computeDiagnostics(ctx, params.TextDocument.URI)
    33  	if err != nil {
    34  		return FullDocumentDiagnosticReport{}, err
    35  	}
    36  
    37  	log.Info().
    38  		Str("uri", string(params.TextDocument.URI)).
    39  		Int("diagnostics", len(diagnostics)).
    40  		Msg("diagnostics complete")
    41  
    42  	return FullDocumentDiagnosticReport{
    43  		Kind:  "full",
    44  		Items: diagnostics,
    45  	}, nil
    46  }
    47  
    48  func (s *Server) computeDiagnostics(ctx context.Context, uri lsp.DocumentURI) ([]lsp.Diagnostic, error) {
    49  	diagnostics := make([]lsp.Diagnostic, 0) // Important: must not be nil for the consumer on the client side
    50  	if err := s.withFiles(func(files *persistent.Map[lsp.DocumentURI, trackedFile]) error {
    51  		file, ok := files.Get(uri)
    52  		if !ok {
    53  			log.Warn().
    54  				Str("uri", string(uri)).
    55  				Msg("file not found for diagnostics")
    56  
    57  			return &jsonrpc2.Error{Code: jsonrpc2.CodeInternalError, Message: "file not found"}
    58  		}
    59  
    60  		devCtx, devErrs, err := development.NewDevContext(ctx, &developerv1.RequestContext{
    61  			Schema:        file.contents,
    62  			Relationships: nil,
    63  		})
    64  		if err != nil {
    65  			return err
    66  		}
    67  
    68  		// Get errors.
    69  		for _, devErr := range devErrs.GetInputErrors() {
    70  			diagnostics = append(diagnostics, lsp.Diagnostic{
    71  				Severity: lsp.Error,
    72  				Range: lsp.Range{
    73  					Start: lsp.Position{Line: int(devErr.Line) - 1, Character: int(devErr.Column) - 1},
    74  					End:   lsp.Position{Line: int(devErr.Line) - 1, Character: int(devErr.Column) - 1},
    75  				},
    76  				Message: devErr.Message,
    77  			})
    78  		}
    79  
    80  		// If there are no errors, we can also check for warnings.
    81  		if len(diagnostics) == 0 {
    82  			warnings, err := development.GetWarnings(ctx, devCtx)
    83  			if err != nil {
    84  				return err
    85  			}
    86  
    87  			for _, devWarning := range warnings {
    88  				diagnostics = append(diagnostics, lsp.Diagnostic{
    89  					Severity: lsp.Warning,
    90  					Range: lsp.Range{
    91  						Start: lsp.Position{Line: int(devWarning.Line) - 1, Character: int(devWarning.Column) - 1},
    92  						End:   lsp.Position{Line: int(devWarning.Line) - 1, Character: int(devWarning.Column) - 1},
    93  					},
    94  					Message: devWarning.Message,
    95  				})
    96  			}
    97  		}
    98  
    99  		return nil
   100  	}); err != nil {
   101  		return nil, err
   102  	}
   103  
   104  	log.Info().Int("diagnostics", len(diagnostics)).Str("uri", string(uri)).Msg("computed diagnostics")
   105  	return diagnostics, nil
   106  }
   107  
   108  func (s *Server) textDocDidSave(ctx context.Context, r *jsonrpc2.Request, conn *jsonrpc2.Conn) (any, error) {
   109  	params, err := unmarshalParams[lsp.DidSaveTextDocumentParams](r)
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  
   114  	if err := s.publishDiagnosticsIfNecessary(ctx, conn, params.TextDocument.URI); err != nil {
   115  		return nil, err
   116  	}
   117  
   118  	return nil, nil
   119  }
   120  
   121  func (s *Server) textDocDidChange(ctx context.Context, r *jsonrpc2.Request, conn *jsonrpc2.Conn) (any, error) {
   122  	params, err := unmarshalParams[lsp.DidChangeTextDocumentParams](r)
   123  	if err != nil {
   124  		return nil, err
   125  	}
   126  
   127  	s.files.Set(params.TextDocument.URI, trackedFile{params.ContentChanges[0].Text, nil}, nil)
   128  
   129  	if err := s.publishDiagnosticsIfNecessary(ctx, conn, params.TextDocument.URI); err != nil {
   130  		return nil, err
   131  	}
   132  
   133  	return nil, nil
   134  }
   135  
   136  func (s *Server) textDocDidClose(_ context.Context, r *jsonrpc2.Request) (any, error) {
   137  	params, err := unmarshalParams[lsp.DidCloseTextDocumentParams](r)
   138  	if err != nil {
   139  		return nil, err
   140  	}
   141  
   142  	s.files.Delete(params.TextDocument.URI)
   143  	return nil, nil
   144  }
   145  
   146  func (s *Server) textDocDidOpen(ctx context.Context, r *jsonrpc2.Request, conn *jsonrpc2.Conn) (any, error) {
   147  	params, err := unmarshalParams[lsp.DidOpenTextDocumentParams](r)
   148  	if err != nil {
   149  		return nil, err
   150  	}
   151  
   152  	uri := params.TextDocument.URI
   153  	contents := params.TextDocument.Text
   154  	s.files.Set(uri, trackedFile{contents, nil}, nil)
   155  
   156  	if err := s.publishDiagnosticsIfNecessary(ctx, conn, uri); err != nil {
   157  		return nil, err
   158  	}
   159  
   160  	log.Debug().
   161  		Str("uri", string(uri)).
   162  		Str("path", strings.TrimPrefix(string(uri), "file://")).
   163  		Msg("refreshed file")
   164  
   165  	return nil, nil
   166  }
   167  
   168  func (s *Server) publishDiagnosticsIfNecessary(ctx context.Context, conn *jsonrpc2.Conn, uri lsp.DocumentURI) error {
   169  	requestsDiagnostics := s.requestsDiagnostics
   170  	if requestsDiagnostics {
   171  		return nil
   172  	}
   173  
   174  	log.Debug().
   175  		Str("uri", string(uri)).
   176  		Msg("publishing diagnostics")
   177  
   178  	diagnostics, err := s.computeDiagnostics(ctx, uri)
   179  	if err != nil {
   180  		return fmt.Errorf("failed to compute diagnostics: %w", err)
   181  	}
   182  
   183  	return conn.Notify(ctx, "textDocument/publishDiagnostics", lsp.PublishDiagnosticsParams{
   184  		URI:         uri,
   185  		Diagnostics: diagnostics,
   186  	})
   187  }
   188  
   189  func (s *Server) getCompiledContents(path lsp.DocumentURI, files *persistent.Map[lsp.DocumentURI, trackedFile]) (*compiler.CompiledSchema, error) {
   190  	file, ok := files.Get(path)
   191  	if !ok {
   192  		return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeInternalError, Message: "file not found"}
   193  	}
   194  
   195  	compiled := file.parsed
   196  	if compiled != nil {
   197  		return compiled, nil
   198  	}
   199  
   200  	justCompiled, derr, err := development.CompileSchema(file.contents)
   201  	if err != nil || derr != nil {
   202  		return nil, err
   203  	}
   204  
   205  	files.Set(path, trackedFile{file.contents, justCompiled}, nil)
   206  	return justCompiled, nil
   207  }
   208  
   209  func (s *Server) textDocHover(_ context.Context, r *jsonrpc2.Request) (*Hover, error) {
   210  	params, err := unmarshalParams[lsp.TextDocumentPositionParams](r)
   211  	if err != nil {
   212  		return nil, err
   213  	}
   214  
   215  	var hoverContents *Hover
   216  	err = s.withFiles(func(files *persistent.Map[lsp.DocumentURI, trackedFile]) error {
   217  		compiled, err := s.getCompiledContents(params.TextDocument.URI, files)
   218  		if err != nil {
   219  			return err
   220  		}
   221  
   222  		resolver, err := development.NewResolver(compiled)
   223  		if err != nil {
   224  			return err
   225  		}
   226  
   227  		position := input.Position{
   228  			LineNumber:     params.Position.Line,
   229  			ColumnPosition: params.Position.Character,
   230  		}
   231  
   232  		resolved, err := resolver.ReferenceAtPosition(input.Source("schema"), position)
   233  		if err != nil {
   234  			return err
   235  		}
   236  
   237  		if resolved == nil {
   238  			return nil
   239  		}
   240  
   241  		var lspRange *lsp.Range
   242  		if resolved.TargetPosition != nil {
   243  			lspRange = &lsp.Range{
   244  				Start: lsp.Position{
   245  					Line:      resolved.TargetPosition.LineNumber,
   246  					Character: resolved.TargetPosition.ColumnPosition + resolved.TargetNamePositionOffset,
   247  				},
   248  				End: lsp.Position{
   249  					Line:      resolved.TargetPosition.LineNumber,
   250  					Character: resolved.TargetPosition.ColumnPosition + resolved.TargetNamePositionOffset + len(resolved.Text),
   251  				},
   252  			}
   253  		}
   254  
   255  		if resolved.TargetSourceCode != "" {
   256  			hoverContents = &Hover{
   257  				Contents: MarkupContent{
   258  					Language: "spicedb",
   259  					Value:    resolved.TargetSourceCode,
   260  				},
   261  				Range: lspRange,
   262  			}
   263  		} else {
   264  			hoverContents = &Hover{
   265  				Contents: MarkupContent{
   266  					Kind:  "markdown",
   267  					Value: resolved.ReferenceMarkdown,
   268  				},
   269  				Range: lspRange,
   270  			}
   271  		}
   272  
   273  		return nil
   274  	})
   275  	if err != nil {
   276  		return nil, err
   277  	}
   278  
   279  	return hoverContents, nil
   280  }
   281  
   282  func (s *Server) textDocFormat(_ context.Context, r *jsonrpc2.Request) ([]lsp.TextEdit, error) {
   283  	params, err := unmarshalParams[lsp.DocumentFormattingParams](r)
   284  	if err != nil {
   285  		return nil, err
   286  	}
   287  
   288  	var formatted string
   289  	err = s.withFiles(func(files *persistent.Map[lsp.DocumentURI, trackedFile]) error {
   290  		compiled, err := s.getCompiledContents(params.TextDocument.URI, files)
   291  		if err != nil {
   292  			return err
   293  		}
   294  
   295  		formattedSchema, _, err := generator.GenerateSchema(compiled.OrderedDefinitions)
   296  		if err != nil {
   297  			return err
   298  		}
   299  
   300  		formatted = formattedSchema
   301  		return nil
   302  	})
   303  	if err != nil {
   304  		return nil, err
   305  	}
   306  
   307  	if formatted == "" {
   308  		return nil, nil
   309  	}
   310  
   311  	return []lsp.TextEdit{
   312  		{
   313  			Range: lsp.Range{
   314  				Start: lsp.Position{Line: 0, Character: 0},
   315  				End:   lsp.Position{Line: 10000000, Character: 100000000}, // Replace the schema entirely
   316  			},
   317  			NewText: formatted,
   318  		},
   319  	}, nil
   320  }
   321  
   322  func (s *Server) initialized(_ context.Context, _ *jsonrpc2.Request) (any, error) {
   323  	if s.state != serverStateInitialized {
   324  		return nil, invalidRequest(errors.New("server not initialized"))
   325  	}
   326  	return nil, nil
   327  }
   328  
   329  func (s *Server) initialize(_ context.Context, r *jsonrpc2.Request) (any, error) {
   330  	ip, err := unmarshalParams[InitializeParams](r)
   331  	if err != nil {
   332  		return nil, err
   333  	}
   334  
   335  	s.requestsDiagnostics = ip.Capabilities.Diagnostics.RefreshSupport
   336  	log.Debug().
   337  		Bool("requestsDiagnostics", s.requestsDiagnostics).
   338  		Msg("initialize")
   339  
   340  	if s.state != serverStateNotInitialized {
   341  		return nil, invalidRequest(errors.New("already initialized"))
   342  	}
   343  
   344  	syncKind := lsp.TDSKFull
   345  	s.state = serverStateInitialized
   346  	return InitializeResult{
   347  		Capabilities: ServerCapabilities{
   348  			TextDocumentSync:           &lsp.TextDocumentSyncOptionsOrKind{Kind: &syncKind},
   349  			CompletionProvider:         &lsp.CompletionOptions{TriggerCharacters: []string{"."}},
   350  			DocumentFormattingProvider: true,
   351  			DiagnosticProvider:         &DiagnosticOptions{Identifier: "spicedb", InterFileDependencies: false, WorkspaceDiagnostics: false},
   352  			HoverProvider:              true,
   353  		},
   354  	}, nil
   355  }
   356  
   357  func (s *Server) shutdown() error {
   358  	s.state = serverStateShuttingDown
   359  	log.Debug().
   360  		Msg("shutting down LSP server")
   361  	return nil
   362  }
   363  
   364  type trackedFile struct {
   365  	contents string
   366  	parsed   *compiler.CompiledSchema
   367  }
   368  
   369  func (s *Server) withFiles(fn func(*persistent.Map[lsp.DocumentURI, trackedFile]) error) error {
   370  	clone := s.files.Clone()
   371  	defer clone.Destroy()
   372  	return fn(clone)
   373  }