github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/kex2/provisioner.go (about) 1 // Copyright 2015 Keybase, Inc. All rights reserved. Use of 2 // this source code is governed by the included BSD license. 3 4 package kex2 5 6 import ( 7 "net" 8 "strings" 9 "time" 10 11 "golang.org/x/net/context" 12 13 keybase1 "github.com/keybase/client/go/protocol/keybase1" 14 "github.com/keybase/go-framed-msgpack-rpc/rpc" 15 ) 16 17 type provisioner struct { 18 baseDevice 19 arg ProvisionerArg 20 helloReceived bool 21 } 22 23 // Provisioner is an interface that abstracts out the crypto and session 24 // management that a provisioner needs to do as part of the protocol. 25 type Provisioner interface { 26 GetHelloArg() (keybase1.HelloArg, error) 27 GetHello2Arg() (keybase1.Hello2Arg, error) 28 CounterSign(keybase1.HelloRes) ([]byte, error) 29 CounterSign2(keybase1.Hello2Res) (keybase1.DidCounterSign2Arg, error) 30 GetLogFactory() rpc.LogFactory 31 GetNetworkInstrumenter() rpc.NetworkInstrumenterStorage 32 } 33 34 // ProvisionerArg provides the details that a provisioner needs in order 35 // to run its course 36 type ProvisionerArg struct { 37 KexBaseArg 38 Provisioner Provisioner 39 HelloTimeout time.Duration 40 } 41 42 func newProvisioner(arg ProvisionerArg) *provisioner { 43 if arg.HelloTimeout == 0 { 44 arg.HelloTimeout = arg.Timeout 45 } 46 ret := &provisioner{ 47 baseDevice: baseDevice{ 48 start: make(chan struct{}), 49 }, 50 arg: arg, 51 } 52 return ret 53 } 54 55 func (p *provisioner) debug(fmtString string, args ...interface{}) { 56 if p.arg.LogCtx != nil { 57 p.arg.LogCtx.Debug(fmtString, args...) 58 } 59 } 60 61 // RunProvisioner runs a provisioner given the necessary arguments. 62 func RunProvisioner(arg ProvisionerArg) error { 63 p := newProvisioner(arg) 64 err := p.run() 65 p.close() // ignore any errors in closing the channel 66 return err 67 } 68 69 func (p *provisioner) close() (err error) { 70 if p.conn != nil { 71 err = p.conn.Close() 72 } 73 return err 74 } 75 76 func (p *provisioner) KexStart(_ context.Context) error { 77 close(p.start) 78 return nil 79 } 80 81 func (p *provisioner) run() (err error) { 82 if err = p.setDeviceID(); err != nil { 83 return err 84 } 85 if err = p.pickFirstConnection(); err != nil { 86 return err 87 } 88 return p.runProtocolWithCancel() 89 } 90 91 func (k KexBaseArg) getDeviceID() (ret DeviceID, err error) { 92 err = k.DeviceID.ToBytes(ret[:]) 93 return ret, err 94 } 95 96 func (p *provisioner) setDeviceID() (err error) { 97 p.deviceID, err = p.arg.getDeviceID() 98 return err 99 } 100 101 func (p *provisioner) pickFirstConnection() (err error) { 102 103 // This connection is auto-closed at the end of this function, so if 104 // you don't want it to close, then set it to nil. See the first 105 // case in the select below. 106 var conn net.Conn 107 var xp rpc.Transporter 108 109 defer func() { 110 if conn != nil { 111 conn.Close() 112 } 113 }() 114 115 // Only make a channel if we were provided a secret to start it with. 116 // If not, we'll just have to wait for a message on p.arg.SecretChannel 117 // and use the provisionee's channel. 118 if len(p.arg.Secret) != 0 { 119 if conn, err = NewConn(p.arg.Ctx, p.arg.LogCtx, p.arg.Mr, p.arg.Secret, p.deviceID, p.arg.Timeout); err != nil { 120 return err 121 } 122 prot := keybase1.Kex2ProvisionerProtocol(p) 123 xp = rpc.NewTransport(conn, p.arg.Provisioner.GetLogFactory(), 124 p.arg.Provisioner.GetNetworkInstrumenter(), nil, rpc.DefaultMaxFrameLength) 125 srv := rpc.NewServer(xp, nil) 126 if err = srv.Register(prot); err != nil { 127 return err 128 } 129 serverDoneCh := srv.Run() 130 // TODO: Do something with serverDoneCh. 131 _ = serverDoneCh 132 } 133 134 select { 135 case <-p.start: 136 p.conn = conn 137 conn = nil // so it's not closed in the defer()'ed close 138 p.xp = xp 139 case sec := <-p.arg.SecretChannel: 140 if len(sec) != SecretLen { 141 return ErrBadSecret 142 } 143 if p.conn, err = NewConn(p.arg.Ctx, p.arg.LogCtx, p.arg.Mr, sec, p.deviceID, p.arg.Timeout); err != nil { 144 return err 145 } 146 p.xp = rpc.NewTransport(p.conn, p.arg.Provisioner.GetLogFactory(), 147 p.arg.Provisioner.GetNetworkInstrumenter(), nil, rpc.DefaultMaxFrameLength) 148 case <-p.arg.Ctx.Done(): 149 err = ErrCanceled 150 case <-time.After(p.arg.Timeout): 151 err = ErrTimedOut 152 } 153 return 154 } 155 156 func (p *provisioner) runProtocolWithCancel() (err error) { 157 ch := make(chan error) 158 go func() { 159 ch <- p.runProtocol() 160 }() 161 select { 162 case <-p.arg.Ctx.Done(): 163 p.canceled = true 164 return ErrCanceled 165 case err = <-ch: 166 if err == context.Canceled && !p.helloReceived { 167 return ErrHelloTimeout 168 } 169 return err 170 } 171 } 172 173 func (p *provisioner) runProtocol() (err error) { 174 var fallback bool 175 p.debug("+ provisioner#runProtocol: try V2") 176 fallback, err = p.runProtocolV2() 177 p.debug("- provisioner#runProtocol -> %v, %v", fallback, err) 178 if fallback { 179 p.debug("+ provisioner#runProtocol: fallback to V1") 180 err = p.runProtocolV1() 181 p.debug("- provisioner#runProtocol V1 -> %v", err) 182 } 183 return err 184 } 185 186 func (p *provisioner) runProtocolV2() (fallback bool, err error) { 187 cli := keybase1.Kex2Provisionee2Client{Cli: rpc.NewClient(p.xp, nil, nil)} 188 var helloArg keybase1.Hello2Arg 189 helloArg, err = p.arg.Provisioner.GetHello2Arg() 190 if err != nil { 191 return false, err 192 } 193 var res keybase1.Hello2Res 194 if res, err = cli.Hello2(context.TODO(), helloArg); err != nil { 195 if strings.Contains(err.Error(), "protocol not found: keybase.1.Kex2Provisionee2") { 196 return true, nil 197 } 198 return false, err 199 } 200 if p.canceled { 201 return false, ErrCanceled 202 } 203 p.helloReceived = true 204 var counterSign2Arg keybase1.DidCounterSign2Arg 205 if counterSign2Arg, err = p.arg.Provisioner.CounterSign2(res); err != nil { 206 return false, err 207 } 208 if err = cli.DidCounterSign2(context.TODO(), counterSign2Arg); err != nil { 209 return false, err 210 } 211 return false, nil 212 } 213 214 func (p *provisioner) runProtocolV1() (err error) { 215 cli := keybase1.Kex2ProvisioneeClient{Cli: rpc.NewClient(p.xp, nil, nil)} 216 var helloArg keybase1.HelloArg 217 helloArg, err = p.arg.Provisioner.GetHelloArg() 218 if err != nil { 219 return 220 } 221 var res keybase1.HelloRes 222 if res, err = cli.Hello(context.TODO(), helloArg); err != nil { 223 return 224 } 225 if p.canceled { 226 return ErrCanceled 227 } 228 p.helloReceived = true 229 var counterSigned []byte 230 if counterSigned, err = p.arg.Provisioner.CounterSign(res); err != nil { 231 return err 232 } 233 return cli.DidCounterSign(context.TODO(), counterSigned) 234 }