github.com/keybase/client/go@v0.0.0-20241007131713-f10651d043c8/kex2/provisioner.go (about)

     1  // Copyright 2015 Keybase, Inc. All rights reserved. Use of
     2  // this source code is governed by the included BSD license.
     3  
     4  package kex2
     5  
     6  import (
     7  	"net"
     8  	"strings"
     9  	"time"
    10  
    11  	"golang.org/x/net/context"
    12  
    13  	keybase1 "github.com/keybase/client/go/protocol/keybase1"
    14  	"github.com/keybase/go-framed-msgpack-rpc/rpc"
    15  )
    16  
    17  type provisioner struct {
    18  	baseDevice
    19  	arg           ProvisionerArg
    20  	helloReceived bool
    21  }
    22  
    23  // Provisioner is an interface that abstracts out the crypto and session
    24  // management that a provisioner needs to do as part of the protocol.
    25  type Provisioner interface {
    26  	GetHelloArg() (keybase1.HelloArg, error)
    27  	GetHello2Arg() (keybase1.Hello2Arg, error)
    28  	CounterSign(keybase1.HelloRes) ([]byte, error)
    29  	CounterSign2(keybase1.Hello2Res) (keybase1.DidCounterSign2Arg, error)
    30  	GetLogFactory() rpc.LogFactory
    31  	GetNetworkInstrumenter() rpc.NetworkInstrumenterStorage
    32  }
    33  
    34  // ProvisionerArg provides the details that a provisioner needs in order
    35  // to run its course
    36  type ProvisionerArg struct {
    37  	KexBaseArg
    38  	Provisioner  Provisioner
    39  	HelloTimeout time.Duration
    40  }
    41  
    42  func newProvisioner(arg ProvisionerArg) *provisioner {
    43  	if arg.HelloTimeout == 0 {
    44  		arg.HelloTimeout = arg.Timeout
    45  	}
    46  	ret := &provisioner{
    47  		baseDevice: baseDevice{
    48  			start: make(chan struct{}),
    49  		},
    50  		arg: arg,
    51  	}
    52  	return ret
    53  }
    54  
    55  func (p *provisioner) debug(fmtString string, args ...interface{}) {
    56  	if p.arg.LogCtx != nil {
    57  		p.arg.LogCtx.Debug(fmtString, args...)
    58  	}
    59  }
    60  
    61  // RunProvisioner runs a provisioner given the necessary arguments.
    62  func RunProvisioner(arg ProvisionerArg) error {
    63  	p := newProvisioner(arg)
    64  	err := p.run()
    65  	p.close() // ignore any errors in closing the channel
    66  	return err
    67  }
    68  
    69  func (p *provisioner) close() (err error) {
    70  	if p.conn != nil {
    71  		err = p.conn.Close()
    72  	}
    73  	return err
    74  }
    75  
    76  func (p *provisioner) KexStart(_ context.Context) error {
    77  	close(p.start)
    78  	return nil
    79  }
    80  
    81  func (p *provisioner) run() (err error) {
    82  	if err = p.setDeviceID(); err != nil {
    83  		return err
    84  	}
    85  	if err = p.pickFirstConnection(); err != nil {
    86  		return err
    87  	}
    88  	return p.runProtocolWithCancel()
    89  }
    90  
    91  func (k KexBaseArg) getDeviceID() (ret DeviceID, err error) {
    92  	err = k.DeviceID.ToBytes(ret[:])
    93  	return ret, err
    94  }
    95  
    96  func (p *provisioner) setDeviceID() (err error) {
    97  	p.deviceID, err = p.arg.getDeviceID()
    98  	return err
    99  }
   100  
   101  func (p *provisioner) pickFirstConnection() (err error) {
   102  
   103  	// This connection is auto-closed at the end of this function, so if
   104  	// you don't want it to close, then set it to nil.  See the first
   105  	// case in the select below.
   106  	var conn net.Conn
   107  	var xp rpc.Transporter
   108  
   109  	defer func() {
   110  		if conn != nil {
   111  			conn.Close()
   112  		}
   113  	}()
   114  
   115  	// Only make a channel if we were provided a secret to start it with.
   116  	// If not, we'll just have to wait for a message on p.arg.SecretChannel
   117  	// and use the provisionee's channel.
   118  	if len(p.arg.Secret) != 0 {
   119  		if conn, err = NewConn(p.arg.Ctx, p.arg.LogCtx, p.arg.Mr, p.arg.Secret, p.deviceID, p.arg.Timeout); err != nil {
   120  			return err
   121  		}
   122  		prot := keybase1.Kex2ProvisionerProtocol(p)
   123  		xp = rpc.NewTransport(conn, p.arg.Provisioner.GetLogFactory(),
   124  			p.arg.Provisioner.GetNetworkInstrumenter(), nil, rpc.DefaultMaxFrameLength)
   125  		srv := rpc.NewServer(xp, nil)
   126  		if err = srv.Register(prot); err != nil {
   127  			return err
   128  		}
   129  		serverDoneCh := srv.Run()
   130  		// TODO: Do something with serverDoneCh.
   131  		_ = serverDoneCh
   132  	}
   133  
   134  	select {
   135  	case <-p.start:
   136  		p.conn = conn
   137  		conn = nil // so it's not closed in the defer()'ed close
   138  		p.xp = xp
   139  	case sec := <-p.arg.SecretChannel:
   140  		if len(sec) != SecretLen {
   141  			return ErrBadSecret
   142  		}
   143  		if p.conn, err = NewConn(p.arg.Ctx, p.arg.LogCtx, p.arg.Mr, sec, p.deviceID, p.arg.Timeout); err != nil {
   144  			return err
   145  		}
   146  		p.xp = rpc.NewTransport(p.conn, p.arg.Provisioner.GetLogFactory(),
   147  			p.arg.Provisioner.GetNetworkInstrumenter(), nil, rpc.DefaultMaxFrameLength)
   148  	case <-p.arg.Ctx.Done():
   149  		err = ErrCanceled
   150  	case <-time.After(p.arg.Timeout):
   151  		err = ErrTimedOut
   152  	}
   153  	return
   154  }
   155  
   156  func (p *provisioner) runProtocolWithCancel() (err error) {
   157  	ch := make(chan error)
   158  	go func() {
   159  		ch <- p.runProtocol()
   160  	}()
   161  	select {
   162  	case <-p.arg.Ctx.Done():
   163  		p.canceled = true
   164  		return ErrCanceled
   165  	case err = <-ch:
   166  		if err == context.Canceled && !p.helloReceived {
   167  			return ErrHelloTimeout
   168  		}
   169  		return err
   170  	}
   171  }
   172  
   173  func (p *provisioner) runProtocol() (err error) {
   174  	var fallback bool
   175  	p.debug("+ provisioner#runProtocol: try V2")
   176  	fallback, err = p.runProtocolV2()
   177  	p.debug("- provisioner#runProtocol -> %v, %v", fallback, err)
   178  	if fallback {
   179  		p.debug("+ provisioner#runProtocol: fallback to V1")
   180  		err = p.runProtocolV1()
   181  		p.debug("- provisioner#runProtocol V1 -> %v", err)
   182  	}
   183  	return err
   184  }
   185  
   186  func (p *provisioner) runProtocolV2() (fallback bool, err error) {
   187  	cli := keybase1.Kex2Provisionee2Client{Cli: rpc.NewClient(p.xp, nil, nil)}
   188  	var helloArg keybase1.Hello2Arg
   189  	helloArg, err = p.arg.Provisioner.GetHello2Arg()
   190  	if err != nil {
   191  		return false, err
   192  	}
   193  	var res keybase1.Hello2Res
   194  	if res, err = cli.Hello2(context.TODO(), helloArg); err != nil {
   195  		if strings.Contains(err.Error(), "protocol not found: keybase.1.Kex2Provisionee2") {
   196  			return true, nil
   197  		}
   198  		return false, err
   199  	}
   200  	if p.canceled {
   201  		return false, ErrCanceled
   202  	}
   203  	p.helloReceived = true
   204  	var counterSign2Arg keybase1.DidCounterSign2Arg
   205  	if counterSign2Arg, err = p.arg.Provisioner.CounterSign2(res); err != nil {
   206  		return false, err
   207  	}
   208  	if err = cli.DidCounterSign2(context.TODO(), counterSign2Arg); err != nil {
   209  		return false, err
   210  	}
   211  	return false, nil
   212  }
   213  
   214  func (p *provisioner) runProtocolV1() (err error) {
   215  	cli := keybase1.Kex2ProvisioneeClient{Cli: rpc.NewClient(p.xp, nil, nil)}
   216  	var helloArg keybase1.HelloArg
   217  	helloArg, err = p.arg.Provisioner.GetHelloArg()
   218  	if err != nil {
   219  		return
   220  	}
   221  	var res keybase1.HelloRes
   222  	if res, err = cli.Hello(context.TODO(), helloArg); err != nil {
   223  		return
   224  	}
   225  	if p.canceled {
   226  		return ErrCanceled
   227  	}
   228  	p.helloReceived = true
   229  	var counterSigned []byte
   230  	if counterSigned, err = p.arg.Provisioner.CounterSign(res); err != nil {
   231  		return err
   232  	}
   233  	return cli.DidCounterSign(context.TODO(), counterSigned)
   234  }