github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/kex2/transport_test.go (about)

     1  // Copyright 2015 Keybase, Inc. All rights reserved. Use of
     2  // this source code is governed by the included BSD license.
     3  
     4  package kex2
     5  
     6  import (
     7  	"bytes"
     8  	"crypto/rand"
     9  	"io"
    10  	"net"
    11  	"runtime"
    12  	"strings"
    13  	"sync"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/stretchr/testify/require"
    18  	"golang.org/x/net/context"
    19  )
    20  
    21  type message struct {
    22  	seqno Seqno
    23  	msg   []byte
    24  }
    25  
    26  type simplexSession struct {
    27  	ch chan message
    28  }
    29  
    30  var zeroDeviceID DeviceID
    31  
    32  func (d DeviceID) isZero() bool {
    33  	return d.Eq(zeroDeviceID)
    34  }
    35  
    36  func newSimplexSession() *simplexSession {
    37  	return &simplexSession{
    38  		ch: make(chan message, 100),
    39  	}
    40  }
    41  
    42  type session struct {
    43  	id              SessionID
    44  	devices         [2]DeviceID
    45  	simplexSessions [2](*simplexSession)
    46  }
    47  
    48  func newSession(i SessionID) *session {
    49  	sess := &session{id: i}
    50  	for j := 0; j < 2; j++ {
    51  		sess.simplexSessions[j] = newSimplexSession()
    52  	}
    53  	return sess
    54  }
    55  
    56  func (s *session) getDeviceNumber(d DeviceID) int {
    57  	if s.devices[0].Eq(d) {
    58  		return 0
    59  	}
    60  	if s.devices[0].isZero() {
    61  		s.devices[0] = d
    62  		return 0
    63  	}
    64  	s.devices[1] = d
    65  	return 1
    66  }
    67  
    68  type mockRouter struct {
    69  	behavior int
    70  	maxPoll  time.Duration
    71  
    72  	sessionMutex sync.Mutex
    73  	sessions     map[SessionID]*session
    74  }
    75  
    76  const (
    77  	GoodRouter                   = 0
    78  	BadRouterCorruptedSession    = 1 << iota
    79  	BadRouterCorruptedSender     = 1 << iota
    80  	BadRouterCorruptedCiphertext = 1 << iota
    81  	BadRouterReorder             = 1 << iota
    82  	BadRouterDrop                = 1 << iota
    83  )
    84  
    85  func corruptMessage(behavior int, msg []byte) {
    86  	if (behavior & BadRouterCorruptedSession) != 0 {
    87  		msg[23] ^= 0x80
    88  	}
    89  	if (behavior & BadRouterCorruptedSender) != 0 {
    90  		msg[10] ^= 0x40
    91  	}
    92  	if (behavior & BadRouterCorruptedCiphertext) != 0 {
    93  		msg[len(msg)-10] ^= 0x01
    94  	}
    95  }
    96  
    97  func newMockRouterWithBehavior(b int) *mockRouter {
    98  	return &mockRouter{
    99  		behavior: b,
   100  		sessions: make(map[SessionID]*session),
   101  	}
   102  }
   103  
   104  func newMockRouterWithBehaviorAndMaxPoll(b int, mp time.Duration) *mockRouter {
   105  	return &mockRouter{
   106  		behavior: b,
   107  		maxPoll:  mp,
   108  		sessions: make(map[SessionID]*session),
   109  	}
   110  }
   111  
   112  func (ss *simplexSession) post(seqno Seqno, msg []byte) error {
   113  	ss.ch <- message{seqno, msg}
   114  	return nil
   115  }
   116  
   117  type lookupType int
   118  
   119  const (
   120  	bySender   lookupType = 0
   121  	byReceiver lookupType = 1
   122  )
   123  
   124  func (s *session) findOrMakeSimplexSession(sender DeviceID, lt lookupType) *simplexSession {
   125  	i := s.getDeviceNumber(sender)
   126  	if lt == byReceiver {
   127  		i = 1 - i
   128  	}
   129  	return s.simplexSessions[i]
   130  }
   131  
   132  func (mr *mockRouter) findOrMakeSimplexSession(i SessionID, sender DeviceID, lt lookupType) *simplexSession {
   133  	mr.sessionMutex.Lock()
   134  	defer mr.sessionMutex.Unlock()
   135  
   136  	sess, ok := mr.sessions[i]
   137  	if !ok {
   138  		sess = newSession(i)
   139  		mr.sessions[i] = sess
   140  	}
   141  	return sess.findOrMakeSimplexSession(sender, lt)
   142  }
   143  
   144  func (mr *mockRouter) Post(i SessionID, sender DeviceID, seqno Seqno, msg []byte) error {
   145  	ss := mr.findOrMakeSimplexSession(i, sender, bySender)
   146  	corruptMessage(mr.behavior, msg)
   147  	return ss.post(seqno, msg)
   148  }
   149  
   150  func (ss *simplexSession) get(seqno Seqno, poll time.Duration, behavior int) (ret [][]byte, err error) {
   151  	timeout := false
   152  	handleMessage := func(msg message) {
   153  		ret = append(ret, msg.msg)
   154  	}
   155  	if poll.Nanoseconds() > 0 {
   156  		select {
   157  		case msg := <-ss.ch:
   158  			handleMessage(msg)
   159  		case <-time.After(poll):
   160  			timeout = true
   161  		}
   162  	}
   163  	if !timeout {
   164  	loopMessages:
   165  		for {
   166  			select {
   167  			case msg := <-ss.ch:
   168  				handleMessage(msg)
   169  			default:
   170  				break loopMessages
   171  			}
   172  		}
   173  	}
   174  
   175  	if (behavior&BadRouterReorder) != 0 && len(ret) > 1 {
   176  		ret[0], ret[1] = ret[1], ret[0]
   177  	}
   178  	if (behavior&BadRouterDrop) != 0 && len(ret) > 1 {
   179  		ret = ret[1:]
   180  	}
   181  
   182  	return ret, err
   183  }
   184  
   185  func (mr *mockRouter) Get(i SessionID, receiver DeviceID, seqno Seqno, poll time.Duration) ([][]byte, error) {
   186  	ss := mr.findOrMakeSimplexSession(i, receiver, byReceiver)
   187  	if mr.maxPoll > time.Duration(0) && poll > mr.maxPoll {
   188  		poll = mr.maxPoll
   189  	}
   190  	return ss.get(seqno, poll, mr.behavior)
   191  }
   192  
   193  func genSecret(t *testing.T) (ret Secret) {
   194  	_, err := rand.Read(ret[:])
   195  	if err != nil {
   196  		t.Fatal(err)
   197  	}
   198  	return ret
   199  }
   200  
   201  func genDeviceID(t *testing.T) (ret DeviceID) {
   202  	_, err := rand.Read(ret[:])
   203  	if err != nil {
   204  		t.Fatal(err)
   205  	}
   206  	return ret
   207  }
   208  
   209  type testLogCtx struct {
   210  	sync.Mutex
   211  	t *testing.T
   212  }
   213  
   214  func newTestLogCtx(t *testing.T) (ret *testLogCtx, closer func()) {
   215  	ret = &testLogCtx{t: t}
   216  	closer = func() {
   217  		ret.Lock()
   218  		defer ret.Unlock()
   219  		ret.t = nil
   220  	}
   221  	return ret, closer
   222  }
   223  
   224  func (t *testLogCtx) Debug(format string, args ...interface{}) {
   225  	t.Lock()
   226  	if t.t != nil {
   227  		t.t.Logf(format, args...)
   228  	}
   229  	t.Unlock()
   230  }
   231  
   232  func genNewConn(t *testLogCtx, mr MessageRouter, s Secret, d DeviceID, rt time.Duration) net.Conn {
   233  	ret, err := NewConn(context.TODO(), t, mr, s, d, rt)
   234  	if err != nil {
   235  		t.t.Fatal(err)
   236  	}
   237  	return ret
   238  }
   239  
   240  func genConnPair(t *testLogCtx, behavior int, readTimeout time.Duration) (c1 net.Conn, c2 net.Conn, d1 DeviceID, d2 DeviceID) {
   241  	r := newMockRouterWithBehavior(behavior)
   242  	s := genSecret(t.t)
   243  	d1 = genDeviceID(t.t)
   244  	d2 = genDeviceID(t.t)
   245  	c1 = genNewConn(t, r, s, d1, readTimeout)
   246  	c2 = genNewConn(t, r, s, d2, readTimeout)
   247  	return
   248  }
   249  
   250  func maybeDisableTest(t *testing.T) {
   251  	if runtime.GOOS == "windows" {
   252  		t.Skip()
   253  	}
   254  }
   255  
   256  func TestHello(t *testing.T) {
   257  	testLogCtx, cleanup := newTestLogCtx(t)
   258  	defer cleanup()
   259  	c1, c2, _, _ := genConnPair(testLogCtx, GoodRouter, time.Duration(0))
   260  	txt := []byte("hello friend")
   261  	if _, err := c1.Write(txt); err != nil {
   262  		t.Fatal(err)
   263  	}
   264  	buf := make([]byte, 100)
   265  	if n, err := c2.Read(buf); err != nil {
   266  		t.Fatal(err)
   267  	} else if n != len(txt) {
   268  		t.Fatal("bad read len")
   269  	} else if !bytes.Equal(buf[0:n], txt) {
   270  		t.Fatal("wrong message back")
   271  	}
   272  	txt2 := []byte("pong PONG pong PONG pong PONG")
   273  	if _, err := c2.Write(txt2); err != nil {
   274  		t.Fatal(err)
   275  	} else if n, err := c1.Read(buf); err != nil {
   276  		t.Fatal(err)
   277  	} else if n != len(txt2) {
   278  		t.Fatal("bad read len")
   279  	} else if !bytes.Equal(buf[0:n], txt2) {
   280  		t.Fatal("wrong ponged text")
   281  	}
   282  }
   283  
   284  func TestBadMetadata(t *testing.T) {
   285  	testLogCtx, cleanup := newTestLogCtx(t)
   286  	defer cleanup()
   287  
   288  	testBehavior := func(b int, wanted error) {
   289  		c1, c2, _, _ := genConnPair(testLogCtx, b, time.Duration(0))
   290  		txt := []byte("hello friend")
   291  		if _, err := c1.Write(txt); err != nil {
   292  			t.Fatal(err)
   293  		}
   294  		buf := make([]byte, 100)
   295  		if _, err := c2.Read(buf); err == nil {
   296  			t.Fatalf("behavior %d: wanted an error, didn't get one", b)
   297  		} else if err != wanted {
   298  			t.Fatalf("behavior %d: wanted error '%v', got '%v'", b, err, wanted)
   299  		}
   300  	}
   301  	testBehavior(BadRouterCorruptedSession, ErrBadMetadata)
   302  	testBehavior(BadRouterCorruptedSender, ErrBadMetadata)
   303  	testBehavior(BadRouterCorruptedCiphertext, ErrDecryption)
   304  }
   305  
   306  func TestReadDeadline(t *testing.T) {
   307  	testLogCtx, cleanup := newTestLogCtx(t)
   308  	defer cleanup()
   309  	c1, c2, _, _ := genConnPair(testLogCtx, GoodRouter, time.Duration(0))
   310  	wait := time.Duration(10) * time.Millisecond
   311  	err := c2.SetReadDeadline(time.Now().Add(wait))
   312  	require.NoError(t, err)
   313  	go func() {
   314  		time.Sleep(wait * 2)
   315  		_, _ = c1.Write([]byte("hello friend"))
   316  	}()
   317  	buf := make([]byte, 100)
   318  	_, err = c2.Read(buf)
   319  	if err != ErrTimedOut {
   320  		t.Fatalf("wanted a read timeout")
   321  	}
   322  }
   323  
   324  func TestReadTimeout(t *testing.T) {
   325  	testLogCtx, cleanup := newTestLogCtx(t)
   326  	defer cleanup()
   327  	wait := time.Duration(10) * time.Millisecond
   328  	c1, c2, _, _ := genConnPair(testLogCtx, GoodRouter, wait)
   329  	go func() {
   330  		time.Sleep(wait * 2)
   331  		_, _ = c1.Write([]byte("hello friend"))
   332  	}()
   333  	buf := make([]byte, 100)
   334  	_, err := c2.Read(buf)
   335  	if err != ErrTimedOut {
   336  		t.Fatalf("wanted a read timeout")
   337  	}
   338  }
   339  
   340  func TestReadDelayedWrite(t *testing.T) {
   341  	maybeDisableTest(t)
   342  	testLogCtx, cleanup := newTestLogCtx(t)
   343  	defer cleanup()
   344  	c1, c2, _, _ := genConnPair(testLogCtx, GoodRouter, time.Duration(0))
   345  	wait := time.Duration(50) * time.Millisecond
   346  	err := c2.SetReadDeadline(time.Now().Add(wait))
   347  	require.NoError(t, err)
   348  	text := "hello friend"
   349  	go func() {
   350  		time.Sleep(wait / 32)
   351  		_, _ = c1.Write([]byte(text))
   352  	}()
   353  	buf := make([]byte, 100)
   354  	n, err := c2.Read(buf)
   355  	if err != nil {
   356  		t.Fatal(err)
   357  	}
   358  	if n != len(text) {
   359  		t.Fatalf("wrong read length")
   360  	}
   361  }
   362  
   363  func TestMultipleWritesOneRead(t *testing.T) {
   364  	maybeDisableTest(t)
   365  	testLogCtx, cleanup := newTestLogCtx(t)
   366  	defer cleanup()
   367  	c1, c2, _, _ := genConnPair(testLogCtx, GoodRouter, time.Duration(0))
   368  	msgs := []string{
   369  		"Alas, poor Yorick! I knew him, Horatio: a fellow",
   370  		"of infinite jest, of most excellent fancy: he hath",
   371  		"borne me on his back a thousand times; and now, how",
   372  		"abhorred in my imagination it is! my gorge rims at",
   373  		"it.",
   374  	}
   375  	for i, m := range msgs {
   376  		if i > 0 {
   377  			m = "\n" + m
   378  		}
   379  		if _, err := c1.Write([]byte(m)); err != nil {
   380  			t.Fatal(err)
   381  		}
   382  	}
   383  	buf := make([]byte, 1000)
   384  	if n, err := c2.Read(buf); err != nil {
   385  		t.Fatal(err)
   386  	} else if strings.Join(msgs, "\n") != string(buf[0:n]) {
   387  		t.Fatal("string mismatch")
   388  	}
   389  }
   390  
   391  func TestOneWriteMultipleReads(t *testing.T) {
   392  	testLogCtx, cleanup := newTestLogCtx(t)
   393  	defer cleanup()
   394  	c1, c2, _, _ := genConnPair(testLogCtx, GoodRouter, time.Duration(0))
   395  	msg := `Crows maunder on the petrified fairway.
   396  Absence! My heart grows tense
   397  as though a harpoon were sparring for the kill.`
   398  	if _, err := c1.Write([]byte(msg)); err != nil {
   399  		return
   400  	}
   401  	small := make([]byte, 3)
   402  	var buf []byte
   403  	for {
   404  		if n, err := c2.Read(small); err != nil && err != ErrAgain {
   405  			t.Fatal(err)
   406  		} else if n == 0 {
   407  			if err != ErrAgain {
   408  				t.Fatalf("exepcted ErrAgain if we read 0 bytes, but got %v", err)
   409  			}
   410  			break
   411  		} else {
   412  			buf = append(buf, small[0:n]...)
   413  		}
   414  	}
   415  	if string(buf) != msg {
   416  		t.Fatal("message mismatch")
   417  	}
   418  }
   419  
   420  func TestReorder(t *testing.T) {
   421  	testLogCtx, cleanup := newTestLogCtx(t)
   422  	defer cleanup()
   423  	c1, c2, _, _ := genConnPair(testLogCtx, BadRouterReorder, time.Duration(0))
   424  	msgs := []string{
   425  		"Alas, poor Yorick! I knew him, Horatio: a fellow",
   426  		"of infinite jest, of most excellent fancy: he hath",
   427  		"borne me on his back a thousand times; and now, how",
   428  		"abhorred in my imagination it is! my gorge rims at",
   429  		"it.",
   430  	}
   431  	for i, m := range msgs {
   432  		if i > 0 {
   433  			m = "\n" + m
   434  		}
   435  		if _, err := c1.Write([]byte(m)); err != nil {
   436  			t.Fatal(err)
   437  		}
   438  	}
   439  	buf := make([]byte, 1000)
   440  	_, err := c2.Read(buf)
   441  	if _, ok := err.(ErrBadPacketSequence); !ok {
   442  		t.Fatalf("expected an ErrBadPacketSequence; got %v", err)
   443  	}
   444  }
   445  
   446  func TestDrop(t *testing.T) {
   447  	testLogCtx, cleanup := newTestLogCtx(t)
   448  	defer cleanup()
   449  	c1, c2, _, _ := genConnPair(testLogCtx, BadRouterDrop, time.Duration(0))
   450  	msgs := []string{
   451  		"Alas, poor Yorick! I knew him, Horatio: a fellow",
   452  		"of infinite jest, of most excellent fancy: he hath",
   453  		"borne me on his back a thousand times; and now, how",
   454  		"abhorred in my imagination it is! my gorge rims at",
   455  		"it.",
   456  	}
   457  	for i, m := range msgs {
   458  		if i > 0 {
   459  			m = "\n" + m
   460  		}
   461  		if _, err := c1.Write([]byte(m)); err != nil {
   462  			t.Fatal(err)
   463  		}
   464  	}
   465  	buf := make([]byte, 1000)
   466  	_, err := c2.Read(buf)
   467  	if _, ok := err.(ErrBadPacketSequence); !ok {
   468  		t.Fatalf("expected an ErrBadPacketSequence; got %v", err)
   469  	}
   470  }
   471  
   472  func TestClose(t *testing.T) {
   473  	testLogCtx, cleanup := newTestLogCtx(t)
   474  	defer cleanup()
   475  	c1, c2, _, _ := genConnPair(testLogCtx, GoodRouter, time.Duration(4)*time.Second)
   476  	msg := "Hello friend. I'm going to mic drop."
   477  	if _, err := c1.Write([]byte(msg)); err != nil {
   478  		t.Fatal(err)
   479  	}
   480  	if err := c1.Close(); err != nil {
   481  		t.Fatal(err)
   482  	}
   483  	buf := make([]byte, 1000)
   484  	if n, err := c2.Read(buf); err != nil {
   485  		t.Fatal(err)
   486  	} else if n != len(msg) {
   487  		t.Fatalf("short read: %d v %d: %v", n, len(msg), msg)
   488  	} else if string(buf[0:n]) != msg {
   489  		t.Fatal("wrong msg")
   490  	}
   491  
   492  	// Assert we get an EOF now and forever...
   493  	for i := 0; i < 3; i++ {
   494  		if n, err := c2.Read(buf); err != io.EOF {
   495  			t.Fatalf("expected EOF, but got err = %v", err)
   496  		} else if n != 0 {
   497  			t.Fatalf("Expected 0-length read, but got %d", n)
   498  		}
   499  	}
   500  }
   501  
   502  func TestErrAgain(t *testing.T) {
   503  	testLogCtx, cleanup := newTestLogCtx(t)
   504  	defer cleanup()
   505  	_, c2, _, _ := genConnPair(testLogCtx, GoodRouter, time.Duration(0))
   506  	buf := make([]byte, 100)
   507  	if n, err := c2.Read(buf); err != ErrAgain {
   508  		t.Fatalf("wanted ErrAgain, but got err = %v", err)
   509  	} else if n != 0 {
   510  		t.Fatalf("Wanted 0 bytes back; got %d", n)
   511  	}
   512  }
   513  
   514  func TestPollLoopSuccess(t *testing.T) {
   515  	maybeDisableTest(t)
   516  
   517  	testLogCtx, cleanup := newTestLogCtx(t)
   518  	defer cleanup()
   519  
   520  	wait := time.Duration(100) * time.Millisecond
   521  	r := newMockRouterWithBehaviorAndMaxPoll(GoodRouter, wait/128)
   522  	s := genSecret(t)
   523  	d1 := genDeviceID(t)
   524  	d2 := genDeviceID(t)
   525  	c1 := genNewConn(testLogCtx, r, s, d1, wait)
   526  	c2 := genNewConn(testLogCtx, r, s, d2, wait)
   527  
   528  	text := "poll for this, will you?"
   529  
   530  	go func() {
   531  		time.Sleep(wait / 32)
   532  		_, _ = c1.Write([]byte(text))
   533  	}()
   534  	buf := make([]byte, 100)
   535  	n, err := c2.Read(buf)
   536  	if err != nil {
   537  		t.Fatal(err)
   538  	}
   539  	if n != len(text) {
   540  		t.Fatalf("wrong read length")
   541  	}
   542  }
   543  
   544  func TestPollLoopTimeout(t *testing.T) {
   545  	maybeDisableTest(t)
   546  
   547  	testLogCtx, cleanup := newTestLogCtx(t)
   548  	defer cleanup()
   549  
   550  	wait := time.Duration(8) * time.Millisecond
   551  	r := newMockRouterWithBehaviorAndMaxPoll(GoodRouter, wait/32)
   552  	s := genSecret(t)
   553  	d1 := genDeviceID(t)
   554  	d2 := genDeviceID(t)
   555  	c1 := genNewConn(testLogCtx, r, s, d1, wait)
   556  	c2 := genNewConn(testLogCtx, r, s, d2, wait)
   557  
   558  	text := "poll for this, will you?"
   559  
   560  	go func() {
   561  		time.Sleep(wait * 2)
   562  		_, _ = c1.Write([]byte(text))
   563  	}()
   564  	buf := make([]byte, 100)
   565  	if _, err := c2.Read(buf); err != ErrTimedOut {
   566  		t.Fatalf("Wanted ErrTimedOut; got %v", err)
   567  	}
   568  }