github.com/keybase/client/go@v0.0.0-20241007131713-f10651d043c8/libkb/kex2_router_test.go (about)

     1  // Copyright 2015 Keybase, Inc. All rights reserved. Use of
     2  // this source code is governed by the included BSD license.
     3  
     4  package libkb
     5  
     6  import (
     7  	"sync"
     8  	"testing"
     9  	"time"
    10  
    11  	"crypto/rand"
    12  
    13  	"github.com/keybase/client/go/kex2"
    14  )
    15  
    16  type ktester struct {
    17  	sender   kex2.DeviceID
    18  	receiver kex2.DeviceID
    19  	I        kex2.SessionID
    20  	seqno    kex2.Seqno
    21  }
    22  
    23  func newKtester() *ktester {
    24  	kt := &ktester{}
    25  	if _, err := rand.Read(kt.sender[:]); err != nil {
    26  		panic(err)
    27  	}
    28  	if _, err := rand.Read(kt.receiver[:]); err != nil {
    29  		panic(err)
    30  	}
    31  	if _, err := rand.Read(kt.I[:]); err != nil {
    32  		panic(err)
    33  	}
    34  
    35  	return kt
    36  }
    37  
    38  func (k *ktester) post(mr kex2.MessageRouter, b []byte) error {
    39  	k.seqno++
    40  	return mr.Post(k.I, k.sender, k.seqno, b)
    41  }
    42  
    43  func (k *ktester) get(mr kex2.MessageRouter, low kex2.Seqno, poll time.Duration) ([][]byte, error) {
    44  	return mr.Get(k.I, k.receiver, low, poll)
    45  }
    46  
    47  func TestKex2Router(t *testing.T) {
    48  	tc := SetupTest(t, "kex2 router", 1)
    49  	defer tc.Cleanup()
    50  
    51  	mr := NewKexRouter(NewMetaContextTODO(tc.G))
    52  	kt := newKtester()
    53  
    54  	m1 := "hello everybody"
    55  	m2 := "goodbye everybody"
    56  	m3 := "plaid shirt"
    57  
    58  	// test send 2 messages
    59  	if err := kt.post(mr, []byte(m1)); err != nil {
    60  		t.Fatal(err)
    61  	}
    62  
    63  	if err := kt.post(mr, []byte(m2)); err != nil {
    64  		t.Fatal(err)
    65  	}
    66  
    67  	// test receive 2 messages
    68  	msgs, err := kt.get(mr, 0, 100*time.Millisecond)
    69  	if err != nil {
    70  		t.Fatal(err)
    71  	}
    72  	if len(msgs) != 2 {
    73  		t.Fatalf("number of messages: %d, expected 2", len(msgs))
    74  	}
    75  	if string(msgs[0]) != m1 {
    76  		t.Errorf("message 0: %q, expected %q", msgs[0], m1)
    77  	}
    78  	if string(msgs[1]) != m2 {
    79  		t.Errorf("message 1: %q, expected %q", msgs[1], m2)
    80  	}
    81  
    82  	// test calling receive before send
    83  	var wg sync.WaitGroup
    84  	wg.Add(1)
    85  	go func() {
    86  		defer wg.Done()
    87  		var merr error
    88  		// Very large timeout, for the benefit of CI, which may be slow
    89  		msgs, merr = kt.get(mr, 3, 10*time.Second)
    90  		if merr != nil {
    91  			t.Errorf("receive error: %s", merr)
    92  		}
    93  	}()
    94  
    95  	time.Sleep(3 * time.Millisecond)
    96  	if err := kt.post(mr, []byte(m3)); err != nil {
    97  		t.Fatal(err)
    98  	}
    99  
   100  	wg.Wait()
   101  	if len(msgs) != 1 {
   102  		t.Fatalf("number of messages: %d, expected 1", len(msgs))
   103  	}
   104  	if string(msgs[0]) != m3 {
   105  		t.Errorf("message: %q, expected %q", msgs[0], m3)
   106  		t.Errorf("Full message vector was: %v\n", msgs)
   107  	}
   108  
   109  	// test no messages ready
   110  	msgs, err = kt.get(mr, 4, 1*time.Millisecond)
   111  	if err != nil {
   112  		t.Fatal(err)
   113  	}
   114  	if len(msgs) != 0 {
   115  		t.Errorf("number of messages: %d, expected 0", len(msgs))
   116  	}
   117  }