github.com/keybase/client/go@v0.0.0-20240309051027-028f7c731f8b/kex2/rpc_test.go (about) 1 // Copyright 2015 Keybase, Inc. All rights reserved. Use of 2 // this source code is governed by the included BSD license. 3 4 package kex2 5 6 import ( 7 "crypto/rand" 8 "encoding/hex" 9 "errors" 10 "io" 11 "strings" 12 "testing" 13 "time" 14 15 keybase1 "github.com/keybase/client/go/protocol/keybase1" 16 "github.com/keybase/go-framed-msgpack-rpc/rpc" 17 "golang.org/x/net/context" 18 ) 19 20 const ( 21 GoodProvisionee = 0 22 BadProvisioneeFailHello = 1 << iota 23 BadProvisioneeFailDidCounterSign = 1 << iota 24 BadProvisioneeSlowHello = 1 << iota 25 BadProvisioneeSlowDidCounterSign = 1 << iota 26 BadProvisioneeCancel = 1 << iota 27 ) 28 29 type mockProvisioner struct { 30 uid keybase1.UID 31 } 32 33 type mockProvisionee struct { 34 behavior int 35 } 36 37 func newMockProvisioner(t *testing.T) *mockProvisioner { 38 return &mockProvisioner{ 39 uid: genUID(t), 40 } 41 } 42 43 type nullLogOutput struct { 44 } 45 46 func (n *nullLogOutput) Error(s string, args ...interface{}) {} 47 func (n *nullLogOutput) Warning(s string, args ...interface{}) {} 48 func (n *nullLogOutput) Info(s string, args ...interface{}) {} 49 func (n *nullLogOutput) Debug(s string, args ...interface{}) {} 50 func (n *nullLogOutput) Profile(s string, args ...interface{}) {} 51 52 var _ rpc.LogOutput = (*nullLogOutput)(nil) 53 54 func makeLogFactory() rpc.LogFactory { 55 if testing.Verbose() { 56 return nil 57 } 58 return rpc.NewSimpleLogFactory(&nullLogOutput{}, nil) 59 } 60 61 func genUID(t *testing.T) keybase1.UID { 62 uid := make([]byte, 8) 63 if _, err := rand.Read(uid); err != nil { 64 t.Fatalf("rand failed: %v\n", err) 65 } 66 return keybase1.UID(hex.EncodeToString(uid)) 67 } 68 69 func genKeybase1DeviceID(t *testing.T) keybase1.DeviceID { 70 did := make([]byte, 16) 71 if _, err := rand.Read(did); err != nil { 72 t.Fatalf("rand failed: %v\n", err) 73 } 74 return keybase1.DeviceID(hex.EncodeToString(did)) 75 } 76 77 func newMockProvisionee(t *testing.T, behavior int) *mockProvisionee { 78 return &mockProvisionee{behavior} 79 } 80 81 func (mp *mockProvisioner) GetLogFactory() rpc.LogFactory { 82 return makeLogFactory() 83 } 84 85 func (mp *mockProvisioner) GetNetworkInstrumenter() rpc.NetworkInstrumenterStorage { 86 return &rpc.DummyInstrumentationStorage{} 87 } 88 89 func (mp *mockProvisioner) CounterSign(input keybase1.HelloRes) (output []byte, err error) { 90 output = []byte(string(input)) 91 return 92 } 93 94 func (mp *mockProvisioner) CounterSign2(input keybase1.Hello2Res) (output keybase1.DidCounterSign2Arg, err error) { 95 output.Sig, err = mp.CounterSign(input.SigPayload) 96 return 97 } 98 99 func (mp *mockProvisioner) GetHelloArg() (res keybase1.HelloArg, err error) { 100 res.Uid = mp.uid 101 return res, err 102 } 103 func (mp *mockProvisioner) GetHello2Arg() (res keybase1.Hello2Arg, err error) { 104 res.Uid = mp.uid 105 return res, err 106 } 107 108 func (mp *mockProvisionee) GetLogFactory() rpc.LogFactory { 109 return makeLogFactory() 110 } 111 112 func (mp *mockProvisionee) GetNetworkInstrumenter() rpc.NetworkInstrumenterStorage { 113 return &rpc.DummyInstrumentationStorage{} 114 } 115 116 var ErrHandleHello = errors.New("handle hello failure") 117 var ErrHandleDidCounterSign = errors.New("handle didCounterSign failure") 118 var testTimeout = time.Duration(500) * time.Millisecond 119 120 func (mp *mockProvisionee) HandleHello2(ctx context.Context, arg2 keybase1.Hello2Arg) (res keybase1.Hello2Res, err error) { 121 arg1 := keybase1.HelloArg{ 122 Uid: arg2.Uid, 123 SigBody: arg2.SigBody, 124 } 125 res.SigPayload, err = mp.HandleHello(ctx, arg1) 126 return res, err 127 } 128 129 func (mp *mockProvisionee) HandleHello(_ context.Context, arg keybase1.HelloArg) (res keybase1.HelloRes, err error) { 130 if (mp.behavior & BadProvisioneeSlowHello) != 0 { 131 time.Sleep(testTimeout * 8) 132 } 133 if (mp.behavior & BadProvisioneeFailHello) != 0 { 134 err = ErrHandleHello 135 return 136 } 137 res = keybase1.HelloRes(arg.SigBody) 138 return 139 } 140 141 func (mp *mockProvisionee) HandleDidCounterSign(_ context.Context, _ []byte) error { 142 if (mp.behavior & BadProvisioneeSlowDidCounterSign) != 0 { 143 time.Sleep(testTimeout * 8) 144 } 145 if (mp.behavior & BadProvisioneeFailDidCounterSign) != 0 { 146 return ErrHandleDidCounterSign 147 } 148 return nil 149 } 150 151 func (mp *mockProvisionee) HandleDidCounterSign2(ctx context.Context, arg keybase1.DidCounterSign2Arg) error { 152 return mp.HandleDidCounterSign(ctx, arg.Sig) 153 } 154 155 func testProtocolXWithBehavior(t *testing.T, provisioneeBehavior int) (results [2]error) { 156 157 timeout := testTimeout 158 router := newMockRouterWithBehaviorAndMaxPoll(GoodRouter, timeout) 159 160 s2 := genSecret(t) 161 162 ch := make(chan error, 3) 163 164 secretCh := make(chan Secret) 165 166 ctx, cancelFn := context.WithCancel(context.Background()) 167 168 testLogCtx, cleanup := newTestLogCtx(t) 169 defer cleanup() 170 171 // Run the provisioner 172 go func() { 173 err := RunProvisioner(ProvisionerArg{ 174 KexBaseArg: KexBaseArg{ 175 Ctx: ctx, 176 LogCtx: testLogCtx, 177 Mr: router, 178 Secret: genSecret(t), 179 DeviceID: genKeybase1DeviceID(t), 180 SecretChannel: secretCh, 181 Timeout: timeout, 182 }, 183 Provisioner: newMockProvisioner(t), 184 }) 185 ch <- err 186 }() 187 188 // Run the privisionee 189 go func() { 190 err := RunProvisionee(ProvisioneeArg{ 191 KexBaseArg: KexBaseArg{ 192 Ctx: context.Background(), 193 LogCtx: testLogCtx, 194 Mr: router, 195 Secret: s2, 196 DeviceID: genKeybase1DeviceID(t), 197 SecretChannel: make(chan Secret), 198 Timeout: timeout, 199 }, 200 Provisionee: newMockProvisionee(t, provisioneeBehavior), 201 }) 202 ch <- err 203 }() 204 205 if (provisioneeBehavior & BadProvisioneeCancel) != 0 { 206 go func() { 207 time.Sleep(testTimeout / 20) 208 cancelFn() 209 }() 210 } 211 212 secretCh <- s2 213 214 for i := 0; i < 2; i++ { 215 if e, eof := <-ch; !eof { 216 t.Fatalf("got unexpected channel close (try %d)", i) 217 } else if e != nil { 218 results[i] = e 219 } 220 } 221 222 return results 223 } 224 225 func TestFullProtocolXSuccess(t *testing.T) { 226 results := testProtocolXWithBehavior(t, GoodProvisionee) 227 for i, e := range results { 228 if e != nil { 229 t.Fatalf("Bad error %d: %v", i, e) 230 } 231 } 232 } 233 234 // Since errors are exported as strings, then we should just test that the 235 // right kind of error was specified 236 func eeq(e1, e2 error) bool { 237 return e1 != nil && e1.Error() == e2.Error() 238 } 239 240 // errHasSuffix makes sure that err's string has errSuffix's string as 241 // a suffix. This is necessary as go-codec prepends stuff to any 242 // errors it catches. 243 func errHasSuffix(err, errSuffix error) bool { 244 return err != nil && strings.HasSuffix(err.Error(), errSuffix.Error()) 245 } 246 247 func TestFullProtocolXProvisioneeFailHello(t *testing.T) { 248 results := testProtocolXWithBehavior(t, BadProvisioneeFailHello) 249 if !eeq(results[0], ErrHandleHello) { 250 t.Fatalf("Bad error 0: %v", results[0]) 251 } 252 if !eeq(results[1], ErrHandleHello) { 253 t.Fatalf("Bad error 1: %v", results[1]) 254 } 255 } 256 257 func TestFullProtocolXProvisioneeFailDidCounterSign(t *testing.T) { 258 results := testProtocolXWithBehavior(t, BadProvisioneeFailDidCounterSign) 259 if !eeq(results[0], ErrHandleDidCounterSign) { 260 t.Fatalf("Bad error 0: %v", results[0]) 261 } 262 if !eeq(results[1], ErrHandleDidCounterSign) { 263 t.Fatalf("Bad error 1: %v", results[1]) 264 } 265 } 266 267 func TestFullProtocolXProvisioneeSlowHello(t *testing.T) { 268 results := testProtocolXWithBehavior(t, BadProvisioneeSlowHello) 269 for i, e := range results { 270 if !errHasSuffix(e, ErrTimedOut) && !errHasSuffix(e, io.EOF) && !errHasSuffix(e, ErrHelloTimeout) { 271 t.Fatalf("Bad error %d: %v", i, e) 272 } 273 } 274 } 275 276 func TestFullProtocolXProvisioneeSlowHelloWithCancel(t *testing.T) { 277 results := testProtocolXWithBehavior(t, BadProvisioneeSlowHello|BadProvisioneeCancel) 278 for i, e := range results { 279 if !eeq(e, ErrCanceled) && !eeq(e, io.EOF) { 280 t.Fatalf("Bad error %d: %v", i, e) 281 } 282 } 283 } 284 285 func TestFullProtocolY(t *testing.T) { 286 287 timeout := time.Duration(60) * time.Second 288 router := newMockRouterWithBehaviorAndMaxPoll(GoodRouter, timeout) 289 290 s1 := genSecret(t) 291 292 ch := make(chan error, 3) 293 294 secretCh := make(chan Secret) 295 testLogCtx, cleanup := newTestLogCtx(t) 296 defer cleanup() 297 298 // Run the provisioner 299 go func() { 300 err := RunProvisioner(ProvisionerArg{ 301 KexBaseArg: KexBaseArg{ 302 Ctx: context.TODO(), 303 LogCtx: testLogCtx, 304 Mr: router, 305 Secret: s1, 306 DeviceID: genKeybase1DeviceID(t), 307 SecretChannel: make(chan Secret), 308 Timeout: timeout, 309 }, 310 Provisioner: newMockProvisioner(t), 311 }) 312 ch <- err 313 }() 314 315 // Run the provisionee 316 go func() { 317 err := RunProvisionee(ProvisioneeArg{ 318 KexBaseArg: KexBaseArg{ 319 Ctx: context.TODO(), 320 LogCtx: testLogCtx, 321 Mr: router, 322 Secret: genSecret(t), 323 DeviceID: genKeybase1DeviceID(t), 324 SecretChannel: secretCh, 325 Timeout: timeout, 326 }, 327 Provisionee: newMockProvisionee(t, GoodProvisionee), 328 }) 329 ch <- err 330 }() 331 332 secretCh <- s1 333 334 for i := 0; i < 2; i++ { 335 if e, eof := <-ch; !eof { 336 t.Fatalf("got unexpected channel close (try %d)", i) 337 } else if e != nil { 338 t.Fatalf("Unexpected error (receive %d): %v", i, e) 339 } 340 } 341 342 }