github.com/keybase/client/go@v0.0.0-20241007131713-f10651d043c8/kex2/provisionee.go (about)

     1  // Copyright 2015 Keybase, Inc. All rights reserved. Use of
     2  // this source code is governed by the included BSD license.
     3  
     4  package kex2
     5  
     6  import (
     7  	"io"
     8  	"time"
     9  
    10  	keybase1 "github.com/keybase/client/go/protocol/keybase1"
    11  	"github.com/keybase/go-framed-msgpack-rpc/rpc"
    12  	"golang.org/x/net/context"
    13  )
    14  
    15  type provisionee struct {
    16  	baseDevice
    17  	arg                ProvisioneeArg
    18  	done               chan error
    19  	startedCounterSign chan struct{}
    20  
    21  	server       *rpc.Server
    22  	serverDoneCh <-chan struct{}
    23  }
    24  
    25  // Provisionee is an interface that abstracts out the crypto and session
    26  // management that a provisionee needs to do as part of the protocol.
    27  type Provisionee interface {
    28  	GetLogFactory() rpc.LogFactory
    29  	GetNetworkInstrumenter() rpc.NetworkInstrumenterStorage
    30  	HandleHello(ctx context.Context, a keybase1.HelloArg) (keybase1.HelloRes, error)
    31  	HandleHello2(ctx context.Context, a keybase1.Hello2Arg) (keybase1.Hello2Res, error)
    32  	HandleDidCounterSign(ctx context.Context, b []byte) error
    33  	HandleDidCounterSign2(ctx context.Context, a keybase1.DidCounterSign2Arg) error
    34  }
    35  
    36  // ProvisioneeArg provides the details that a provisionee needs in order
    37  // to run its course
    38  type ProvisioneeArg struct {
    39  	KexBaseArg
    40  	Provisionee Provisionee
    41  }
    42  
    43  func newProvisionee(arg ProvisioneeArg) *provisionee {
    44  	ret := &provisionee{
    45  		baseDevice: baseDevice{
    46  			start: make(chan struct{}),
    47  		},
    48  		arg:                arg,
    49  		done:               make(chan error),
    50  		startedCounterSign: make(chan struct{}),
    51  	}
    52  	return ret
    53  }
    54  
    55  // RunProvisionee runs a provisionee given the necessary arguments.
    56  func RunProvisionee(arg ProvisioneeArg) error {
    57  	p := newProvisionee(arg)
    58  	return p.run()
    59  }
    60  
    61  // Hello is called via the RPC server interface by the remote client.
    62  // It in turn delegates the work to the passed in Provisionee interface,
    63  // calling HandleHello()
    64  func (p *provisionee) Hello(ctx context.Context, arg keybase1.HelloArg) (res keybase1.HelloRes, err error) {
    65  	close(p.start)
    66  	res, err = p.arg.Provisionee.HandleHello(ctx, arg)
    67  	if err != nil {
    68  		p.done <- err
    69  	}
    70  	return res, err
    71  }
    72  
    73  // Hello2 is called via the RPC server interface by the remote client.
    74  // It in turn delegates the work to the passed in Provisionee interface,
    75  // calling HandleHello()
    76  func (p *provisionee) Hello2(ctx context.Context, arg keybase1.Hello2Arg) (res keybase1.Hello2Res, err error) {
    77  	close(p.start)
    78  	res, err = p.arg.Provisionee.HandleHello2(ctx, arg)
    79  	if err != nil {
    80  		p.done <- err
    81  	}
    82  	return res, err
    83  }
    84  
    85  // DidCounterSign is called via the RPC server interface by the remote client.
    86  // It in turn delegates the work to the passed in Provisionee interface,
    87  // calling HandleDidCounterSign()
    88  func (p *provisionee) DidCounterSign(ctx context.Context, sig []byte) (err error) {
    89  	p.startedCounterSign <- struct{}{}
    90  	err = p.arg.Provisionee.HandleDidCounterSign(ctx, sig)
    91  	p.done <- err
    92  	return err
    93  }
    94  
    95  // DidCounterSign2 is called via the RPC server interface by the remote client.
    96  // It in turn delegates the work to the passed in Provisionee interface,
    97  // calling HandleDidCounterSign()
    98  func (p *provisionee) DidCounterSign2(ctx context.Context, arg keybase1.DidCounterSign2Arg) (err error) {
    99  	p.startedCounterSign <- struct{}{}
   100  	err = p.arg.Provisionee.HandleDidCounterSign2(ctx, arg)
   101  	p.done <- err
   102  	return err
   103  }
   104  
   105  func (p *provisionee) run() (err error) {
   106  
   107  	if err = p.setDeviceID(); err != nil {
   108  		return err
   109  	}
   110  
   111  	if err = p.startServer(p.arg.Secret); err != nil {
   112  		return err
   113  	}
   114  
   115  	if err = p.pickFirstConnection(); err != nil {
   116  		return err
   117  	}
   118  
   119  	// If we hit a done or a server EOF before we started doing the
   120  	// countersign operation, then we have to bail out.
   121  	select {
   122  	case err := <-p.done:
   123  		return err
   124  	case <-p.serverDoneCh:
   125  		return p.server.Err()
   126  	case <-p.startedCounterSign:
   127  	}
   128  
   129  	// After we've started the counter sign operation, we don't care if the
   130  	// provisioner explodes. It makes sense to try to finish, however we can.
   131  	// Thus, we wait for EOF from the server in a Go routine.
   132  	go func() {
   133  		<-p.serverDoneCh
   134  		tmp := p.server.Err()
   135  		if tmp != nil && tmp != io.EOF {
   136  			p.debug("provisionee#run: RPC server died with an error: %s", tmp.Error())
   137  		}
   138  	}()
   139  
   140  	// Since we've already started the countersign operation, just wait around all
   141  	// day until it's done.
   142  	return <-p.done
   143  }
   144  
   145  func (p *provisionee) debug(fmtString string, args ...interface{}) {
   146  	if p.arg.LogCtx != nil {
   147  		p.arg.LogCtx.Debug(fmtString, args...)
   148  	}
   149  }
   150  
   151  func (p *provisionee) startServer(s Secret) (err error) {
   152  	if p.conn, err = NewConn(p.arg.Ctx, p.arg.LogCtx, p.arg.Mr, s, p.deviceID, p.arg.Timeout); err != nil {
   153  		return err
   154  	}
   155  	prots := []rpc.Protocol{
   156  		keybase1.Kex2ProvisioneeProtocol(p),
   157  	}
   158  	prots = append(prots, keybase1.Kex2Provisionee2Protocol(p))
   159  	p.xp = rpc.NewTransport(p.conn, p.arg.Provisionee.GetLogFactory(),
   160  		p.arg.Provisionee.GetNetworkInstrumenter(), nil, rpc.DefaultMaxFrameLength)
   161  	srv := rpc.NewServer(p.xp, nil)
   162  	for _, prot := range prots {
   163  		if err = srv.Register(prot); err != nil {
   164  			return err
   165  		}
   166  	}
   167  
   168  	p.server = srv
   169  	p.serverDoneCh = srv.Run()
   170  	return nil
   171  }
   172  
   173  func (p *provisionee) pickFirstConnection() (err error) {
   174  
   175  	select {
   176  	case <-p.start:
   177  	case sec := <-p.arg.SecretChannel:
   178  		if len(sec) != SecretLen {
   179  			return ErrBadSecret
   180  		}
   181  		p.conn.Close()
   182  		err = p.startServer(sec)
   183  		if err != nil {
   184  			return err
   185  		}
   186  		cli := keybase1.Kex2ProvisionerClient{
   187  			Cli: rpc.NewClient(p.xp, nil, nil)}
   188  		if err = cli.KexStart(p.arg.Ctx); err != nil {
   189  			return err
   190  		}
   191  	case <-p.arg.Ctx.Done():
   192  		err = ErrCanceled
   193  	case <-time.After(p.arg.Timeout):
   194  		err = ErrTimedOut
   195  	}
   196  	return
   197  }
   198  
   199  func (p *provisionee) setDeviceID() (err error) {
   200  	p.deviceID, err = p.arg.getDeviceID()
   201  	return err
   202  }