github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/internal/lsp/testutil.go (about)

     1  package lsp
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io"
     8  	"strings"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/sourcegraph/go-lsp"
    13  	"github.com/sourcegraph/jsonrpc2"
    14  	"github.com/stretchr/testify/require"
    15  )
    16  
    17  type testStream struct {
    18  	toRead   chan string
    19  	received chan string
    20  }
    21  
    22  var _ io.ReadWriteCloser = (*testStream)(nil)
    23  
    24  func (ts *testStream) Read(p []byte) (int, error) {
    25  	read := <-ts.toRead
    26  	if len(read) == 0 {
    27  		return 0, io.EOF
    28  	}
    29  
    30  	copy(p, read)
    31  	return len(read), nil
    32  }
    33  
    34  func (ts *testStream) Write(p []byte) (int, error) {
    35  	ts.received <- string(p)
    36  	return len(p), nil
    37  }
    38  
    39  func (ts *testStream) Close() error {
    40  	return nil
    41  }
    42  
    43  type lspTester struct {
    44  	t      *testing.T
    45  	ts     *testStream
    46  	server *Server
    47  }
    48  
    49  func (lt *lspTester) initialize() {
    50  	resp, serverState := sendAndReceive[lsp.InitializeResult](lt, "initialize", InitializeParams{
    51  		Capabilities: ClientCapabilities{
    52  			Diagnostics: DiagnosticWorkspaceClientCapabilities{
    53  				RefreshSupport: true,
    54  			},
    55  		},
    56  	})
    57  	require.Equal(lt.t, serverStateInitialized, serverState)
    58  	require.True(lt.t, resp.Capabilities.DocumentFormattingProvider)
    59  }
    60  
    61  func (lt *lspTester) setFileContents(path string, contents string) {
    62  	sendAndReceive[any](lt, "textDocument/didChange", lsp.DidChangeTextDocumentParams{
    63  		TextDocument: lsp.VersionedTextDocumentIdentifier{
    64  			TextDocumentIdentifier: lsp.TextDocumentIdentifier{URI: lsp.DocumentURI(path)},
    65  			Version:                1,
    66  		},
    67  		ContentChanges: []lsp.TextDocumentContentChangeEvent{
    68  			{
    69  				Text: contents,
    70  			},
    71  		},
    72  	})
    73  }
    74  
    75  func sendAndExpectError(lt *lspTester, method string, params interface{}) (*jsonrpc2.Error, serverState) {
    76  	paramsBytes, err := json.Marshal(params)
    77  	require.NoError(lt.t, err)
    78  
    79  	paramsMsg := json.RawMessage(paramsBytes)
    80  
    81  	r := &jsonrpc2.Request{
    82  		Method: method,
    83  		ID:     jsonrpc2.ID{Num: 1},
    84  		Params: &paramsMsg,
    85  	}
    86  	message, err := r.MarshalJSON()
    87  	require.NoError(lt.t, err)
    88  
    89  	lt.ts.toRead <- fmt.Sprintf("Content-Length: %d\r\n", len(message))
    90  	lt.ts.toRead <- "\r\n"
    91  	lt.ts.toRead <- string(message)
    92  
    93  	select {
    94  	case received := <-lt.ts.received:
    95  		lines := strings.Split(received, "\r\n")
    96  		require.Greater(lt.t, len(lines), 2)
    97  
    98  		var resp jsonrpc2.Response
    99  		err := json.Unmarshal([]byte(lines[2]), &resp)
   100  		require.NoError(lt.t, err)
   101  		require.NotNil(lt.t, resp.Error)
   102  
   103  		return resp.Error, lt.server.state
   104  
   105  	case <-time.After(1 * time.Second):
   106  		lt.t.Fatal("timed out waiting for response")
   107  	}
   108  
   109  	return nil, serverStateNotInitialized
   110  }
   111  
   112  func sendAndReceive[T any](lt *lspTester, method string, params interface{}) (T, serverState) {
   113  	paramsBytes, err := json.Marshal(params)
   114  	require.NoError(lt.t, err)
   115  
   116  	paramsMsg := json.RawMessage(paramsBytes)
   117  
   118  	r := &jsonrpc2.Request{
   119  		Method: method,
   120  		ID:     jsonrpc2.ID{Num: 1},
   121  		Params: &paramsMsg,
   122  	}
   123  	message, err := r.MarshalJSON()
   124  	require.NoError(lt.t, err)
   125  
   126  	lt.ts.toRead <- fmt.Sprintf("Content-Length: %d\r\n", len(message))
   127  	lt.ts.toRead <- "\r\n"
   128  	lt.ts.toRead <- string(message)
   129  
   130  	select {
   131  	case received := <-lt.ts.received:
   132  		lines := strings.Split(received, "\r\n")
   133  		require.Greater(lt.t, len(lines), 2)
   134  
   135  		var resp jsonrpc2.Response
   136  		err := json.Unmarshal([]byte(lines[2]), &resp)
   137  		require.NoError(lt.t, err)
   138  		require.Nil(lt.t, resp.Error)
   139  
   140  		var result T
   141  		err = json.Unmarshal(*resp.Result, &result)
   142  		require.NoError(lt.t, err)
   143  
   144  		return result, lt.server.state
   145  
   146  	case <-time.After(1 * time.Second):
   147  		lt.t.Fatal("timed out waiting for response")
   148  	}
   149  
   150  	var empty T
   151  	return empty, serverStateNotInitialized
   152  }
   153  
   154  func newLSPTester(t *testing.T) *lspTester {
   155  	ctx := context.Background()
   156  	ts := newTestStream()
   157  
   158  	var connOpts []jsonrpc2.ConnOpt
   159  	stream := jsonrpc2.NewBufferedStream(ts, jsonrpc2.VSCodeObjectCodec{})
   160  
   161  	server := NewServer()
   162  	conn := jsonrpc2.NewConn(ctx, stream, server, connOpts...)
   163  	t.Cleanup(func() {
   164  		_ = conn.Close()
   165  	})
   166  
   167  	return &lspTester{t, ts, server}
   168  }
   169  
   170  func newTestStream() *testStream {
   171  	return &testStream{make(chan string), make(chan string)}
   172  }