github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/kex2/rpc_test.go (about)

     1  // Copyright 2015 Keybase, Inc. All rights reserved. Use of
     2  // this source code is governed by the included BSD license.
     3  
     4  package kex2
     5  
     6  import (
     7  	"crypto/rand"
     8  	"encoding/hex"
     9  	"errors"
    10  	"io"
    11  	"strings"
    12  	"testing"
    13  	"time"
    14  
    15  	keybase1 "github.com/keybase/client/go/protocol/keybase1"
    16  	"github.com/keybase/go-framed-msgpack-rpc/rpc"
    17  	"golang.org/x/net/context"
    18  )
    19  
    20  const (
    21  	GoodProvisionee                  = 0
    22  	BadProvisioneeFailHello          = 1 << iota
    23  	BadProvisioneeFailDidCounterSign = 1 << iota
    24  	BadProvisioneeSlowHello          = 1 << iota
    25  	BadProvisioneeSlowDidCounterSign = 1 << iota
    26  	BadProvisioneeCancel             = 1 << iota
    27  )
    28  
    29  type mockProvisioner struct {
    30  	uid keybase1.UID
    31  }
    32  
    33  type mockProvisionee struct {
    34  	behavior int
    35  }
    36  
    37  func newMockProvisioner(t *testing.T) *mockProvisioner {
    38  	return &mockProvisioner{
    39  		uid: genUID(t),
    40  	}
    41  }
    42  
    43  type nullLogOutput struct {
    44  }
    45  
    46  func (n *nullLogOutput) Error(s string, args ...interface{})   {}
    47  func (n *nullLogOutput) Warning(s string, args ...interface{}) {}
    48  func (n *nullLogOutput) Info(s string, args ...interface{})    {}
    49  func (n *nullLogOutput) Debug(s string, args ...interface{})   {}
    50  func (n *nullLogOutput) Profile(s string, args ...interface{}) {}
    51  
    52  var _ rpc.LogOutput = (*nullLogOutput)(nil)
    53  
    54  func makeLogFactory() rpc.LogFactory {
    55  	if testing.Verbose() {
    56  		return nil
    57  	}
    58  	return rpc.NewSimpleLogFactory(&nullLogOutput{}, nil)
    59  }
    60  
    61  func genUID(t *testing.T) keybase1.UID {
    62  	uid := make([]byte, 8)
    63  	if _, err := rand.Read(uid); err != nil {
    64  		t.Fatalf("rand failed: %v\n", err)
    65  	}
    66  	return keybase1.UID(hex.EncodeToString(uid))
    67  }
    68  
    69  func genKeybase1DeviceID(t *testing.T) keybase1.DeviceID {
    70  	did := make([]byte, 16)
    71  	if _, err := rand.Read(did); err != nil {
    72  		t.Fatalf("rand failed: %v\n", err)
    73  	}
    74  	return keybase1.DeviceID(hex.EncodeToString(did))
    75  }
    76  
    77  func newMockProvisionee(t *testing.T, behavior int) *mockProvisionee {
    78  	return &mockProvisionee{behavior}
    79  }
    80  
    81  func (mp *mockProvisioner) GetLogFactory() rpc.LogFactory {
    82  	return makeLogFactory()
    83  }
    84  
    85  func (mp *mockProvisioner) GetNetworkInstrumenter() rpc.NetworkInstrumenterStorage {
    86  	return &rpc.DummyInstrumentationStorage{}
    87  }
    88  
    89  func (mp *mockProvisioner) CounterSign(input keybase1.HelloRes) (output []byte, err error) {
    90  	output = []byte(string(input))
    91  	return
    92  }
    93  
    94  func (mp *mockProvisioner) CounterSign2(input keybase1.Hello2Res) (output keybase1.DidCounterSign2Arg, err error) {
    95  	output.Sig, err = mp.CounterSign(input.SigPayload)
    96  	return
    97  }
    98  
    99  func (mp *mockProvisioner) GetHelloArg() (res keybase1.HelloArg, err error) {
   100  	res.Uid = mp.uid
   101  	return res, err
   102  }
   103  func (mp *mockProvisioner) GetHello2Arg() (res keybase1.Hello2Arg, err error) {
   104  	res.Uid = mp.uid
   105  	return res, err
   106  }
   107  
   108  func (mp *mockProvisionee) GetLogFactory() rpc.LogFactory {
   109  	return makeLogFactory()
   110  }
   111  
   112  func (mp *mockProvisionee) GetNetworkInstrumenter() rpc.NetworkInstrumenterStorage {
   113  	return &rpc.DummyInstrumentationStorage{}
   114  }
   115  
   116  var ErrHandleHello = errors.New("handle hello failure")
   117  var ErrHandleDidCounterSign = errors.New("handle didCounterSign failure")
   118  var testTimeout = time.Duration(500) * time.Millisecond
   119  
   120  func (mp *mockProvisionee) HandleHello2(ctx context.Context, arg2 keybase1.Hello2Arg) (res keybase1.Hello2Res, err error) {
   121  	arg1 := keybase1.HelloArg{
   122  		Uid:     arg2.Uid,
   123  		SigBody: arg2.SigBody,
   124  	}
   125  	res.SigPayload, err = mp.HandleHello(ctx, arg1)
   126  	return res, err
   127  }
   128  
   129  func (mp *mockProvisionee) HandleHello(_ context.Context, arg keybase1.HelloArg) (res keybase1.HelloRes, err error) {
   130  	if (mp.behavior & BadProvisioneeSlowHello) != 0 {
   131  		time.Sleep(testTimeout * 8)
   132  	}
   133  	if (mp.behavior & BadProvisioneeFailHello) != 0 {
   134  		err = ErrHandleHello
   135  		return
   136  	}
   137  	res = keybase1.HelloRes(arg.SigBody)
   138  	return
   139  }
   140  
   141  func (mp *mockProvisionee) HandleDidCounterSign(_ context.Context, _ []byte) error {
   142  	if (mp.behavior & BadProvisioneeSlowDidCounterSign) != 0 {
   143  		time.Sleep(testTimeout * 8)
   144  	}
   145  	if (mp.behavior & BadProvisioneeFailDidCounterSign) != 0 {
   146  		return ErrHandleDidCounterSign
   147  	}
   148  	return nil
   149  }
   150  
   151  func (mp *mockProvisionee) HandleDidCounterSign2(ctx context.Context, arg keybase1.DidCounterSign2Arg) error {
   152  	return mp.HandleDidCounterSign(ctx, arg.Sig)
   153  }
   154  
   155  func testProtocolXWithBehavior(t *testing.T, provisioneeBehavior int) (results [2]error) {
   156  
   157  	timeout := testTimeout
   158  	router := newMockRouterWithBehaviorAndMaxPoll(GoodRouter, timeout)
   159  
   160  	s2 := genSecret(t)
   161  
   162  	ch := make(chan error, 3)
   163  
   164  	secretCh := make(chan Secret)
   165  
   166  	ctx, cancelFn := context.WithCancel(context.Background())
   167  
   168  	testLogCtx, cleanup := newTestLogCtx(t)
   169  	defer cleanup()
   170  
   171  	// Run the provisioner
   172  	go func() {
   173  		err := RunProvisioner(ProvisionerArg{
   174  			KexBaseArg: KexBaseArg{
   175  				Ctx:           ctx,
   176  				LogCtx:        testLogCtx,
   177  				Mr:            router,
   178  				Secret:        genSecret(t),
   179  				DeviceID:      genKeybase1DeviceID(t),
   180  				SecretChannel: secretCh,
   181  				Timeout:       timeout,
   182  			},
   183  			Provisioner: newMockProvisioner(t),
   184  		})
   185  		ch <- err
   186  	}()
   187  
   188  	// Run the privisionee
   189  	go func() {
   190  		err := RunProvisionee(ProvisioneeArg{
   191  			KexBaseArg: KexBaseArg{
   192  				Ctx:           context.Background(),
   193  				LogCtx:        testLogCtx,
   194  				Mr:            router,
   195  				Secret:        s2,
   196  				DeviceID:      genKeybase1DeviceID(t),
   197  				SecretChannel: make(chan Secret),
   198  				Timeout:       timeout,
   199  			},
   200  			Provisionee: newMockProvisionee(t, provisioneeBehavior),
   201  		})
   202  		ch <- err
   203  	}()
   204  
   205  	if (provisioneeBehavior & BadProvisioneeCancel) != 0 {
   206  		go func() {
   207  			time.Sleep(testTimeout / 20)
   208  			cancelFn()
   209  		}()
   210  	}
   211  
   212  	secretCh <- s2
   213  
   214  	for i := 0; i < 2; i++ {
   215  		if e, eof := <-ch; !eof {
   216  			t.Fatalf("got unexpected channel close (try %d)", i)
   217  		} else if e != nil {
   218  			results[i] = e
   219  		}
   220  	}
   221  
   222  	return results
   223  }
   224  
   225  func TestFullProtocolXSuccess(t *testing.T) {
   226  	results := testProtocolXWithBehavior(t, GoodProvisionee)
   227  	for i, e := range results {
   228  		if e != nil {
   229  			t.Fatalf("Bad error %d: %v", i, e)
   230  		}
   231  	}
   232  }
   233  
   234  // Since errors are exported as strings, then we should just test that the
   235  // right kind of error was specified
   236  func eeq(e1, e2 error) bool {
   237  	return e1 != nil && e1.Error() == e2.Error()
   238  }
   239  
   240  // errHasSuffix makes sure that err's string has errSuffix's string as
   241  // a suffix. This is necessary as go-codec prepends stuff to any
   242  // errors it catches.
   243  func errHasSuffix(err, errSuffix error) bool {
   244  	return err != nil && strings.HasSuffix(err.Error(), errSuffix.Error())
   245  }
   246  
   247  func TestFullProtocolXProvisioneeFailHello(t *testing.T) {
   248  	results := testProtocolXWithBehavior(t, BadProvisioneeFailHello)
   249  	if !eeq(results[0], ErrHandleHello) {
   250  		t.Fatalf("Bad error 0: %v", results[0])
   251  	}
   252  	if !eeq(results[1], ErrHandleHello) {
   253  		t.Fatalf("Bad error 1: %v", results[1])
   254  	}
   255  }
   256  
   257  func TestFullProtocolXProvisioneeFailDidCounterSign(t *testing.T) {
   258  	results := testProtocolXWithBehavior(t, BadProvisioneeFailDidCounterSign)
   259  	if !eeq(results[0], ErrHandleDidCounterSign) {
   260  		t.Fatalf("Bad error 0: %v", results[0])
   261  	}
   262  	if !eeq(results[1], ErrHandleDidCounterSign) {
   263  		t.Fatalf("Bad error 1: %v", results[1])
   264  	}
   265  }
   266  
   267  func TestFullProtocolXProvisioneeSlowHello(t *testing.T) {
   268  	results := testProtocolXWithBehavior(t, BadProvisioneeSlowHello)
   269  	for i, e := range results {
   270  		if !errHasSuffix(e, ErrTimedOut) && !errHasSuffix(e, io.EOF) && !errHasSuffix(e, ErrHelloTimeout) {
   271  			t.Fatalf("Bad error %d: %v", i, e)
   272  		}
   273  	}
   274  }
   275  
   276  func TestFullProtocolXProvisioneeSlowHelloWithCancel(t *testing.T) {
   277  	results := testProtocolXWithBehavior(t, BadProvisioneeSlowHello|BadProvisioneeCancel)
   278  	for i, e := range results {
   279  		if !eeq(e, ErrCanceled) && !eeq(e, io.EOF) {
   280  			t.Fatalf("Bad error %d: %v", i, e)
   281  		}
   282  	}
   283  }
   284  
   285  func TestFullProtocolY(t *testing.T) {
   286  
   287  	timeout := time.Duration(60) * time.Second
   288  	router := newMockRouterWithBehaviorAndMaxPoll(GoodRouter, timeout)
   289  
   290  	s1 := genSecret(t)
   291  
   292  	ch := make(chan error, 3)
   293  
   294  	secretCh := make(chan Secret)
   295  	testLogCtx, cleanup := newTestLogCtx(t)
   296  	defer cleanup()
   297  
   298  	// Run the provisioner
   299  	go func() {
   300  		err := RunProvisioner(ProvisionerArg{
   301  			KexBaseArg: KexBaseArg{
   302  				Ctx:           context.TODO(),
   303  				LogCtx:        testLogCtx,
   304  				Mr:            router,
   305  				Secret:        s1,
   306  				DeviceID:      genKeybase1DeviceID(t),
   307  				SecretChannel: make(chan Secret),
   308  				Timeout:       timeout,
   309  			},
   310  			Provisioner: newMockProvisioner(t),
   311  		})
   312  		ch <- err
   313  	}()
   314  
   315  	// Run the provisionee
   316  	go func() {
   317  		err := RunProvisionee(ProvisioneeArg{
   318  			KexBaseArg: KexBaseArg{
   319  				Ctx:           context.TODO(),
   320  				LogCtx:        testLogCtx,
   321  				Mr:            router,
   322  				Secret:        genSecret(t),
   323  				DeviceID:      genKeybase1DeviceID(t),
   324  				SecretChannel: secretCh,
   325  				Timeout:       timeout,
   326  			},
   327  			Provisionee: newMockProvisionee(t, GoodProvisionee),
   328  		})
   329  		ch <- err
   330  	}()
   331  
   332  	secretCh <- s1
   333  
   334  	for i := 0; i < 2; i++ {
   335  		if e, eof := <-ch; !eof {
   336  			t.Fatalf("got unexpected channel close (try %d)", i)
   337  		} else if e != nil {
   338  			t.Fatalf("Unexpected error (receive %d): %v", i, e)
   339  		}
   340  	}
   341  
   342  }