github.com/keybase/client/go@v0.0.0-20241007131713-f10651d043c8/kex2/provisionee.go (about) 1 // Copyright 2015 Keybase, Inc. All rights reserved. Use of 2 // this source code is governed by the included BSD license. 3 4 package kex2 5 6 import ( 7 "io" 8 "time" 9 10 keybase1 "github.com/keybase/client/go/protocol/keybase1" 11 "github.com/keybase/go-framed-msgpack-rpc/rpc" 12 "golang.org/x/net/context" 13 ) 14 15 type provisionee struct { 16 baseDevice 17 arg ProvisioneeArg 18 done chan error 19 startedCounterSign chan struct{} 20 21 server *rpc.Server 22 serverDoneCh <-chan struct{} 23 } 24 25 // Provisionee is an interface that abstracts out the crypto and session 26 // management that a provisionee needs to do as part of the protocol. 27 type Provisionee interface { 28 GetLogFactory() rpc.LogFactory 29 GetNetworkInstrumenter() rpc.NetworkInstrumenterStorage 30 HandleHello(ctx context.Context, a keybase1.HelloArg) (keybase1.HelloRes, error) 31 HandleHello2(ctx context.Context, a keybase1.Hello2Arg) (keybase1.Hello2Res, error) 32 HandleDidCounterSign(ctx context.Context, b []byte) error 33 HandleDidCounterSign2(ctx context.Context, a keybase1.DidCounterSign2Arg) error 34 } 35 36 // ProvisioneeArg provides the details that a provisionee needs in order 37 // to run its course 38 type ProvisioneeArg struct { 39 KexBaseArg 40 Provisionee Provisionee 41 } 42 43 func newProvisionee(arg ProvisioneeArg) *provisionee { 44 ret := &provisionee{ 45 baseDevice: baseDevice{ 46 start: make(chan struct{}), 47 }, 48 arg: arg, 49 done: make(chan error), 50 startedCounterSign: make(chan struct{}), 51 } 52 return ret 53 } 54 55 // RunProvisionee runs a provisionee given the necessary arguments. 56 func RunProvisionee(arg ProvisioneeArg) error { 57 p := newProvisionee(arg) 58 return p.run() 59 } 60 61 // Hello is called via the RPC server interface by the remote client. 62 // It in turn delegates the work to the passed in Provisionee interface, 63 // calling HandleHello() 64 func (p *provisionee) Hello(ctx context.Context, arg keybase1.HelloArg) (res keybase1.HelloRes, err error) { 65 close(p.start) 66 res, err = p.arg.Provisionee.HandleHello(ctx, arg) 67 if err != nil { 68 p.done <- err 69 } 70 return res, err 71 } 72 73 // Hello2 is called via the RPC server interface by the remote client. 74 // It in turn delegates the work to the passed in Provisionee interface, 75 // calling HandleHello() 76 func (p *provisionee) Hello2(ctx context.Context, arg keybase1.Hello2Arg) (res keybase1.Hello2Res, err error) { 77 close(p.start) 78 res, err = p.arg.Provisionee.HandleHello2(ctx, arg) 79 if err != nil { 80 p.done <- err 81 } 82 return res, err 83 } 84 85 // DidCounterSign is called via the RPC server interface by the remote client. 86 // It in turn delegates the work to the passed in Provisionee interface, 87 // calling HandleDidCounterSign() 88 func (p *provisionee) DidCounterSign(ctx context.Context, sig []byte) (err error) { 89 p.startedCounterSign <- struct{}{} 90 err = p.arg.Provisionee.HandleDidCounterSign(ctx, sig) 91 p.done <- err 92 return err 93 } 94 95 // DidCounterSign2 is called via the RPC server interface by the remote client. 96 // It in turn delegates the work to the passed in Provisionee interface, 97 // calling HandleDidCounterSign() 98 func (p *provisionee) DidCounterSign2(ctx context.Context, arg keybase1.DidCounterSign2Arg) (err error) { 99 p.startedCounterSign <- struct{}{} 100 err = p.arg.Provisionee.HandleDidCounterSign2(ctx, arg) 101 p.done <- err 102 return err 103 } 104 105 func (p *provisionee) run() (err error) { 106 107 if err = p.setDeviceID(); err != nil { 108 return err 109 } 110 111 if err = p.startServer(p.arg.Secret); err != nil { 112 return err 113 } 114 115 if err = p.pickFirstConnection(); err != nil { 116 return err 117 } 118 119 // If we hit a done or a server EOF before we started doing the 120 // countersign operation, then we have to bail out. 121 select { 122 case err := <-p.done: 123 return err 124 case <-p.serverDoneCh: 125 return p.server.Err() 126 case <-p.startedCounterSign: 127 } 128 129 // After we've started the counter sign operation, we don't care if the 130 // provisioner explodes. It makes sense to try to finish, however we can. 131 // Thus, we wait for EOF from the server in a Go routine. 132 go func() { 133 <-p.serverDoneCh 134 tmp := p.server.Err() 135 if tmp != nil && tmp != io.EOF { 136 p.debug("provisionee#run: RPC server died with an error: %s", tmp.Error()) 137 } 138 }() 139 140 // Since we've already started the countersign operation, just wait around all 141 // day until it's done. 142 return <-p.done 143 } 144 145 func (p *provisionee) debug(fmtString string, args ...interface{}) { 146 if p.arg.LogCtx != nil { 147 p.arg.LogCtx.Debug(fmtString, args...) 148 } 149 } 150 151 func (p *provisionee) startServer(s Secret) (err error) { 152 if p.conn, err = NewConn(p.arg.Ctx, p.arg.LogCtx, p.arg.Mr, s, p.deviceID, p.arg.Timeout); err != nil { 153 return err 154 } 155 prots := []rpc.Protocol{ 156 keybase1.Kex2ProvisioneeProtocol(p), 157 } 158 prots = append(prots, keybase1.Kex2Provisionee2Protocol(p)) 159 p.xp = rpc.NewTransport(p.conn, p.arg.Provisionee.GetLogFactory(), 160 p.arg.Provisionee.GetNetworkInstrumenter(), nil, rpc.DefaultMaxFrameLength) 161 srv := rpc.NewServer(p.xp, nil) 162 for _, prot := range prots { 163 if err = srv.Register(prot); err != nil { 164 return err 165 } 166 } 167 168 p.server = srv 169 p.serverDoneCh = srv.Run() 170 return nil 171 } 172 173 func (p *provisionee) pickFirstConnection() (err error) { 174 175 select { 176 case <-p.start: 177 case sec := <-p.arg.SecretChannel: 178 if len(sec) != SecretLen { 179 return ErrBadSecret 180 } 181 p.conn.Close() 182 err = p.startServer(sec) 183 if err != nil { 184 return err 185 } 186 cli := keybase1.Kex2ProvisionerClient{ 187 Cli: rpc.NewClient(p.xp, nil, nil)} 188 if err = cli.KexStart(p.arg.Ctx); err != nil { 189 return err 190 } 191 case <-p.arg.Ctx.Done(): 192 err = ErrCanceled 193 case <-time.After(p.arg.Timeout): 194 err = ErrTimedOut 195 } 196 return 197 } 198 199 func (p *provisionee) setDeviceID() (err error) { 200 p.deviceID, err = p.arg.getDeviceID() 201 return err 202 }