github.com/ethersphere/bee/v2@v2.2.0/pkg/api/pss.go (about)

     1  // Copyright 2020 The Swarm Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package api
     6  
     7  import (
     8  	"context"
     9  	"crypto/ecdsa"
    10  	"encoding/hex"
    11  	"errors"
    12  	"io"
    13  	"net/http"
    14  	"strings"
    15  	"time"
    16  
    17  	"github.com/ethersphere/bee/v2/pkg/crypto"
    18  	"github.com/ethersphere/bee/v2/pkg/jsonhttp"
    19  	"github.com/ethersphere/bee/v2/pkg/postage"
    20  	"github.com/ethersphere/bee/v2/pkg/pss"
    21  	"github.com/ethersphere/bee/v2/pkg/swarm"
    22  	"github.com/gorilla/mux"
    23  	"github.com/gorilla/websocket"
    24  )
    25  
    26  const (
    27  	writeDeadline   = 4 * time.Second // write deadline. should be smaller than the shutdown timeout on api close
    28  	targetMaxLength = 3               // max target length in bytes, in order to prevent grieving by excess computation
    29  )
    30  
    31  func (s *Service) pssPostHandler(w http.ResponseWriter, r *http.Request) {
    32  	logger := s.logger.WithName("post_pss_send").Build()
    33  
    34  	paths := struct {
    35  		Topic   string `map:"topic" validate:"required"`
    36  		Targets string `map:"targets" validate:"required"`
    37  	}{}
    38  	if response := s.mapStructure(mux.Vars(r), &paths); response != nil {
    39  		response("invalid path params", logger, w)
    40  		return
    41  	}
    42  	topic := pss.NewTopic(paths.Topic)
    43  
    44  	var targets pss.Targets
    45  	for _, v := range strings.Split(paths.Targets, ",") {
    46  		target := struct {
    47  			Val []byte `map:"target" validate:"required,max=3"`
    48  		}{}
    49  		if response := s.mapStructure(map[string]string{"target": v}, &target); response != nil {
    50  			response("invalid path params", logger, w)
    51  			return
    52  		}
    53  		targets = append(targets, target.Val)
    54  	}
    55  
    56  	queries := struct {
    57  		Recipient *ecdsa.PublicKey `map:"recipient,omitempty"`
    58  	}{}
    59  	if response := s.mapStructure(r.URL.Query(), &queries); response != nil {
    60  		response("invalid query params", logger, w)
    61  		return
    62  	}
    63  	if queries.Recipient == nil {
    64  		queries.Recipient = &(crypto.Secp256k1PrivateKeyFromBytes(topic[:])).PublicKey
    65  	}
    66  
    67  	headers := struct {
    68  		BatchID []byte `map:"Swarm-Postage-Batch-Id" validate:"required"`
    69  	}{}
    70  	if response := s.mapStructure(r.Header, &headers); response != nil {
    71  		response("invalid header params", logger, w)
    72  		return
    73  	}
    74  
    75  	payload, err := io.ReadAll(r.Body)
    76  	if err != nil {
    77  		logger.Debug("read body failed", "error", err)
    78  		logger.Error(nil, "read body failed")
    79  		jsonhttp.InternalServerError(w, "pss send failed")
    80  		return
    81  	}
    82  	i, save, err := s.post.GetStampIssuer(headers.BatchID)
    83  	if err != nil {
    84  		logger.Debug("get postage batch issuer failed", "batch_id", hex.EncodeToString(headers.BatchID), "error", err)
    85  		logger.Error(nil, "get postage batch issuer failed")
    86  		switch {
    87  		case errors.Is(err, postage.ErrNotFound):
    88  			jsonhttp.BadRequest(w, "batch not found")
    89  		case errors.Is(err, postage.ErrNotUsable):
    90  			jsonhttp.BadRequest(w, "batch not usable yet")
    91  		default:
    92  			jsonhttp.BadRequest(w, "postage stamp issuer")
    93  		}
    94  		return
    95  	}
    96  
    97  	stamper := postage.NewStamper(s.stamperStore, i, s.signer)
    98  
    99  	err = s.pss.Send(r.Context(), topic, payload, stamper, queries.Recipient, targets)
   100  	if err != nil {
   101  		logger.Debug("send payload failed", "topic", paths.Topic, "error", err)
   102  		logger.Error(nil, "send payload failed")
   103  		switch {
   104  		case errors.Is(err, postage.ErrBucketFull):
   105  			jsonhttp.PaymentRequired(w, "batch is overissued")
   106  		default:
   107  			jsonhttp.InternalServerError(w, "pss send failed")
   108  		}
   109  		return
   110  	}
   111  
   112  	if err = save(); err != nil {
   113  		logger.Debug("save stamp failed", "error", err)
   114  		logger.Error(nil, "save stamp failed")
   115  		jsonhttp.InternalServerError(w, "pss send failed")
   116  		return
   117  	}
   118  
   119  	jsonhttp.Created(w, nil)
   120  }
   121  
   122  func (s *Service) pssWsHandler(w http.ResponseWriter, r *http.Request) {
   123  	logger := s.logger.WithName("pss_subscribe").Build()
   124  
   125  	paths := struct {
   126  		Topic string `map:"topic" validate:"required"`
   127  	}{}
   128  	if response := s.mapStructure(mux.Vars(r), &paths); response != nil {
   129  		response("invalid path params", logger, w)
   130  		return
   131  	}
   132  
   133  	upgrader := websocket.Upgrader{
   134  		ReadBufferSize:  swarm.ChunkSize,
   135  		WriteBufferSize: swarm.ChunkSize,
   136  		CheckOrigin:     s.checkOrigin,
   137  	}
   138  
   139  	conn, err := upgrader.Upgrade(w, r, nil)
   140  	if err != nil {
   141  		logger.Debug("upgrade failed", "error", err)
   142  		logger.Error(nil, "upgrade failed")
   143  		jsonhttp.InternalServerError(w, "upgrade failed")
   144  		return
   145  	}
   146  
   147  	s.wsWg.Add(1)
   148  	go s.pumpWs(conn, paths.Topic)
   149  }
   150  
   151  func (s *Service) pumpWs(conn *websocket.Conn, t string) {
   152  	defer s.wsWg.Done()
   153  
   154  	var (
   155  		dataC  = make(chan []byte)
   156  		gone   = make(chan struct{})
   157  		topic  = pss.NewTopic(t)
   158  		ticker = time.NewTicker(s.WsPingPeriod)
   159  		err    error
   160  	)
   161  	defer func() {
   162  		ticker.Stop()
   163  		_ = conn.Close()
   164  	}()
   165  	cleanup := s.pss.Register(topic, func(ctx context.Context, m []byte) {
   166  		select {
   167  		case dataC <- m:
   168  		case <-ctx.Done():
   169  			return
   170  		case <-gone:
   171  			return
   172  		case <-s.quit:
   173  			return
   174  		}
   175  	})
   176  
   177  	defer cleanup()
   178  
   179  	conn.SetCloseHandler(func(code int, text string) error {
   180  		s.logger.Debug("pss ws: client gone", "code", code, "message", text)
   181  		close(gone)
   182  		return nil
   183  	})
   184  
   185  	for {
   186  		select {
   187  		case b := <-dataC:
   188  			err = conn.SetWriteDeadline(time.Now().Add(writeDeadline))
   189  			if err != nil {
   190  				s.logger.Debug("pss ws: set write deadline failed", "error", err)
   191  				return
   192  			}
   193  
   194  			err = conn.WriteMessage(websocket.BinaryMessage, b)
   195  			if err != nil {
   196  				s.logger.Debug("pss ws: write message failed", "error", err)
   197  				return
   198  			}
   199  
   200  		case <-s.quit:
   201  			// shutdown
   202  			err = conn.SetWriteDeadline(time.Now().Add(writeDeadline))
   203  			if err != nil {
   204  				s.logger.Debug("pss ws: set write deadline failed", "error", err)
   205  				return
   206  			}
   207  			err = conn.WriteMessage(websocket.CloseMessage, []byte{})
   208  			if err != nil {
   209  				s.logger.Debug("pss ws: write close message failed", "error", err)
   210  			}
   211  			return
   212  		case <-gone:
   213  			// client gone
   214  			return
   215  		case <-ticker.C:
   216  			err = conn.SetWriteDeadline(time.Now().Add(writeDeadline))
   217  			if err != nil {
   218  				s.logger.Debug("pss ws: set write deadline failed", "error", err)
   219  				return
   220  			}
   221  			if err = conn.WriteMessage(websocket.PingMessage, nil); err != nil {
   222  				// error encountered while pinging client. client probably gone
   223  				return
   224  			}
   225  		}
   226  	}
   227  }