github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/tequilapi/endpoints/sse_handler.go (about)

     1  /*
     2   * Copyright (C) 2020 The "MysteriumNetwork/node" Authors.
     3   *
     4   * This program is free software: you can redistribute it and/or modify
     5   * it under the terms of the GNU General Public License as published by
     6   * the Free Software Foundation, either version 3 of the License, or
     7   * (at your option) any later version.
     8   *
     9   * This program is distributed in the hope that it will be useful,
    10   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    11   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    12   * GNU General Public License for more details.
    13   *
    14   * You should have received a copy of the GNU General Public License
    15   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    16   */
    17  
    18  package endpoints
    19  
    20  import (
    21  	"encoding/json"
    22  	"fmt"
    23  	"math/big"
    24  	"net/http"
    25  	"sync"
    26  
    27  	"github.com/gin-gonic/gin"
    28  	"github.com/mysteriumnetwork/node/core/connection/connectionstate"
    29  	"github.com/pkg/errors"
    30  	"github.com/rs/zerolog/log"
    31  
    32  	"github.com/mysteriumnetwork/node/consumer/session"
    33  	nodeEvent "github.com/mysteriumnetwork/node/core/node/event"
    34  	"github.com/mysteriumnetwork/node/core/state/event"
    35  	stateEvent "github.com/mysteriumnetwork/node/core/state/event"
    36  	"github.com/mysteriumnetwork/node/eventbus"
    37  	"github.com/mysteriumnetwork/node/session/pingpong"
    38  	"github.com/mysteriumnetwork/node/tequilapi/contract"
    39  )
    40  
    41  // EventType represents all the event types we're subscribing to
    42  type EventType string
    43  
    44  // Event represents an event we're gonna send
    45  type Event struct {
    46  	Payload interface{} `json:"payload"`
    47  	Type    EventType   `json:"type"`
    48  }
    49  
    50  const (
    51  	// NATEvent represents the nat event type
    52  	NATEvent EventType = "nat"
    53  	// ServiceStatusEvent represents the service status event type
    54  	ServiceStatusEvent EventType = "service-status"
    55  	// StateChangeEvent represents the state change
    56  	StateChangeEvent EventType = "state-change"
    57  )
    58  
    59  // Handler represents an sse handler
    60  type Handler struct {
    61  	clients       map[chan string]struct{}
    62  	newClients    chan chan string
    63  	deadClients   chan chan string
    64  	messages      chan string
    65  	stopOnce      sync.Once
    66  	stopChan      chan struct{}
    67  	stateProvider stateProvider
    68  }
    69  
    70  type stateProvider interface {
    71  	GetState() stateEvent.State
    72  	GetConnection(string) stateEvent.Connection
    73  }
    74  
    75  // NewSSEHandler returns a new instance of handler
    76  func NewSSEHandler(stateProvider stateProvider) *Handler {
    77  	return &Handler{
    78  		clients:       make(map[chan string]struct{}),
    79  		newClients:    make(chan (chan string)),
    80  		deadClients:   make(chan (chan string)),
    81  		messages:      make(chan string, 20),
    82  		stopChan:      make(chan struct{}),
    83  		stateProvider: stateProvider,
    84  	}
    85  }
    86  
    87  // Subscribe subscribes to the event bus.
    88  func (h *Handler) Subscribe(bus eventbus.Subscriber) error {
    89  	err := bus.Subscribe(nodeEvent.AppTopicNode, h.ConsumeNodeEvent)
    90  	if err != nil {
    91  		return err
    92  	}
    93  	err = bus.Subscribe(stateEvent.AppTopicState, h.ConsumeStateEvent)
    94  	return err
    95  }
    96  
    97  // Sub subscribes a user to sse
    98  func (h *Handler) Sub(c *gin.Context) {
    99  	resp := c.Writer
   100  	req := c.Request
   101  
   102  	f, ok := resp.(http.Flusher)
   103  	if !ok {
   104  		resp.WriteHeader(http.StatusBadRequest)
   105  		resp.Header().Set("Content-type", "application/json; charset=utf-8")
   106  		writeErr := json.NewEncoder(resp).Encode(errors.New("not a flusher - cannot continue"))
   107  		if writeErr != nil {
   108  			http.Error(resp, "Http response write error", http.StatusInternalServerError)
   109  		}
   110  		return
   111  	}
   112  
   113  	resp.Header().Set("Content-Type", "text/event-stream")
   114  	resp.Header().Set("Cache-Control", "no-cache,no-transform")
   115  	resp.Header().Set("Connection", "keep-alive")
   116  
   117  	messageChan := make(chan string, 1)
   118  	err := h.sendInitialState(messageChan)
   119  	if err != nil {
   120  		resp.WriteHeader(http.StatusBadRequest)
   121  		resp.Header().Set("Content-type", "application/json; charset=utf-8")
   122  		writeErr := json.NewEncoder(resp).Encode(err)
   123  		if writeErr != nil {
   124  			http.Error(resp, "Http response write error", http.StatusInternalServerError)
   125  		}
   126  	}
   127  
   128  	h.newClients <- messageChan
   129  
   130  	defer func() {
   131  		h.deadClients <- messageChan
   132  	}()
   133  
   134  	for {
   135  		select {
   136  		case <-req.Context().Done():
   137  			return
   138  		case msg, open := <-messageChan:
   139  			if !open {
   140  				return
   141  			}
   142  
   143  			_, err := fmt.Fprintf(resp, "data: %s\n\n", msg)
   144  			if err != nil {
   145  				log.Error().Err(err).Msg("failed to print data in response")
   146  				return
   147  			}
   148  
   149  			f.Flush()
   150  		case <-h.stopChan:
   151  			return
   152  		}
   153  	}
   154  }
   155  
   156  func (h *Handler) sendInitialState(messageChan chan string) error {
   157  	res, err := json.Marshal(Event{
   158  		Type:    StateChangeEvent,
   159  		Payload: mapState(h.stateProvider.GetState()),
   160  	})
   161  	if err != nil {
   162  		return err
   163  	}
   164  
   165  	messageChan <- string(res)
   166  	return nil
   167  }
   168  
   169  func (h *Handler) serve() {
   170  	defer func() {
   171  		for k := range h.clients {
   172  			close(k)
   173  		}
   174  	}()
   175  
   176  	for {
   177  		select {
   178  		case <-h.stopChan:
   179  			return
   180  		case s := <-h.newClients:
   181  			h.clients[s] = struct{}{}
   182  		case s := <-h.deadClients:
   183  			delete(h.clients, s)
   184  			close(s)
   185  		case msg := <-h.messages:
   186  			for s := range h.clients {
   187  				// non-locking send to each client
   188  				select {
   189  				case s <- msg:
   190  				default:
   191  				}
   192  			}
   193  		}
   194  	}
   195  }
   196  
   197  func (h *Handler) stop() {
   198  	h.stopOnce.Do(func() { close(h.stopChan) })
   199  }
   200  
   201  func (h *Handler) send(e Event) {
   202  	marshaled, err := json.Marshal(e)
   203  	if err != nil {
   204  		log.Error().Err(err).Msg("Could not marshal SSE message")
   205  		return
   206  	}
   207  	h.messages <- string(marshaled)
   208  }
   209  
   210  // ConsumeNodeEvent consumes the node state event
   211  func (h *Handler) ConsumeNodeEvent(e nodeEvent.Payload) {
   212  	if e.Status == nodeEvent.StatusStarted {
   213  		go h.serve()
   214  		return
   215  	}
   216  	if e.Status == nodeEvent.StatusStopped {
   217  		h.stop()
   218  		return
   219  	}
   220  }
   221  
   222  type stateRes struct {
   223  	Services      []contract.ServiceInfoDTO    `json:"service_info"`
   224  	Sessions      []contract.SessionDTO        `json:"sessions"`
   225  	SessionsStats contract.SessionStatsDTO     `json:"sessions_stats"`
   226  	Consumer      consumerStateRes             `json:"consumer"`
   227  	Identities    []contract.IdentityDTO       `json:"identities"`
   228  	Channels      []contract.PaymentChannelDTO `json:"channels"`
   229  }
   230  
   231  type consumerStateRes struct {
   232  	Connection contract.ConnectionDTO `json:"connection"`
   233  }
   234  
   235  func mapState(state stateEvent.State) stateRes {
   236  	identitiesRes := make([]contract.IdentityDTO, len(state.Identities))
   237  	for idx, identity := range state.Identities {
   238  		stake := new(big.Int)
   239  
   240  		if channel := identityChannel(identity.Address, state.ProviderChannels); channel != nil {
   241  			stake = channel.Channel.Stake
   242  		}
   243  
   244  		identitiesRes[idx] = contract.IdentityDTO{
   245  			Address:             identity.Address,
   246  			RegistrationStatus:  identity.RegistrationStatus.String(),
   247  			ChannelAddress:      identity.ChannelAddress.Hex(),
   248  			Balance:             identity.Balance,
   249  			BalanceTokens:       contract.NewTokens(identity.Balance),
   250  			Earnings:            identity.Earnings,
   251  			EarningsTokens:      contract.NewTokens(identity.Earnings),
   252  			EarningsTotal:       identity.EarningsTotal,
   253  			EarningsTotalTokens: contract.NewTokens(identity.EarningsTotal),
   254  			Stake:               stake,
   255  			HermesID:            identity.HermesID.Hex(),
   256  			EarningsPerHermes:   contract.NewEarningsPerHermesDTO(identity.EarningsPerHermes),
   257  		}
   258  	}
   259  
   260  	channelsRes := make([]contract.PaymentChannelDTO, len(state.ProviderChannels))
   261  	for idx, channel := range state.ProviderChannels {
   262  		channelsRes[idx] = contract.NewPaymentChannelDTO(channel)
   263  	}
   264  
   265  	sessionsRes := make([]contract.SessionDTO, len(state.Sessions))
   266  	sessionsStats := session.NewStats()
   267  	for idx, se := range state.Sessions {
   268  		sessionsRes[idx] = contract.NewSessionDTO(se)
   269  		sessionsStats.Add(se)
   270  	}
   271  
   272  	conn := event.Connection{Session: connectionstate.Status{State: connectionstate.NotConnected}}
   273  
   274  	for k, c := range state.Connections {
   275  		if c.Session.State == "" {
   276  			c.Session.State = connectionstate.NotConnected
   277  		}
   278  		conn = c
   279  		if len(k) > 0 {
   280  			break
   281  		}
   282  	}
   283  
   284  	res := stateRes{
   285  		Services:      state.Services,
   286  		Sessions:      sessionsRes,
   287  		SessionsStats: contract.NewSessionStatsDTO(sessionsStats),
   288  		Consumer: consumerStateRes{
   289  			Connection: contract.NewConnectionDTO(conn.Session, conn.Statistics, conn.Throughput, conn.Invoice),
   290  		},
   291  		Identities: identitiesRes,
   292  		Channels:   channelsRes,
   293  	}
   294  	return res
   295  }
   296  
   297  func identityChannel(address string, channels []pingpong.HermesChannel) *pingpong.HermesChannel {
   298  	for idx := range channels {
   299  		if channels[idx].Identity.Address == address {
   300  			return &channels[idx]
   301  		}
   302  	}
   303  
   304  	return nil
   305  }
   306  
   307  // ConsumeStateEvent consumes the state change event
   308  func (h *Handler) ConsumeStateEvent(event stateEvent.State) {
   309  	h.send(Event{
   310  		Type:    StateChangeEvent,
   311  		Payload: mapState(event),
   312  	})
   313  }