github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/systests/device_common_test.go (about)

     1  package systests
     2  
     3  import (
     4  	"encoding/hex"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/keybase/client/go/client"
    12  	"github.com/keybase/client/go/libkb"
    13  	"github.com/keybase/client/go/logger"
    14  	keybase1 "github.com/keybase/client/go/protocol/keybase1"
    15  	"github.com/keybase/client/go/service"
    16  	"github.com/keybase/clockwork"
    17  	"github.com/keybase/go-framed-msgpack-rpc/rpc"
    18  	"github.com/stretchr/testify/require"
    19  	context "golang.org/x/net/context"
    20  )
    21  
    22  //
    23  // device_common_test is common utilities for testing multiple devices;
    24  // for instance, to test rekey reminders on device revoke.
    25  //
    26  
    27  // Used for provisioning users and new devices within our testing framework
    28  type testUI struct {
    29  	libkb.Contextified
    30  	baseNullUI
    31  	sessionID       int
    32  	outputDescHook  func(libkb.OutputDescriptor, string) error
    33  	promptHook      func(libkb.PromptDescriptor, string) (string, error)
    34  	promptYesNoHook func(libkb.PromptDescriptor, string, libkb.PromptDefault) (bool, error)
    35  }
    36  
    37  var sessionCounter = 1
    38  
    39  func newTestUI(g *libkb.GlobalContext) *testUI {
    40  	x := sessionCounter
    41  	sessionCounter++
    42  	return &testUI{Contextified: libkb.NewContextified(g), sessionID: x}
    43  }
    44  
    45  func (t *testUI) GetTerminalUI() libkb.TerminalUI {
    46  	return t
    47  }
    48  
    49  func (t *testUI) Write(b []byte) (int, error) {
    50  	t.G().Log.Debug("Terminal write: %s", string(b))
    51  	return len(b), nil
    52  }
    53  
    54  func (t *testUI) ErrorWriter() io.Writer {
    55  	return t
    56  }
    57  
    58  func (t *testUI) Output(s string) error {
    59  	t.G().Log.Debug("Terminal Output: %s", s)
    60  	return nil
    61  }
    62  
    63  func (t *testUI) OutputDesc(d libkb.OutputDescriptor, s string) error {
    64  	if t.outputDescHook != nil {
    65  		return t.outputDescHook(d, s)
    66  	}
    67  	return t.Output(s)
    68  }
    69  
    70  func (t *testUI) OutputWriter() io.Writer {
    71  	return t
    72  }
    73  
    74  func (t *testUI) UnescapedOutputWriter() io.Writer {
    75  	return t
    76  }
    77  
    78  func (t *testUI) Printf(f string, args ...interface{}) (int, error) {
    79  	return t.PrintfUnescaped(f, args...)
    80  }
    81  
    82  func (t *testUI) PrintfUnescaped(f string, args ...interface{}) (int, error) {
    83  	s := fmt.Sprintf(f, args...)
    84  	t.G().Log.Debug("Terminal Printf: %s", s)
    85  	return len(s), nil
    86  }
    87  
    88  func (t *testUI) Prompt(d libkb.PromptDescriptor, s string) (string, error) {
    89  	if t.promptHook != nil {
    90  		return t.promptHook(d, s)
    91  	}
    92  	return "", fmt.Errorf("Unhandled prompt: %q (%d)", s, d)
    93  }
    94  
    95  func (t *testUI) PromptForConfirmation(p string) error {
    96  	return fmt.Errorf("unhandled prompt for confirmation: %q", p)
    97  }
    98  
    99  func (t *testUI) PromptPassword(d libkb.PromptDescriptor, s string) (string, error) {
   100  	return "", fmt.Errorf("unhandled prompt for password: %q (%d)", s, d)
   101  }
   102  
   103  func (t *testUI) PromptYesNo(d libkb.PromptDescriptor, s string, def libkb.PromptDefault) (bool, error) {
   104  	if t.promptYesNoHook != nil {
   105  		return t.promptYesNoHook(d, s, def)
   106  	}
   107  	return false, fmt.Errorf("unhandled yes/no prompt: %q (%d)", s, d)
   108  }
   109  
   110  func (t *testUI) Tablify(headings []string, rowfunc func() []string) {
   111  	libkb.Tablify(t.OutputWriter(), headings, rowfunc)
   112  }
   113  
   114  func (t *testUI) PromptPasswordMaybeScripted(_ libkb.PromptDescriptor, _ string) (string, error) {
   115  	return "", nil
   116  }
   117  
   118  func (t *testUI) TerminalSize() (width int, height int) {
   119  	return 80, 24
   120  }
   121  
   122  type backupKey struct {
   123  	KID      keybase1.KID
   124  	deviceID keybase1.DeviceID
   125  	secret   string
   126  }
   127  
   128  // testDevice wraps a mock "device", meaning an independent running service and
   129  // some connected clients. It's forked from deviceWrapper in rekey_test.
   130  type testDevice struct {
   131  	t                  *testing.T
   132  	tctx               *libkb.TestContext
   133  	clones, usedClones []*libkb.TestContext
   134  	stopCh             chan error
   135  	service            *service.Service
   136  	testUI             *testUI
   137  	deviceID           keybase1.DeviceID
   138  	deviceName         string
   139  	deviceKey          keybase1.PublicKey
   140  	cli                *rpc.Client
   141  	srv                *rpc.Server
   142  	userClient         keybase1.UserClient
   143  }
   144  
   145  type testDeviceSet struct {
   146  	t          *testing.T
   147  	log        logger.Logger
   148  	devices    []*testDevice
   149  	fakeClock  clockwork.FakeClock
   150  	backupKeys []backupKey
   151  	username   string
   152  	uid        keybase1.UID
   153  }
   154  
   155  func (d *testDevice) startService(numClones int) {
   156  	for i := 0; i < numClones; i++ {
   157  		d.clones = append(d.clones, cloneContext(d.tctx))
   158  	}
   159  	d.stopCh = make(chan error)
   160  	svc := service.NewService(d.tctx.G, false)
   161  	d.service = svc
   162  	startCh := svc.GetStartChannel()
   163  	go func() {
   164  		d.stopCh <- svc.Run()
   165  	}()
   166  	<-startCh
   167  }
   168  
   169  func (d *testDevice) KID() keybase1.KID { return d.deviceKey.KID }
   170  
   171  func (d *testDevice) startClient() {
   172  	tctx := d.popClone()
   173  	g := tctx.G
   174  	ui := newTestUI(g)
   175  	d.testUI = ui
   176  	launch := func() error {
   177  		cli, xp, err := client.GetRPCClientWithContext(g)
   178  		if err != nil {
   179  			return err
   180  		}
   181  		srv := rpc.NewServer(xp, nil)
   182  		d.cli = cli
   183  		d.srv = srv
   184  		d.userClient = keybase1.UserClient{Cli: cli}
   185  		return nil
   186  	}
   187  
   188  	if err := launch(); err != nil {
   189  		d.t.Fatalf("Failed to launch rekey UI: %s", err)
   190  	}
   191  }
   192  
   193  func (d *testDevice) start(numClones int) *testDevice {
   194  	d.startService(numClones)
   195  	d.startClient()
   196  	return d
   197  }
   198  
   199  func (d *testDevice) stop() error {
   200  	return <-d.stopCh
   201  }
   202  
   203  func (d *testDevice) popClone() *libkb.TestContext {
   204  	if len(d.clones) == 0 {
   205  		panic("ran out of cloned environments")
   206  	}
   207  	ret := d.clones[0]
   208  	// Hold a reference to this clone for cleanup
   209  	d.usedClones = append(d.usedClones, ret)
   210  	d.clones = d.clones[1:]
   211  	return ret
   212  }
   213  
   214  func newTestDeviceSet(t *testing.T, cl clockwork.FakeClock) *testDeviceSet {
   215  	if cl == nil {
   216  		cl = clockwork.NewFakeClockAt(time.Now())
   217  	}
   218  	return &testDeviceSet{
   219  		t:         t,
   220  		fakeClock: cl,
   221  	}
   222  }
   223  
   224  func (s *testDeviceSet) cleanup() {
   225  	for _, od := range s.devices {
   226  		od.tctx.Cleanup()
   227  		if od.service != nil {
   228  			od.service.Stop(0)
   229  			err := od.stop()
   230  			require.NoError(s.t, err)
   231  		}
   232  		for _, cl := range od.clones {
   233  			cl.Cleanup()
   234  		}
   235  		for _, cl := range od.usedClones {
   236  			cl.Cleanup()
   237  		}
   238  	}
   239  }
   240  
   241  func (s *testDeviceSet) newDevice(nm string) *testDevice {
   242  	tctx := setupTest(s.t, nm)
   243  	tctx.G.SetClock(s.fakeClock)
   244  
   245  	// Opportunistically take a log as soon as we make one.
   246  	if s.log == nil {
   247  		s.log = tctx.G.Log
   248  	}
   249  
   250  	installInsecureTriplesec(tctx.G)
   251  
   252  	ret := &testDevice{t: s.t, tctx: tctx, deviceName: nm}
   253  	s.devices = append(s.devices, ret)
   254  	return ret
   255  }
   256  
   257  func (d *testDevice) loadEncryptionKIDs() (devices []keybase1.KID, backups []backupKey) {
   258  	keyMap := make(map[keybase1.KID]keybase1.PublicKey)
   259  	keys, err := d.userClient.LoadMyPublicKeys(context.TODO(), 0)
   260  	if err != nil {
   261  		d.t.Fatalf("Failed to LoadMyPublicKeys: %s", err)
   262  	}
   263  	for _, key := range keys {
   264  		keyMap[key.KID] = key
   265  	}
   266  
   267  	for _, key := range keys {
   268  		if key.IsSibkey {
   269  			continue
   270  		}
   271  		parent, found := keyMap[keybase1.KID(key.ParentID)]
   272  		if !found {
   273  			continue
   274  		}
   275  
   276  		switch parent.DeviceType {
   277  		case keybase1.DeviceTypeV2_PAPER:
   278  			backups = append(backups, backupKey{KID: key.KID, deviceID: parent.DeviceID})
   279  		case keybase1.DeviceTypeV2_DESKTOP:
   280  			devices = append(devices, key.KID)
   281  		default:
   282  		}
   283  	}
   284  	return devices, backups
   285  }
   286  
   287  func (d *testDevice) loadDeviceList() []keybase1.Device {
   288  	cli := keybase1.DeviceClient{Cli: d.cli}
   289  	devices, err := cli.DeviceList(context.TODO(), 0)
   290  	if err != nil {
   291  		d.t.Fatalf("devices: %s", err)
   292  	}
   293  	var ret []keybase1.Device
   294  	for _, device := range devices {
   295  		if device.Type == keybase1.DeviceTypeV2_DESKTOP {
   296  			ret = append(ret, device)
   297  		}
   298  
   299  	}
   300  	return ret
   301  }
   302  
   303  func (s *testDeviceSet) signupUser(dev *testDevice) {
   304  	s.signupUserWithRandomPassphrase(dev, false)
   305  }
   306  
   307  func (s *testDeviceSet) signupUserWithRandomPassphrase(dev *testDevice, randomPassphrase bool) {
   308  	userInfo := randomUser("rekey")
   309  	tctx := dev.popClone()
   310  	g := tctx.G
   311  	signupUI := signupUI{
   312  		info:         userInfo,
   313  		Contextified: libkb.NewContextified(g),
   314  	}
   315  	g.SetUI(&signupUI)
   316  	signup := client.NewCmdSignupRunner(g)
   317  	signup.SetTest()
   318  	if randomPassphrase {
   319  		signup.SetNoPassphrasePrompt()
   320  	}
   321  	if err := signup.Run(); err != nil {
   322  		s.t.Fatal(err)
   323  	}
   324  	s.t.Logf("signed up %s", userInfo.username)
   325  	s.username = userInfo.username
   326  	s.uid = libkb.UsernameToUID(s.username)
   327  	var backupKey backupKey
   328  	deviceKeys, backups := dev.loadEncryptionKIDs()
   329  	if len(deviceKeys) != 1 {
   330  		s.t.Fatalf("Expected 1 device back; got %d", len(deviceKeys))
   331  	}
   332  	if len(backups) != 1 {
   333  		s.t.Fatalf("Expected 1 backup back; got %d", len(backups))
   334  	}
   335  	dev.deviceKey.KID = deviceKeys[0]
   336  	backupKey = backups[0]
   337  	backupKey.secret = signupUI.info.displayedPaperKey
   338  	s.backupKeys = append(s.backupKeys, backupKey)
   339  
   340  	devices := dev.loadDeviceList()
   341  	if len(devices) != 1 {
   342  		s.t.Fatalf("Expected 1 device back; got %d", len(devices))
   343  	}
   344  	dev.deviceID = devices[0].DeviceID
   345  }
   346  
   347  type testProvisionUI struct {
   348  	baseNullUI
   349  	username   string
   350  	deviceName string
   351  	backupKey  backupKey
   352  }
   353  
   354  var _ libkb.LoginUI = (*testProvisionUI)(nil)
   355  
   356  func (r *testProvisionUI) GetEmailOrUsername(context.Context, int) (string, error) {
   357  	return r.username, nil
   358  }
   359  func (r *testProvisionUI) PromptRevokePaperKeys(context.Context, keybase1.PromptRevokePaperKeysArg) (ret bool, err error) {
   360  	return false, nil
   361  }
   362  func (r *testProvisionUI) DisplayPaperKeyPhrase(context.Context, keybase1.DisplayPaperKeyPhraseArg) error {
   363  	return nil
   364  }
   365  func (r *testProvisionUI) DisplayPrimaryPaperKey(context.Context, keybase1.DisplayPrimaryPaperKeyArg) error {
   366  	return nil
   367  }
   368  func (r *testProvisionUI) ChooseProvisioningMethod(context.Context, keybase1.ChooseProvisioningMethodArg) (ret keybase1.ProvisionMethod, err error) {
   369  	return ret, nil
   370  }
   371  func (r *testProvisionUI) ChooseGPGMethod(context.Context, keybase1.ChooseGPGMethodArg) (ret keybase1.GPGMethod, err error) {
   372  	return ret, nil
   373  }
   374  func (r *testProvisionUI) SwitchToGPGSignOK(context.Context, keybase1.SwitchToGPGSignOKArg) (ret bool, err error) {
   375  	return ret, nil
   376  }
   377  func (r *testProvisionUI) ChooseDeviceType(context.Context, keybase1.ChooseDeviceTypeArg) (ret keybase1.DeviceType, err error) {
   378  	return ret, nil
   379  }
   380  func (r *testProvisionUI) DisplayAndPromptSecret(context.Context, keybase1.DisplayAndPromptSecretArg) (ret keybase1.SecretResponse, err error) {
   381  	return ret, nil
   382  }
   383  func (r *testProvisionUI) DisplaySecretExchanged(context.Context, int) error {
   384  	return nil
   385  }
   386  func (r *testProvisionUI) PromptNewDeviceName(context.Context, keybase1.PromptNewDeviceNameArg) (ret string, err error) {
   387  	return r.deviceName, nil
   388  }
   389  func (r *testProvisionUI) ProvisioneeSuccess(context.Context, keybase1.ProvisioneeSuccessArg) error {
   390  	return nil
   391  }
   392  func (r *testProvisionUI) ProvisionerSuccess(context.Context, keybase1.ProvisionerSuccessArg) error {
   393  	return nil
   394  }
   395  func (r *testProvisionUI) ChooseDevice(context.Context, keybase1.ChooseDeviceArg) (ret keybase1.DeviceID, err error) {
   396  	return r.backupKey.deviceID, nil
   397  }
   398  func (r *testProvisionUI) GetPassphrase(context.Context, keybase1.GetPassphraseArg) (ret keybase1.GetPassphraseRes, err error) {
   399  	ret.Passphrase = r.backupKey.secret
   400  	return ret, nil
   401  }
   402  func (r *testProvisionUI) PromptResetAccount(_ context.Context, arg keybase1.PromptResetAccountArg) (keybase1.ResetPromptResponse, error) {
   403  	return keybase1.ResetPromptResponse_NOTHING, nil
   404  }
   405  func (r *testProvisionUI) DisplayResetProgress(_ context.Context, arg keybase1.DisplayResetProgressArg) error {
   406  	return nil
   407  }
   408  func (r *testProvisionUI) ExplainDeviceRecovery(_ context.Context, arg keybase1.ExplainDeviceRecoveryArg) error {
   409  	return nil
   410  }
   411  func (r *testProvisionUI) PromptPassphraseRecovery(_ context.Context, arg keybase1.PromptPassphraseRecoveryArg) (bool, error) {
   412  	return false, nil
   413  }
   414  func (r *testProvisionUI) ChooseDeviceToRecoverWith(_ context.Context, arg keybase1.ChooseDeviceToRecoverWithArg) (keybase1.DeviceID, error) {
   415  	return "", nil
   416  }
   417  func (r *testProvisionUI) DisplayResetMessage(_ context.Context, arg keybase1.DisplayResetMessageArg) error {
   418  	return nil
   419  }
   420  
   421  func (s *testDeviceSet) findNewKIDs(newList []keybase1.KID) []keybase1.KID {
   422  	var ret []keybase1.KID
   423  	for _, newKID := range newList {
   424  		tmpFound := false
   425  		for _, device := range s.devices {
   426  			if !device.KID().IsNil() && newKID.Equal(device.KID()) {
   427  				tmpFound = true
   428  				break
   429  			}
   430  		}
   431  		if !tmpFound {
   432  			ret = append(ret, newKID)
   433  		}
   434  	}
   435  	return ret
   436  }
   437  
   438  func (s *testDeviceSet) findNewDevices(newList []keybase1.Device) []keybase1.Device {
   439  	var ret []keybase1.Device
   440  	for _, newDevice := range newList {
   441  		tmpFound := false
   442  		for _, device := range s.devices {
   443  			if !device.deviceID.IsNil() && newDevice.DeviceID.Eq(device.deviceID) {
   444  				tmpFound = true
   445  				break
   446  			}
   447  		}
   448  		if !tmpFound {
   449  			ret = append(ret, newDevice)
   450  		}
   451  	}
   452  	return ret
   453  }
   454  
   455  func (s *testDeviceSet) provision(d *testDevice) {
   456  	tctx := d.popClone()
   457  	g := tctx.G
   458  	var loginClient keybase1.LoginClient
   459  	ui := &testProvisionUI{username: s.username, backupKey: s.backupKeys[0], deviceName: d.deviceName}
   460  
   461  	launch := func() error {
   462  		cli, xp, err := client.GetRPCClientWithContext(g)
   463  		if err != nil {
   464  			return err
   465  		}
   466  		srv := rpc.NewServer(xp, nil)
   467  		protocols := []rpc.Protocol{
   468  			keybase1.LoginUiProtocol(ui),
   469  			keybase1.SecretUiProtocol(ui),
   470  			keybase1.ProvisionUiProtocol(ui),
   471  		}
   472  		for _, prot := range protocols {
   473  			if err = srv.Register(prot); err != nil {
   474  				return err
   475  			}
   476  		}
   477  		loginClient = keybase1.LoginClient{Cli: cli}
   478  		_ = loginClient
   479  		return nil
   480  	}
   481  
   482  	if err := launch(); err != nil {
   483  		s.t.Fatalf("Failed to login rekey UI: %s", err)
   484  	}
   485  	cmd := client.NewCmdLoginRunner(g)
   486  	if err := cmd.Run(); err != nil {
   487  		s.t.Fatalf("Login failed: %s\n", err)
   488  	}
   489  
   490  	deviceKeys, backups := d.loadEncryptionKIDs()
   491  	deviceKeys = s.findNewKIDs(deviceKeys)
   492  	if len(deviceKeys) != 1 {
   493  		s.t.Fatalf("expected 1 new device encryption key")
   494  	}
   495  	d.deviceKey.KID = deviceKeys[0]
   496  	if len(backups) != 1 {
   497  		s.t.Fatalf("expected 1 backup key only")
   498  	}
   499  	devices := s.findNewDevices(d.loadDeviceList())
   500  	if len(devices) != 1 {
   501  		s.t.Fatalf("expected 1 device ID; got %d", len(devices))
   502  	}
   503  	d.deviceID = devices[0].DeviceID
   504  }
   505  
   506  func (s *testDeviceSet) provisionNewStandaloneDevice(name string, numClones int) *testDevice {
   507  	ret := s.newDevice(name)
   508  	_ = ret.tctx.G.Env.GetConfigWriter().SetBoolAtPath("push.disabled", true)
   509  	ret.start(numClones + 1)
   510  	s.provision(ret)
   511  	return ret
   512  }
   513  
   514  func (s *testDeviceSet) provisionNewDevice(name string, numClones int) *testDevice {
   515  	ret := s.newDevice(name)
   516  	ret.start(numClones + 1)
   517  	s.provision(ret)
   518  	return ret
   519  }
   520  
   521  func newTLFID() keybase1.TLFID {
   522  	var b []byte
   523  	b, err := libkb.RandBytes(16)
   524  	if err != nil {
   525  		return ""
   526  	}
   527  	b[15] = 0x16
   528  	return keybase1.TLFID(hex.EncodeToString(b))
   529  }
   530  
   531  type fakeTLF struct {
   532  	id       keybase1.TLFID
   533  	revision int
   534  }
   535  
   536  func newFakeTLF() *fakeTLF {
   537  	return &fakeTLF{
   538  		id:       newTLFID(),
   539  		revision: 0,
   540  	}
   541  }
   542  
   543  type tlfUser struct {
   544  	UID  keybase1.UID   `json:"uid"`
   545  	Keys []keybase1.KID `json:"encryptKeys"`
   546  }
   547  
   548  type tlfUpdate struct {
   549  	ID        keybase1.TLFID `json:"tlfid"`
   550  	UID       keybase1.UID   `json:"uid"`
   551  	KID       keybase1.KID   `json:"kid"`
   552  	Revision  int            `json:"folderREvision"`
   553  	Writers   []tlfUser      `json:"resolvedWriters"`
   554  	Readers   []tlfUser      `json:"resolvedReaders"`
   555  	IsPrivate bool           `json:"is_private"`
   556  }
   557  
   558  func (d *testDevice) keyNewTLF(uid keybase1.UID, writers []tlfUser, readers []tlfUser) *fakeTLF {
   559  	ret := newFakeTLF()
   560  	d.keyTLF(ret, uid, writers, readers)
   561  	return ret
   562  }
   563  
   564  func (d *testDevice) keyTLF(tlf *fakeTLF, uid keybase1.UID, writers []tlfUser, readers []tlfUser) {
   565  	tlf.revision++
   566  	up := tlfUpdate{
   567  		ID:        tlf.id,
   568  		UID:       uid,
   569  		KID:       d.KID(),
   570  		Revision:  tlf.revision,
   571  		Writers:   writers,
   572  		Readers:   readers,
   573  		IsPrivate: true,
   574  	}
   575  	g := d.tctx.G
   576  	b, err := json.Marshal(up)
   577  	if err != nil {
   578  		d.t.Fatalf("error marshalling: %s", err)
   579  	}
   580  	mctx := libkb.NewMetaContextTODO(g)
   581  	apiArg := libkb.APIArg{
   582  		Endpoint: "test/fake_generic_tlf",
   583  		Args: libkb.HTTPArgs{
   584  			"tlf_info": libkb.S{Val: string(b)},
   585  		},
   586  		SessionType: libkb.APISessionTypeREQUIRED,
   587  	}
   588  	_, err = g.API.Post(mctx, apiArg)
   589  	if err != nil {
   590  		d.t.Fatalf("post error: %s", err)
   591  	}
   592  }