github.com/ethersphere/bee/v2@v2.2.0/pkg/api/pss_test.go (about) 1 // Copyright 2020 The Swarm Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package api_test 6 7 import ( 8 "bytes" 9 "context" 10 "crypto/ecdsa" 11 "encoding/hex" 12 "errors" 13 "fmt" 14 "math/big" 15 "net/http" 16 "net/url" 17 "strings" 18 "sync" 19 "testing" 20 "time" 21 22 "github.com/ethersphere/bee/v2/pkg/api" 23 "github.com/ethersphere/bee/v2/pkg/crypto" 24 "github.com/ethersphere/bee/v2/pkg/jsonhttp" 25 "github.com/ethersphere/bee/v2/pkg/jsonhttp/jsonhttptest" 26 "github.com/ethersphere/bee/v2/pkg/log" 27 "github.com/ethersphere/bee/v2/pkg/postage" 28 mockpost "github.com/ethersphere/bee/v2/pkg/postage/mock" 29 "github.com/ethersphere/bee/v2/pkg/pss" 30 "github.com/ethersphere/bee/v2/pkg/pushsync" 31 "github.com/ethersphere/bee/v2/pkg/spinlock" 32 mockstorer "github.com/ethersphere/bee/v2/pkg/storer/mock" 33 "github.com/ethersphere/bee/v2/pkg/swarm" 34 "github.com/ethersphere/bee/v2/pkg/util/testutil" 35 "github.com/gorilla/websocket" 36 ) 37 38 var ( 39 target = pss.Target([]byte{1}) 40 targets = pss.Targets([]pss.Target{target}) 41 payload = []byte("testdata") 42 topic = pss.NewTopic("testtopic") 43 mTimeout = 2 * time.Second 44 longTimeout = 30 * time.Second 45 ) 46 47 // creates a single websocket handler for an arbitrary topic, and receives a message 48 func TestPssWebsocketSingleHandler(t *testing.T) { 49 t.Parallel() 50 51 var ( 52 p, publicKey, cl, _ = newPssTest(t, opts{}) 53 respC = make(chan error, 1) 54 tc swarm.Chunk 55 ) 56 57 // the long timeout is needed so that we dont time out while still mining the message with Wrap() 58 // otherwise the test (and other tests below) flakes 59 err := cl.SetReadDeadline(time.Now().Add(longTimeout)) 60 if err != nil { 61 t.Fatal(err) 62 } 63 cl.SetReadLimit(swarm.ChunkSize) 64 65 tc, err = pss.Wrap(context.Background(), topic, payload, publicKey, targets) 66 if err != nil { 67 t.Fatal(err) 68 } 69 70 p.TryUnwrap(tc) 71 72 go expectMessage(t, cl, respC, payload) 73 if err := <-respC; err != nil { 74 t.Fatal(err) 75 } 76 } 77 78 func TestPssWebsocketSingleHandlerDeregister(t *testing.T) { 79 t.Parallel() 80 81 // create a new pss instance, register a handle through ws, call 82 // pss.TryUnwrap with a chunk designated for this handler and expect 83 // the handler to be notified 84 var ( 85 p, publicKey, cl, _ = newPssTest(t, opts{}) 86 respC = make(chan error, 1) 87 tc swarm.Chunk 88 ) 89 90 err := cl.SetReadDeadline(time.Now().Add(longTimeout)) 91 92 if err != nil { 93 t.Fatal(err) 94 } 95 cl.SetReadLimit(swarm.ChunkSize) 96 97 tc, err = pss.Wrap(context.Background(), topic, payload, publicKey, targets) 98 if err != nil { 99 t.Fatal(err) 100 } 101 102 // close the websocket before calling pss with the message 103 err = cl.WriteMessage(websocket.CloseMessage, []byte{}) 104 if err != nil { 105 t.Fatal(err) 106 } 107 108 p.TryUnwrap(tc) 109 110 go expectMessage(t, cl, respC, payload) 111 if err := <-respC; err != nil { 112 t.Fatal(err) 113 } 114 } 115 116 func TestPssWebsocketMultiHandler(t *testing.T) { 117 t.Parallel() 118 119 var ( 120 p, publicKey, cl, listener = newPssTest(t, opts{}) 121 122 u = url.URL{Scheme: "ws", Host: listener, Path: "/pss/subscribe/testtopic"} 123 cl2, _, err = websocket.DefaultDialer.Dial(u.String(), nil) 124 125 respC = make(chan error, 2) 126 tc swarm.Chunk 127 ) 128 if err != nil { 129 t.Fatalf("dial: %v. url %v", err, u.String()) 130 } 131 testutil.CleanupCloser(t, cl2) 132 133 err = cl.SetReadDeadline(time.Now().Add(longTimeout)) 134 if err != nil { 135 t.Fatal(err) 136 } 137 cl.SetReadLimit(swarm.ChunkSize) 138 139 tc, err = pss.Wrap(context.Background(), topic, payload, publicKey, targets) 140 if err != nil { 141 t.Fatal(err) 142 } 143 144 // close the websocket before calling pss with the message 145 err = cl.WriteMessage(websocket.CloseMessage, []byte{}) 146 if err != nil { 147 t.Fatal(err) 148 } 149 150 p.TryUnwrap(tc) 151 152 go expectMessage(t, cl, respC, payload) 153 go expectMessage(t, cl2, respC, payload) 154 if err := <-respC; err != nil { 155 t.Fatal(err) 156 } 157 if err := <-respC; err != nil { 158 t.Fatal(err) 159 } 160 } 161 162 // nolint:paralleltest 163 // TestPssSend tests that the pss message sending over http works correctly. 164 func TestPssSend(t *testing.T) { 165 var ( 166 mtx sync.Mutex 167 receivedTopic pss.Topic 168 receivedBytes []byte 169 receivedTargets pss.Targets 170 done bool 171 172 privk, _ = crypto.GenerateSecp256k1Key() 173 publicKeyBytes = crypto.EncodeSecp256k1PublicKey(&privk.PublicKey) 174 175 sendFn = func(ctx context.Context, targets pss.Targets, chunk swarm.Chunk) error { 176 mtx.Lock() 177 topic, msg, err := pss.Unwrap(ctx, privk, chunk, []pss.Topic{topic}) 178 receivedTopic = topic 179 receivedBytes = msg 180 receivedTargets = targets 181 done = true 182 mtx.Unlock() 183 return err 184 } 185 mp = mockpost.New(mockpost.WithIssuer(postage.NewStampIssuer("", "", batchOk, big.NewInt(3), 11, 10, 1000, true))) 186 p = newMockPss(sendFn) 187 client, _, _, _ = newTestServer(t, testServerOptions{ 188 Pss: p, 189 Storer: mockstorer.New(), 190 Post: mp, 191 }) 192 193 recipient = hex.EncodeToString(publicKeyBytes) 194 targets = fmt.Sprintf("[[%d]]", 0x12) 195 topic = "testtopic" 196 hasher = swarm.NewHasher() 197 _, err = hasher.Write([]byte(topic)) 198 topicHash = hasher.Sum(nil) 199 ) 200 if err != nil { 201 t.Fatal(err) 202 } 203 204 t.Run("err - bad batch", func(t *testing.T) { 205 hexbatch := "abcdefgg" 206 jsonhttptest.Request(t, client, http.MethodPost, "/pss/send/to/12", http.StatusBadRequest, 207 jsonhttptest.WithRequestHeader(api.SwarmPostageBatchIdHeader, hexbatch), 208 jsonhttptest.WithRequestBody(bytes.NewReader(payload)), 209 jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{ 210 Code: http.StatusBadRequest, 211 Message: "invalid header params", 212 Reasons: []jsonhttp.Reason{ 213 { 214 Field: api.SwarmPostageBatchIdHeader, 215 Error: api.HexInvalidByteError('g').Error(), 216 }, 217 }, 218 }), 219 ) 220 }) 221 222 t.Run("ok batch", func(t *testing.T) { 223 hexbatch := hex.EncodeToString(batchOk) 224 jsonhttptest.Request(t, client, http.MethodPost, "/pss/send/to/12", http.StatusCreated, 225 jsonhttptest.WithRequestHeader(api.SwarmPostageBatchIdHeader, hexbatch), 226 jsonhttptest.WithRequestBody(bytes.NewReader(payload)), 227 ) 228 }) 229 t.Run("bad request - batch empty", func(t *testing.T) { 230 hexbatch := hex.EncodeToString(batchEmpty) 231 jsonhttptest.Request(t, client, http.MethodPost, "/pss/send/to/12", http.StatusBadRequest, 232 jsonhttptest.WithRequestHeader(api.SwarmPostageBatchIdHeader, hexbatch), 233 jsonhttptest.WithRequestBody(bytes.NewReader(payload)), 234 ) 235 }) 236 237 t.Run("ok", func(t *testing.T) { 238 jsonhttptest.Request(t, client, http.MethodPost, "/pss/send/testtopic/12?recipient="+recipient, http.StatusCreated, 239 jsonhttptest.WithRequestHeader(api.SwarmPostageBatchIdHeader, batchOkStr), 240 jsonhttptest.WithRequestBody(bytes.NewReader(payload)), 241 jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{ 242 Message: "Created", 243 Code: http.StatusCreated, 244 }), 245 ) 246 waitDone(t, &mtx, &done) 247 if !bytes.Equal(receivedBytes, payload) { 248 t.Fatalf("payload mismatch. want %v got %v", payload, receivedBytes) 249 } 250 if targets != fmt.Sprint(receivedTargets) { 251 t.Fatalf("targets mismatch. want %v got %v", targets, receivedTargets) 252 } 253 if string(topicHash) != string(receivedTopic[:]) { 254 t.Fatalf("topic mismatch. want %v got %v", topic, string(receivedTopic[:])) 255 } 256 }) 257 258 t.Run("without recipient", func(t *testing.T) { 259 jsonhttptest.Request(t, client, http.MethodPost, "/pss/send/testtopic/12", http.StatusCreated, 260 jsonhttptest.WithRequestHeader(api.SwarmPostageBatchIdHeader, batchOkStr), 261 jsonhttptest.WithRequestBody(bytes.NewReader(payload)), 262 jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{ 263 Message: "Created", 264 Code: http.StatusCreated, 265 }), 266 ) 267 waitDone(t, &mtx, &done) 268 if !bytes.Equal(receivedBytes, payload) { 269 t.Fatalf("payload mismatch. want %v got %v", payload, receivedBytes) 270 } 271 if targets != fmt.Sprint(receivedTargets) { 272 t.Fatalf("targets mismatch. want %v got %v", targets, receivedTargets) 273 } 274 if string(topicHash) != string(receivedTopic[:]) { 275 t.Fatalf("topic mismatch. want %v got %v", topic, string(receivedTopic[:])) 276 } 277 }) 278 } 279 280 // TestPssPingPong tests that the websocket api adheres to the websocket standard 281 // and sends ping-pong messages to keep the connection alive. 282 // The test opens a websocket, keeps it alive for 500ms, then receives a pss message. 283 func TestPssPingPong(t *testing.T) { 284 t.Parallel() 285 286 var ( 287 p, publicKey, cl, _ = newPssTest(t, opts{pingPeriod: 90 * time.Millisecond}) 288 289 respC = make(chan error, 1) 290 tc swarm.Chunk 291 pongWait = 1 * time.Millisecond 292 ) 293 294 cl.SetReadLimit(swarm.ChunkSize) 295 err := cl.SetReadDeadline(time.Now().Add(pongWait)) 296 if err != nil { 297 t.Fatal(err) 298 } 299 300 tc, err = pss.Wrap(context.Background(), topic, payload, publicKey, targets) 301 if err != nil { 302 t.Fatal(err) 303 } 304 305 time.Sleep(500 * time.Millisecond) // wait to see that the websocket is kept alive 306 307 p.TryUnwrap(tc) 308 309 go expectMessage(t, cl, respC, nil) 310 if err := <-respC; err == nil || !strings.Contains(err.Error(), "i/o timeout") { 311 // note: error has *websocket.netError type so we need to check error by checking message 312 t.Fatal("want timeout error") 313 } 314 } 315 316 func expectMessage(t *testing.T, cl *websocket.Conn, respC chan error, expData []byte) { 317 t.Helper() 318 319 timeout := time.NewTimer(mTimeout) 320 defer timeout.Stop() 321 322 for { 323 select { 324 case <-timeout.C: 325 if expData == nil { 326 respC <- nil 327 } else { 328 respC <- errors.New("timed out waiting for message") 329 } 330 return 331 default: 332 msgType, message, err := cl.ReadMessage() 333 if err != nil { 334 respC <- err 335 return 336 } 337 if msgType == websocket.PongMessage { 338 // ignore pings 339 continue 340 } 341 if message == nil { 342 continue 343 } 344 345 if bytes.Equal(message, expData) { 346 respC <- nil 347 } else { 348 respC <- errors.New("unexpected message") 349 } 350 return 351 } 352 } 353 } 354 355 func waitDone(t *testing.T, mtx *sync.Mutex, done *bool) { 356 t.Helper() 357 358 err := spinlock.Wait(time.Second, func() bool { 359 mtx.Lock() 360 defer mtx.Unlock() 361 return *done 362 }) 363 if err != nil { 364 t.Fatal("timed out waiting for send") 365 } 366 } 367 368 type opts struct { 369 pingPeriod time.Duration 370 } 371 372 func newPssTest(t *testing.T, o opts) (pss.Interface, *ecdsa.PublicKey, *websocket.Conn, string) { 373 t.Helper() 374 375 privkey, err := crypto.GenerateSecp256k1Key() 376 if err != nil { 377 t.Fatal(err) 378 } 379 380 pss := pss.New(privkey, log.Noop) 381 testutil.CleanupCloser(t, pss) 382 383 if o.pingPeriod == 0 { 384 o.pingPeriod = 10 * time.Second 385 } 386 _, cl, listener, _ := newTestServer(t, testServerOptions{ 387 Pss: pss, 388 WsPath: "/pss/subscribe/testtopic", 389 Storer: mockstorer.New(), 390 Logger: log.Noop, 391 WsPingPeriod: o.pingPeriod, 392 }) 393 394 return pss, &privkey.PublicKey, cl, listener 395 } 396 397 func TestPssPostHandlerInvalidInputs(t *testing.T) { 398 t.Parallel() 399 400 client, _, _, _ := newTestServer(t, testServerOptions{}) 401 402 tests := []struct { 403 name string 404 topic string 405 targets string 406 want jsonhttp.StatusResponse 407 }{{ 408 name: "targets - odd length hex string", 409 topic: "test_topic", 410 targets: "1", 411 want: jsonhttp.StatusResponse{ 412 Code: http.StatusBadRequest, 413 Message: "invalid path params", 414 Reasons: []jsonhttp.Reason{ 415 { 416 Field: "target", 417 Error: api.ErrHexLength.Error(), 418 }, 419 }, 420 }, 421 }, { 422 name: "targets - odd length hex string", 423 topic: "test_topic", 424 targets: "1G", 425 want: jsonhttp.StatusResponse{ 426 Code: http.StatusBadRequest, 427 Message: "invalid path params", 428 Reasons: []jsonhttp.Reason{ 429 { 430 Field: "target", 431 Error: api.HexInvalidByteError('G').Error(), 432 }, 433 }, 434 }, 435 }} 436 437 for _, tc := range tests { 438 tc := tc 439 t.Run(tc.name, func(t *testing.T) { 440 t.Parallel() 441 442 jsonhttptest.Request(t, client, http.MethodPost, "/pss/send/"+tc.topic+"/"+tc.targets, tc.want.Code, 443 jsonhttptest.WithExpectedJSONResponse(tc.want), 444 ) 445 }) 446 } 447 } 448 449 type pssSendFn func(context.Context, pss.Targets, swarm.Chunk) error 450 type mpss struct { 451 f pssSendFn 452 } 453 454 func newMockPss(f pssSendFn) *mpss { 455 return &mpss{f} 456 } 457 458 // Send arbitrary byte slice with the given topic to Targets. 459 func (m *mpss) Send(ctx context.Context, topic pss.Topic, payload []byte, _ postage.Stamper, recipient *ecdsa.PublicKey, targets pss.Targets) error { 460 chunk, err := pss.Wrap(ctx, topic, payload, recipient, targets) 461 if err != nil { 462 return err 463 } 464 return m.f(ctx, targets, chunk) 465 } 466 467 // Register a Handler for a given Topic. 468 func (m *mpss) Register(_ pss.Topic, _ pss.Handler) func() { 469 panic("not implemented") // TODO: Implement 470 } 471 472 // TryUnwrap tries to unwrap a wrapped trojan message. 473 func (m *mpss) TryUnwrap(_ swarm.Chunk) { 474 panic("not implemented") // TODO: Implement 475 } 476 477 func (m *mpss) SetPushSyncer(pushSyncer pushsync.PushSyncer) { 478 panic("not implemented") // TODO: Implement 479 } 480 481 func (m *mpss) Close() error { 482 panic("not implemented") // TODO: Implement 483 }