github.com/ethersphere/bee/v2@v2.2.0/pkg/api/pss_test.go (about)

     1  // Copyright 2020 The Swarm Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package api_test
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"crypto/ecdsa"
    11  	"encoding/hex"
    12  	"errors"
    13  	"fmt"
    14  	"math/big"
    15  	"net/http"
    16  	"net/url"
    17  	"strings"
    18  	"sync"
    19  	"testing"
    20  	"time"
    21  
    22  	"github.com/ethersphere/bee/v2/pkg/api"
    23  	"github.com/ethersphere/bee/v2/pkg/crypto"
    24  	"github.com/ethersphere/bee/v2/pkg/jsonhttp"
    25  	"github.com/ethersphere/bee/v2/pkg/jsonhttp/jsonhttptest"
    26  	"github.com/ethersphere/bee/v2/pkg/log"
    27  	"github.com/ethersphere/bee/v2/pkg/postage"
    28  	mockpost "github.com/ethersphere/bee/v2/pkg/postage/mock"
    29  	"github.com/ethersphere/bee/v2/pkg/pss"
    30  	"github.com/ethersphere/bee/v2/pkg/pushsync"
    31  	"github.com/ethersphere/bee/v2/pkg/spinlock"
    32  	mockstorer "github.com/ethersphere/bee/v2/pkg/storer/mock"
    33  	"github.com/ethersphere/bee/v2/pkg/swarm"
    34  	"github.com/ethersphere/bee/v2/pkg/util/testutil"
    35  	"github.com/gorilla/websocket"
    36  )
    37  
    38  var (
    39  	target      = pss.Target([]byte{1})
    40  	targets     = pss.Targets([]pss.Target{target})
    41  	payload     = []byte("testdata")
    42  	topic       = pss.NewTopic("testtopic")
    43  	mTimeout    = 2 * time.Second
    44  	longTimeout = 30 * time.Second
    45  )
    46  
    47  // creates a single websocket handler for an arbitrary topic, and receives a message
    48  func TestPssWebsocketSingleHandler(t *testing.T) {
    49  	t.Parallel()
    50  
    51  	var (
    52  		p, publicKey, cl, _ = newPssTest(t, opts{})
    53  		respC               = make(chan error, 1)
    54  		tc                  swarm.Chunk
    55  	)
    56  
    57  	// the long timeout is needed so that we dont time out while still mining the message with Wrap()
    58  	// otherwise the test (and other tests below) flakes
    59  	err := cl.SetReadDeadline(time.Now().Add(longTimeout))
    60  	if err != nil {
    61  		t.Fatal(err)
    62  	}
    63  	cl.SetReadLimit(swarm.ChunkSize)
    64  
    65  	tc, err = pss.Wrap(context.Background(), topic, payload, publicKey, targets)
    66  	if err != nil {
    67  		t.Fatal(err)
    68  	}
    69  
    70  	p.TryUnwrap(tc)
    71  
    72  	go expectMessage(t, cl, respC, payload)
    73  	if err := <-respC; err != nil {
    74  		t.Fatal(err)
    75  	}
    76  }
    77  
    78  func TestPssWebsocketSingleHandlerDeregister(t *testing.T) {
    79  	t.Parallel()
    80  
    81  	// create a new pss instance, register a handle through ws, call
    82  	// pss.TryUnwrap with a chunk designated for this handler and expect
    83  	// the handler to be notified
    84  	var (
    85  		p, publicKey, cl, _ = newPssTest(t, opts{})
    86  		respC               = make(chan error, 1)
    87  		tc                  swarm.Chunk
    88  	)
    89  
    90  	err := cl.SetReadDeadline(time.Now().Add(longTimeout))
    91  
    92  	if err != nil {
    93  		t.Fatal(err)
    94  	}
    95  	cl.SetReadLimit(swarm.ChunkSize)
    96  
    97  	tc, err = pss.Wrap(context.Background(), topic, payload, publicKey, targets)
    98  	if err != nil {
    99  		t.Fatal(err)
   100  	}
   101  
   102  	// close the websocket before calling pss with the message
   103  	err = cl.WriteMessage(websocket.CloseMessage, []byte{})
   104  	if err != nil {
   105  		t.Fatal(err)
   106  	}
   107  
   108  	p.TryUnwrap(tc)
   109  
   110  	go expectMessage(t, cl, respC, payload)
   111  	if err := <-respC; err != nil {
   112  		t.Fatal(err)
   113  	}
   114  }
   115  
   116  func TestPssWebsocketMultiHandler(t *testing.T) {
   117  	t.Parallel()
   118  
   119  	var (
   120  		p, publicKey, cl, listener = newPssTest(t, opts{})
   121  
   122  		u           = url.URL{Scheme: "ws", Host: listener, Path: "/pss/subscribe/testtopic"}
   123  		cl2, _, err = websocket.DefaultDialer.Dial(u.String(), nil)
   124  
   125  		respC = make(chan error, 2)
   126  		tc    swarm.Chunk
   127  	)
   128  	if err != nil {
   129  		t.Fatalf("dial: %v. url %v", err, u.String())
   130  	}
   131  	testutil.CleanupCloser(t, cl2)
   132  
   133  	err = cl.SetReadDeadline(time.Now().Add(longTimeout))
   134  	if err != nil {
   135  		t.Fatal(err)
   136  	}
   137  	cl.SetReadLimit(swarm.ChunkSize)
   138  
   139  	tc, err = pss.Wrap(context.Background(), topic, payload, publicKey, targets)
   140  	if err != nil {
   141  		t.Fatal(err)
   142  	}
   143  
   144  	// close the websocket before calling pss with the message
   145  	err = cl.WriteMessage(websocket.CloseMessage, []byte{})
   146  	if err != nil {
   147  		t.Fatal(err)
   148  	}
   149  
   150  	p.TryUnwrap(tc)
   151  
   152  	go expectMessage(t, cl, respC, payload)
   153  	go expectMessage(t, cl2, respC, payload)
   154  	if err := <-respC; err != nil {
   155  		t.Fatal(err)
   156  	}
   157  	if err := <-respC; err != nil {
   158  		t.Fatal(err)
   159  	}
   160  }
   161  
   162  // nolint:paralleltest
   163  // TestPssSend tests that the pss message sending over http works correctly.
   164  func TestPssSend(t *testing.T) {
   165  	var (
   166  		mtx             sync.Mutex
   167  		receivedTopic   pss.Topic
   168  		receivedBytes   []byte
   169  		receivedTargets pss.Targets
   170  		done            bool
   171  
   172  		privk, _       = crypto.GenerateSecp256k1Key()
   173  		publicKeyBytes = crypto.EncodeSecp256k1PublicKey(&privk.PublicKey)
   174  
   175  		sendFn = func(ctx context.Context, targets pss.Targets, chunk swarm.Chunk) error {
   176  			mtx.Lock()
   177  			topic, msg, err := pss.Unwrap(ctx, privk, chunk, []pss.Topic{topic})
   178  			receivedTopic = topic
   179  			receivedBytes = msg
   180  			receivedTargets = targets
   181  			done = true
   182  			mtx.Unlock()
   183  			return err
   184  		}
   185  		mp              = mockpost.New(mockpost.WithIssuer(postage.NewStampIssuer("", "", batchOk, big.NewInt(3), 11, 10, 1000, true)))
   186  		p               = newMockPss(sendFn)
   187  		client, _, _, _ = newTestServer(t, testServerOptions{
   188  			Pss:    p,
   189  			Storer: mockstorer.New(),
   190  			Post:   mp,
   191  		})
   192  
   193  		recipient = hex.EncodeToString(publicKeyBytes)
   194  		targets   = fmt.Sprintf("[[%d]]", 0x12)
   195  		topic     = "testtopic"
   196  		hasher    = swarm.NewHasher()
   197  		_, err    = hasher.Write([]byte(topic))
   198  		topicHash = hasher.Sum(nil)
   199  	)
   200  	if err != nil {
   201  		t.Fatal(err)
   202  	}
   203  
   204  	t.Run("err - bad batch", func(t *testing.T) {
   205  		hexbatch := "abcdefgg"
   206  		jsonhttptest.Request(t, client, http.MethodPost, "/pss/send/to/12", http.StatusBadRequest,
   207  			jsonhttptest.WithRequestHeader(api.SwarmPostageBatchIdHeader, hexbatch),
   208  			jsonhttptest.WithRequestBody(bytes.NewReader(payload)),
   209  			jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{
   210  				Code:    http.StatusBadRequest,
   211  				Message: "invalid header params",
   212  				Reasons: []jsonhttp.Reason{
   213  					{
   214  						Field: api.SwarmPostageBatchIdHeader,
   215  						Error: api.HexInvalidByteError('g').Error(),
   216  					},
   217  				},
   218  			}),
   219  		)
   220  	})
   221  
   222  	t.Run("ok batch", func(t *testing.T) {
   223  		hexbatch := hex.EncodeToString(batchOk)
   224  		jsonhttptest.Request(t, client, http.MethodPost, "/pss/send/to/12", http.StatusCreated,
   225  			jsonhttptest.WithRequestHeader(api.SwarmPostageBatchIdHeader, hexbatch),
   226  			jsonhttptest.WithRequestBody(bytes.NewReader(payload)),
   227  		)
   228  	})
   229  	t.Run("bad request - batch empty", func(t *testing.T) {
   230  		hexbatch := hex.EncodeToString(batchEmpty)
   231  		jsonhttptest.Request(t, client, http.MethodPost, "/pss/send/to/12", http.StatusBadRequest,
   232  			jsonhttptest.WithRequestHeader(api.SwarmPostageBatchIdHeader, hexbatch),
   233  			jsonhttptest.WithRequestBody(bytes.NewReader(payload)),
   234  		)
   235  	})
   236  
   237  	t.Run("ok", func(t *testing.T) {
   238  		jsonhttptest.Request(t, client, http.MethodPost, "/pss/send/testtopic/12?recipient="+recipient, http.StatusCreated,
   239  			jsonhttptest.WithRequestHeader(api.SwarmPostageBatchIdHeader, batchOkStr),
   240  			jsonhttptest.WithRequestBody(bytes.NewReader(payload)),
   241  			jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{
   242  				Message: "Created",
   243  				Code:    http.StatusCreated,
   244  			}),
   245  		)
   246  		waitDone(t, &mtx, &done)
   247  		if !bytes.Equal(receivedBytes, payload) {
   248  			t.Fatalf("payload mismatch. want %v got %v", payload, receivedBytes)
   249  		}
   250  		if targets != fmt.Sprint(receivedTargets) {
   251  			t.Fatalf("targets mismatch. want %v got %v", targets, receivedTargets)
   252  		}
   253  		if string(topicHash) != string(receivedTopic[:]) {
   254  			t.Fatalf("topic mismatch. want %v got %v", topic, string(receivedTopic[:]))
   255  		}
   256  	})
   257  
   258  	t.Run("without recipient", func(t *testing.T) {
   259  		jsonhttptest.Request(t, client, http.MethodPost, "/pss/send/testtopic/12", http.StatusCreated,
   260  			jsonhttptest.WithRequestHeader(api.SwarmPostageBatchIdHeader, batchOkStr),
   261  			jsonhttptest.WithRequestBody(bytes.NewReader(payload)),
   262  			jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{
   263  				Message: "Created",
   264  				Code:    http.StatusCreated,
   265  			}),
   266  		)
   267  		waitDone(t, &mtx, &done)
   268  		if !bytes.Equal(receivedBytes, payload) {
   269  			t.Fatalf("payload mismatch. want %v got %v", payload, receivedBytes)
   270  		}
   271  		if targets != fmt.Sprint(receivedTargets) {
   272  			t.Fatalf("targets mismatch. want %v got %v", targets, receivedTargets)
   273  		}
   274  		if string(topicHash) != string(receivedTopic[:]) {
   275  			t.Fatalf("topic mismatch. want %v got %v", topic, string(receivedTopic[:]))
   276  		}
   277  	})
   278  }
   279  
   280  // TestPssPingPong tests that the websocket api adheres to the websocket standard
   281  // and sends ping-pong messages to keep the connection alive.
   282  // The test opens a websocket, keeps it alive for 500ms, then receives a pss message.
   283  func TestPssPingPong(t *testing.T) {
   284  	t.Parallel()
   285  
   286  	var (
   287  		p, publicKey, cl, _ = newPssTest(t, opts{pingPeriod: 90 * time.Millisecond})
   288  
   289  		respC    = make(chan error, 1)
   290  		tc       swarm.Chunk
   291  		pongWait = 1 * time.Millisecond
   292  	)
   293  
   294  	cl.SetReadLimit(swarm.ChunkSize)
   295  	err := cl.SetReadDeadline(time.Now().Add(pongWait))
   296  	if err != nil {
   297  		t.Fatal(err)
   298  	}
   299  
   300  	tc, err = pss.Wrap(context.Background(), topic, payload, publicKey, targets)
   301  	if err != nil {
   302  		t.Fatal(err)
   303  	}
   304  
   305  	time.Sleep(500 * time.Millisecond) // wait to see that the websocket is kept alive
   306  
   307  	p.TryUnwrap(tc)
   308  
   309  	go expectMessage(t, cl, respC, nil)
   310  	if err := <-respC; err == nil || !strings.Contains(err.Error(), "i/o timeout") {
   311  		// note: error has *websocket.netError type so we need to check error by checking message
   312  		t.Fatal("want timeout error")
   313  	}
   314  }
   315  
   316  func expectMessage(t *testing.T, cl *websocket.Conn, respC chan error, expData []byte) {
   317  	t.Helper()
   318  
   319  	timeout := time.NewTimer(mTimeout)
   320  	defer timeout.Stop()
   321  
   322  	for {
   323  		select {
   324  		case <-timeout.C:
   325  			if expData == nil {
   326  				respC <- nil
   327  			} else {
   328  				respC <- errors.New("timed out waiting for message")
   329  			}
   330  			return
   331  		default:
   332  			msgType, message, err := cl.ReadMessage()
   333  			if err != nil {
   334  				respC <- err
   335  				return
   336  			}
   337  			if msgType == websocket.PongMessage {
   338  				// ignore pings
   339  				continue
   340  			}
   341  			if message == nil {
   342  				continue
   343  			}
   344  
   345  			if bytes.Equal(message, expData) {
   346  				respC <- nil
   347  			} else {
   348  				respC <- errors.New("unexpected message")
   349  			}
   350  			return
   351  		}
   352  	}
   353  }
   354  
   355  func waitDone(t *testing.T, mtx *sync.Mutex, done *bool) {
   356  	t.Helper()
   357  
   358  	err := spinlock.Wait(time.Second, func() bool {
   359  		mtx.Lock()
   360  		defer mtx.Unlock()
   361  		return *done
   362  	})
   363  	if err != nil {
   364  		t.Fatal("timed out waiting for send")
   365  	}
   366  }
   367  
   368  type opts struct {
   369  	pingPeriod time.Duration
   370  }
   371  
   372  func newPssTest(t *testing.T, o opts) (pss.Interface, *ecdsa.PublicKey, *websocket.Conn, string) {
   373  	t.Helper()
   374  
   375  	privkey, err := crypto.GenerateSecp256k1Key()
   376  	if err != nil {
   377  		t.Fatal(err)
   378  	}
   379  
   380  	pss := pss.New(privkey, log.Noop)
   381  	testutil.CleanupCloser(t, pss)
   382  
   383  	if o.pingPeriod == 0 {
   384  		o.pingPeriod = 10 * time.Second
   385  	}
   386  	_, cl, listener, _ := newTestServer(t, testServerOptions{
   387  		Pss:          pss,
   388  		WsPath:       "/pss/subscribe/testtopic",
   389  		Storer:       mockstorer.New(),
   390  		Logger:       log.Noop,
   391  		WsPingPeriod: o.pingPeriod,
   392  	})
   393  
   394  	return pss, &privkey.PublicKey, cl, listener
   395  }
   396  
   397  func TestPssPostHandlerInvalidInputs(t *testing.T) {
   398  	t.Parallel()
   399  
   400  	client, _, _, _ := newTestServer(t, testServerOptions{})
   401  
   402  	tests := []struct {
   403  		name    string
   404  		topic   string
   405  		targets string
   406  		want    jsonhttp.StatusResponse
   407  	}{{
   408  		name:    "targets - odd length hex string",
   409  		topic:   "test_topic",
   410  		targets: "1",
   411  		want: jsonhttp.StatusResponse{
   412  			Code:    http.StatusBadRequest,
   413  			Message: "invalid path params",
   414  			Reasons: []jsonhttp.Reason{
   415  				{
   416  					Field: "target",
   417  					Error: api.ErrHexLength.Error(),
   418  				},
   419  			},
   420  		},
   421  	}, {
   422  		name:    "targets - odd length hex string",
   423  		topic:   "test_topic",
   424  		targets: "1G",
   425  		want: jsonhttp.StatusResponse{
   426  			Code:    http.StatusBadRequest,
   427  			Message: "invalid path params",
   428  			Reasons: []jsonhttp.Reason{
   429  				{
   430  					Field: "target",
   431  					Error: api.HexInvalidByteError('G').Error(),
   432  				},
   433  			},
   434  		},
   435  	}}
   436  
   437  	for _, tc := range tests {
   438  		tc := tc
   439  		t.Run(tc.name, func(t *testing.T) {
   440  			t.Parallel()
   441  
   442  			jsonhttptest.Request(t, client, http.MethodPost, "/pss/send/"+tc.topic+"/"+tc.targets, tc.want.Code,
   443  				jsonhttptest.WithExpectedJSONResponse(tc.want),
   444  			)
   445  		})
   446  	}
   447  }
   448  
   449  type pssSendFn func(context.Context, pss.Targets, swarm.Chunk) error
   450  type mpss struct {
   451  	f pssSendFn
   452  }
   453  
   454  func newMockPss(f pssSendFn) *mpss {
   455  	return &mpss{f}
   456  }
   457  
   458  // Send arbitrary byte slice with the given topic to Targets.
   459  func (m *mpss) Send(ctx context.Context, topic pss.Topic, payload []byte, _ postage.Stamper, recipient *ecdsa.PublicKey, targets pss.Targets) error {
   460  	chunk, err := pss.Wrap(ctx, topic, payload, recipient, targets)
   461  	if err != nil {
   462  		return err
   463  	}
   464  	return m.f(ctx, targets, chunk)
   465  }
   466  
   467  // Register a Handler for a given Topic.
   468  func (m *mpss) Register(_ pss.Topic, _ pss.Handler) func() {
   469  	panic("not implemented") // TODO: Implement
   470  }
   471  
   472  // TryUnwrap tries to unwrap a wrapped trojan message.
   473  func (m *mpss) TryUnwrap(_ swarm.Chunk) {
   474  	panic("not implemented") // TODO: Implement
   475  }
   476  
   477  func (m *mpss) SetPushSyncer(pushSyncer pushsync.PushSyncer) {
   478  	panic("not implemented") // TODO: Implement
   479  }
   480  
   481  func (m *mpss) Close() error {
   482  	panic("not implemented") // TODO: Implement
   483  }