github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/lsp/handlers.go (about) 1 package lsp 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "strings" 8 9 "github.com/jzelinskie/persistent" 10 "github.com/sourcegraph/go-lsp" 11 "github.com/sourcegraph/jsonrpc2" 12 13 log "github.com/authzed/spicedb/internal/logging" 14 "github.com/authzed/spicedb/pkg/development" 15 developerv1 "github.com/authzed/spicedb/pkg/proto/developer/v1" 16 "github.com/authzed/spicedb/pkg/schemadsl/compiler" 17 "github.com/authzed/spicedb/pkg/schemadsl/generator" 18 "github.com/authzed/spicedb/pkg/schemadsl/input" 19 ) 20 21 func (s *Server) textDocDiagnostic(ctx context.Context, r *jsonrpc2.Request) (FullDocumentDiagnosticReport, error) { 22 params, err := unmarshalParams[TextDocumentDiagnosticParams](r) 23 if err != nil { 24 return FullDocumentDiagnosticReport{}, err 25 } 26 27 log.Info(). 28 Str("method", "textDocument/diagnostic"). 29 Str("uri", string(params.TextDocument.URI)). 30 Msg("textDocDiagnostic") 31 32 diagnostics, err := s.computeDiagnostics(ctx, params.TextDocument.URI) 33 if err != nil { 34 return FullDocumentDiagnosticReport{}, err 35 } 36 37 log.Info(). 38 Str("uri", string(params.TextDocument.URI)). 39 Int("diagnostics", len(diagnostics)). 40 Msg("diagnostics complete") 41 42 return FullDocumentDiagnosticReport{ 43 Kind: "full", 44 Items: diagnostics, 45 }, nil 46 } 47 48 func (s *Server) computeDiagnostics(ctx context.Context, uri lsp.DocumentURI) ([]lsp.Diagnostic, error) { 49 diagnostics := make([]lsp.Diagnostic, 0) // Important: must not be nil for the consumer on the client side 50 if err := s.withFiles(func(files *persistent.Map[lsp.DocumentURI, trackedFile]) error { 51 file, ok := files.Get(uri) 52 if !ok { 53 log.Warn(). 54 Str("uri", string(uri)). 55 Msg("file not found for diagnostics") 56 57 return &jsonrpc2.Error{Code: jsonrpc2.CodeInternalError, Message: "file not found"} 58 } 59 60 devCtx, devErrs, err := development.NewDevContext(ctx, &developerv1.RequestContext{ 61 Schema: file.contents, 62 Relationships: nil, 63 }) 64 if err != nil { 65 return err 66 } 67 68 // Get errors. 69 for _, devErr := range devErrs.GetInputErrors() { 70 diagnostics = append(diagnostics, lsp.Diagnostic{ 71 Severity: lsp.Error, 72 Range: lsp.Range{ 73 Start: lsp.Position{Line: int(devErr.Line) - 1, Character: int(devErr.Column) - 1}, 74 End: lsp.Position{Line: int(devErr.Line) - 1, Character: int(devErr.Column) - 1}, 75 }, 76 Message: devErr.Message, 77 }) 78 } 79 80 // If there are no errors, we can also check for warnings. 81 if len(diagnostics) == 0 { 82 warnings, err := development.GetWarnings(ctx, devCtx) 83 if err != nil { 84 return err 85 } 86 87 for _, devWarning := range warnings { 88 diagnostics = append(diagnostics, lsp.Diagnostic{ 89 Severity: lsp.Warning, 90 Range: lsp.Range{ 91 Start: lsp.Position{Line: int(devWarning.Line) - 1, Character: int(devWarning.Column) - 1}, 92 End: lsp.Position{Line: int(devWarning.Line) - 1, Character: int(devWarning.Column) - 1}, 93 }, 94 Message: devWarning.Message, 95 }) 96 } 97 } 98 99 return nil 100 }); err != nil { 101 return nil, err 102 } 103 104 log.Info().Int("diagnostics", len(diagnostics)).Str("uri", string(uri)).Msg("computed diagnostics") 105 return diagnostics, nil 106 } 107 108 func (s *Server) textDocDidSave(ctx context.Context, r *jsonrpc2.Request, conn *jsonrpc2.Conn) (any, error) { 109 params, err := unmarshalParams[lsp.DidSaveTextDocumentParams](r) 110 if err != nil { 111 return nil, err 112 } 113 114 if err := s.publishDiagnosticsIfNecessary(ctx, conn, params.TextDocument.URI); err != nil { 115 return nil, err 116 } 117 118 return nil, nil 119 } 120 121 func (s *Server) textDocDidChange(ctx context.Context, r *jsonrpc2.Request, conn *jsonrpc2.Conn) (any, error) { 122 params, err := unmarshalParams[lsp.DidChangeTextDocumentParams](r) 123 if err != nil { 124 return nil, err 125 } 126 127 s.files.Set(params.TextDocument.URI, trackedFile{params.ContentChanges[0].Text, nil}, nil) 128 129 if err := s.publishDiagnosticsIfNecessary(ctx, conn, params.TextDocument.URI); err != nil { 130 return nil, err 131 } 132 133 return nil, nil 134 } 135 136 func (s *Server) textDocDidClose(_ context.Context, r *jsonrpc2.Request) (any, error) { 137 params, err := unmarshalParams[lsp.DidCloseTextDocumentParams](r) 138 if err != nil { 139 return nil, err 140 } 141 142 s.files.Delete(params.TextDocument.URI) 143 return nil, nil 144 } 145 146 func (s *Server) textDocDidOpen(ctx context.Context, r *jsonrpc2.Request, conn *jsonrpc2.Conn) (any, error) { 147 params, err := unmarshalParams[lsp.DidOpenTextDocumentParams](r) 148 if err != nil { 149 return nil, err 150 } 151 152 uri := params.TextDocument.URI 153 contents := params.TextDocument.Text 154 s.files.Set(uri, trackedFile{contents, nil}, nil) 155 156 if err := s.publishDiagnosticsIfNecessary(ctx, conn, uri); err != nil { 157 return nil, err 158 } 159 160 log.Debug(). 161 Str("uri", string(uri)). 162 Str("path", strings.TrimPrefix(string(uri), "file://")). 163 Msg("refreshed file") 164 165 return nil, nil 166 } 167 168 func (s *Server) publishDiagnosticsIfNecessary(ctx context.Context, conn *jsonrpc2.Conn, uri lsp.DocumentURI) error { 169 requestsDiagnostics := s.requestsDiagnostics 170 if requestsDiagnostics { 171 return nil 172 } 173 174 log.Debug(). 175 Str("uri", string(uri)). 176 Msg("publishing diagnostics") 177 178 diagnostics, err := s.computeDiagnostics(ctx, uri) 179 if err != nil { 180 return fmt.Errorf("failed to compute diagnostics: %w", err) 181 } 182 183 return conn.Notify(ctx, "textDocument/publishDiagnostics", lsp.PublishDiagnosticsParams{ 184 URI: uri, 185 Diagnostics: diagnostics, 186 }) 187 } 188 189 func (s *Server) getCompiledContents(path lsp.DocumentURI, files *persistent.Map[lsp.DocumentURI, trackedFile]) (*compiler.CompiledSchema, error) { 190 file, ok := files.Get(path) 191 if !ok { 192 return nil, &jsonrpc2.Error{Code: jsonrpc2.CodeInternalError, Message: "file not found"} 193 } 194 195 compiled := file.parsed 196 if compiled != nil { 197 return compiled, nil 198 } 199 200 justCompiled, derr, err := development.CompileSchema(file.contents) 201 if err != nil || derr != nil { 202 return nil, err 203 } 204 205 files.Set(path, trackedFile{file.contents, justCompiled}, nil) 206 return justCompiled, nil 207 } 208 209 func (s *Server) textDocHover(_ context.Context, r *jsonrpc2.Request) (*Hover, error) { 210 params, err := unmarshalParams[lsp.TextDocumentPositionParams](r) 211 if err != nil { 212 return nil, err 213 } 214 215 var hoverContents *Hover 216 err = s.withFiles(func(files *persistent.Map[lsp.DocumentURI, trackedFile]) error { 217 compiled, err := s.getCompiledContents(params.TextDocument.URI, files) 218 if err != nil { 219 return err 220 } 221 222 resolver, err := development.NewResolver(compiled) 223 if err != nil { 224 return err 225 } 226 227 position := input.Position{ 228 LineNumber: params.Position.Line, 229 ColumnPosition: params.Position.Character, 230 } 231 232 resolved, err := resolver.ReferenceAtPosition(input.Source("schema"), position) 233 if err != nil { 234 return err 235 } 236 237 if resolved == nil { 238 return nil 239 } 240 241 var lspRange *lsp.Range 242 if resolved.TargetPosition != nil { 243 lspRange = &lsp.Range{ 244 Start: lsp.Position{ 245 Line: resolved.TargetPosition.LineNumber, 246 Character: resolved.TargetPosition.ColumnPosition + resolved.TargetNamePositionOffset, 247 }, 248 End: lsp.Position{ 249 Line: resolved.TargetPosition.LineNumber, 250 Character: resolved.TargetPosition.ColumnPosition + resolved.TargetNamePositionOffset + len(resolved.Text), 251 }, 252 } 253 } 254 255 if resolved.TargetSourceCode != "" { 256 hoverContents = &Hover{ 257 Contents: MarkupContent{ 258 Language: "spicedb", 259 Value: resolved.TargetSourceCode, 260 }, 261 Range: lspRange, 262 } 263 } else { 264 hoverContents = &Hover{ 265 Contents: MarkupContent{ 266 Kind: "markdown", 267 Value: resolved.ReferenceMarkdown, 268 }, 269 Range: lspRange, 270 } 271 } 272 273 return nil 274 }) 275 if err != nil { 276 return nil, err 277 } 278 279 return hoverContents, nil 280 } 281 282 func (s *Server) textDocFormat(_ context.Context, r *jsonrpc2.Request) ([]lsp.TextEdit, error) { 283 params, err := unmarshalParams[lsp.DocumentFormattingParams](r) 284 if err != nil { 285 return nil, err 286 } 287 288 var formatted string 289 err = s.withFiles(func(files *persistent.Map[lsp.DocumentURI, trackedFile]) error { 290 compiled, err := s.getCompiledContents(params.TextDocument.URI, files) 291 if err != nil { 292 return err 293 } 294 295 formattedSchema, _, err := generator.GenerateSchema(compiled.OrderedDefinitions) 296 if err != nil { 297 return err 298 } 299 300 formatted = formattedSchema 301 return nil 302 }) 303 if err != nil { 304 return nil, err 305 } 306 307 if formatted == "" { 308 return nil, nil 309 } 310 311 return []lsp.TextEdit{ 312 { 313 Range: lsp.Range{ 314 Start: lsp.Position{Line: 0, Character: 0}, 315 End: lsp.Position{Line: 10000000, Character: 100000000}, // Replace the schema entirely 316 }, 317 NewText: formatted, 318 }, 319 }, nil 320 } 321 322 func (s *Server) initialized(_ context.Context, _ *jsonrpc2.Request) (any, error) { 323 if s.state != serverStateInitialized { 324 return nil, invalidRequest(errors.New("server not initialized")) 325 } 326 return nil, nil 327 } 328 329 func (s *Server) initialize(_ context.Context, r *jsonrpc2.Request) (any, error) { 330 ip, err := unmarshalParams[InitializeParams](r) 331 if err != nil { 332 return nil, err 333 } 334 335 s.requestsDiagnostics = ip.Capabilities.Diagnostics.RefreshSupport 336 log.Debug(). 337 Bool("requestsDiagnostics", s.requestsDiagnostics). 338 Msg("initialize") 339 340 if s.state != serverStateNotInitialized { 341 return nil, invalidRequest(errors.New("already initialized")) 342 } 343 344 syncKind := lsp.TDSKFull 345 s.state = serverStateInitialized 346 return InitializeResult{ 347 Capabilities: ServerCapabilities{ 348 TextDocumentSync: &lsp.TextDocumentSyncOptionsOrKind{Kind: &syncKind}, 349 CompletionProvider: &lsp.CompletionOptions{TriggerCharacters: []string{"."}}, 350 DocumentFormattingProvider: true, 351 DiagnosticProvider: &DiagnosticOptions{Identifier: "spicedb", InterFileDependencies: false, WorkspaceDiagnostics: false}, 352 HoverProvider: true, 353 }, 354 }, nil 355 } 356 357 func (s *Server) shutdown() error { 358 s.state = serverStateShuttingDown 359 log.Debug(). 360 Msg("shutting down LSP server") 361 return nil 362 } 363 364 type trackedFile struct { 365 contents string 366 parsed *compiler.CompiledSchema 367 } 368 369 func (s *Server) withFiles(fn func(*persistent.Map[lsp.DocumentURI, trackedFile]) error) error { 370 clone := s.files.Clone() 371 defer clone.Destroy() 372 return fn(clone) 373 }