github.com/tilt-dev/tilt@v0.33.15-0.20240515162809-0a22ed45d8a0/internal/hud/server/websocket_reader_test.go (about)

     1  package server
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/golang/protobuf/jsonpb"
    11  	"github.com/stretchr/testify/assert"
    12  
    13  	"github.com/tilt-dev/tilt/internal/testutils/bufsync"
    14  
    15  	"github.com/tilt-dev/tilt/internal/testutils"
    16  	proto_webview "github.com/tilt-dev/tilt/pkg/webview"
    17  )
    18  
    19  func TestViewsHandled(t *testing.T) {
    20  	f := newWebsocketReaderFixture(t)
    21  	f.start()
    22  
    23  	v := &proto_webview.View{Log: "hello world"}
    24  	f.sendView(v)
    25  	f.assertHandlerCallCount(1)
    26  	assert.Equal(t, "hello world", f.handler.lastViewLog)
    27  
    28  	v = &proto_webview.View{Log: "goodbye world"}
    29  	f.sendView(v)
    30  	f.assertHandlerCallCount(2)
    31  	assert.Equal(t, "goodbye world", f.handler.lastViewLog)
    32  }
    33  
    34  func TestHandlerErrorDoesntStopLoop(t *testing.T) {
    35  	f := newWebsocketReaderFixture(t)
    36  	f.start()
    37  	f.handler.nextErr = fmt.Errorf("aw nerts")
    38  
    39  	v := &proto_webview.View{Log: "hello world"}
    40  	f.sendView(v)
    41  	f.assertHandlerCallCount(1)
    42  	f.assertLogs("aw nerts")
    43  
    44  	// should still be running!
    45  	v = &proto_webview.View{Log: "goodbye world"}
    46  	f.sendView(v)
    47  	f.assertHandlerCallCount(2)
    48  	assert.Equal(t, "goodbye world", f.handler.lastViewLog)
    49  }
    50  
    51  func TestNonPersistentReaderExistsAfterHandling(t *testing.T) {
    52  	f := newWebsocketReaderFixture(t).withPersistent(false)
    53  	f.start()
    54  
    55  	v := &proto_webview.View{Log: "hello world"}
    56  	f.sendView(v)
    57  	f.assertHandlerCallCount(1)
    58  	assert.Equal(t, "hello world", f.handler.lastViewLog)
    59  	f.assertDone()
    60  }
    61  
    62  func TestWebsocketCloseOnNextReaderError(t *testing.T) {
    63  	f := newWebsocketReaderFixture(t)
    64  	f.start()
    65  
    66  	f.conn.readCh <- readerOrErr{err: fmt.Errorf("read error")}
    67  
    68  	time.Sleep(10 * time.Millisecond)
    69  	f.assertDone()
    70  }
    71  
    72  type websocketReaderFixture struct {
    73  	t       *testing.T
    74  	ctx     context.Context
    75  	cancel  context.CancelFunc
    76  	out     *bufsync.ThreadSafeBuffer
    77  	conn    *fakeConn
    78  	handler *fakeViewHandler
    79  	wsr     *WebsocketReader
    80  	done    chan error
    81  }
    82  
    83  func newWebsocketReaderFixture(t *testing.T) *websocketReaderFixture {
    84  	out := bufsync.NewThreadSafeBuffer()
    85  	baseCtx, _, _ := testutils.ForkedCtxAndAnalyticsForTest(out)
    86  	ctx, cancel := context.WithCancel(baseCtx)
    87  	conn := newFakeConn()
    88  	handler := &fakeViewHandler{}
    89  
    90  	ret := &websocketReaderFixture{
    91  		t:       t,
    92  		ctx:     ctx,
    93  		cancel:  cancel,
    94  		out:     out,
    95  		conn:    conn,
    96  		handler: handler,
    97  		wsr:     newWebsocketReader(conn, true, handler),
    98  		done:    make(chan error),
    99  	}
   100  
   101  	t.Cleanup(ret.tearDown)
   102  	return ret
   103  }
   104  
   105  func (f *websocketReaderFixture) withPersistent(persistent bool) *websocketReaderFixture {
   106  	f.wsr.persistent = persistent
   107  	return f
   108  }
   109  
   110  func (f *websocketReaderFixture) start() {
   111  	go func() {
   112  		err := f.wsr.Listen(f.ctx)
   113  		f.done <- err
   114  		close(f.done)
   115  	}()
   116  }
   117  
   118  func (f *websocketReaderFixture) sendView(v *proto_webview.View) {
   119  	buf := &bytes.Buffer{}
   120  	err := (&jsonpb.Marshaler{}).Marshal(buf, v)
   121  	assert.NoError(f.t, err)
   122  
   123  	f.conn.newMessageToRead(buf)
   124  }
   125  
   126  func (f *websocketReaderFixture) assertHandlerCallCount(n int) {
   127  	ctx, cancel := context.WithTimeout(f.ctx, time.Millisecond*10)
   128  	defer cancel()
   129  	isCanceled := false
   130  
   131  	for {
   132  		if f.handler.callCount == n {
   133  			return
   134  		}
   135  		if isCanceled {
   136  			f.t.Fatalf("Timed out waiting for handler.callCount = %d (got: %d)",
   137  				n, f.handler.callCount)
   138  		}
   139  
   140  		select {
   141  		case <-ctx.Done():
   142  			// Let the loop run the check one more time
   143  			isCanceled = true
   144  		case <-time.After(time.Millisecond):
   145  		}
   146  	}
   147  }
   148  
   149  func (f *websocketReaderFixture) assertLogs(msg string) {
   150  	f.out.AssertEventuallyContains(f.t, msg, time.Second)
   151  }
   152  
   153  func (f *websocketReaderFixture) tearDown() {
   154  	f.cancel()
   155  	f.assertDone()
   156  }
   157  
   158  func (f *websocketReaderFixture) assertDone() {
   159  	select {
   160  	case <-time.After(100 * time.Millisecond):
   161  		f.t.Fatal("timed out waiting for close")
   162  	case err := <-f.done:
   163  		assert.NoError(f.t, err)
   164  	}
   165  }
   166  
   167  type fakeViewHandler struct {
   168  	callCount   int
   169  	lastViewLog string // use the Log field to differentiate the views we send, cuz why not
   170  	nextErr     error
   171  }
   172  
   173  func (fvh *fakeViewHandler) Handle(v *proto_webview.View) error {
   174  	fvh.callCount += 1
   175  	if fvh.nextErr != nil {
   176  		err := fvh.nextErr
   177  		fvh.nextErr = nil
   178  		return err
   179  	}
   180  	fvh.lastViewLog = v.Log
   181  	return nil
   182  }