github.com/lastbackend/toolkit@v0.0.0-20241020043710-cafa37b95aad/pkg/server/http/websockets/manager.go (about)

     1  /*
     2  Copyright [2014] - [2023] The Last.Backend authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package websockets
    18  
    19  import (
    20  	"context"
    21  	"github.com/lastbackend/toolkit/pkg/runtime/logger"
    22  	"net/http"
    23  	"reflect"
    24  	"sync"
    25  	"unsafe"
    26  
    27  	"github.com/gorilla/websocket"
    28  	"github.com/lastbackend/toolkit/pkg/server/http/errors"
    29  	"github.com/lastbackend/toolkit/pkg/util/converter"
    30  )
    31  
    32  var (
    33  	// upgrader is used to upgrade incoming HTTP requests
    34  	// into a persistent websocket connection
    35  	upgrader = websocket.Upgrader{
    36  		ReadBufferSize:  1024,
    37  		WriteBufferSize: 1024,
    38  		CheckOrigin: func(r *http.Request) bool {
    39  			return true
    40  		},
    41  	}
    42  )
    43  
    44  var (
    45  	ErrEventNotSupported = errors.New("this event type is not supported")
    46  )
    47  
    48  const (
    49  	RequestHeaders = "x-http-req-headers"
    50  )
    51  
    52  // Manager is used to hold references to all Clients Registered, and Broadcasting etc
    53  type Manager struct {
    54  	sync.RWMutex
    55  
    56  	log      logger.Logger
    57  	clients  ClientList
    58  	handlers map[string]EventHandler
    59  }
    60  
    61  // NewManager is used to initialize all the values inside the manager
    62  func NewManager(log logger.Logger) *Manager {
    63  	m := &Manager{
    64  		log:      log,
    65  		clients:  make(ClientList),
    66  		handlers: make(map[string]EventHandler),
    67  	}
    68  	return m
    69  }
    70  
    71  // AddEventHandler configures and adds all handlers
    72  func (m *Manager) AddEventHandler(event string, handler EventHandler) {
    73  	m.Lock()
    74  	defer m.Unlock()
    75  	m.handlers[event] = handler
    76  }
    77  
    78  // routeEvent is used to make sure the correct event goes into the correct handler
    79  func (m *Manager) routeEvent(event Event, c *Client) error {
    80  	if handler, ok := m.handlers[event.Type]; ok {
    81  		if err := handler(c.ctx, event, c); err != nil {
    82  			return err
    83  		}
    84  		return nil
    85  	} else {
    86  		return ErrEventNotSupported
    87  	}
    88  }
    89  
    90  // ServeWS is an HTTP Handler that the has the Manager that allows connections
    91  func (m *Manager) ServeWS(w http.ResponseWriter, r *http.Request) {
    92  
    93  	if r.Method != http.MethodGet {
    94  		http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
    95  		return
    96  	}
    97  
    98  	// Begin by upgrading the HTTP request
    99  	conn, err := upgrader.Upgrade(w, r, nil)
   100  	if err != nil {
   101  		m.log.Errorf("upgrading the HTTP request failed %v", err)
   102  		return
   103  	}
   104  
   105  	// Copy request context data
   106  	ctx := context.Background()
   107  	contextValues := reflect.ValueOf(r.Context()).Elem()
   108  	contextKeys := reflect.TypeOf(r.Context()).Elem()
   109  
   110  	if contextKeys.Kind() == reflect.Struct {
   111  		for i := 0; i < contextValues.NumField(); i++ {
   112  			reflectValue := contextValues.Field(i)
   113  			reflectValue = reflect.NewAt(reflectValue.Type(), unsafe.Pointer(reflectValue.UnsafeAddr())).Elem()
   114  
   115  			reflectField := contextKeys.Field(i)
   116  			if reflectField.Name == "key" {
   117  				ctx = context.WithValue(ctx, reflectValue.Interface(), r.Context().Value(reflectValue.Interface()))
   118  			}
   119  		}
   120  	}
   121  
   122  	headers, err := converter.PrepareHeaderFromRequest(r)
   123  	if err != nil {
   124  		errors.HTTP.InternalServerError(w)
   125  		return
   126  	}
   127  
   128  	ctx = context.WithValue(ctx, RequestHeaders, headers)
   129  
   130  	client := NewClient(ctx, m.log, conn, m)
   131  
   132  	m.addClient(client)
   133  
   134  	go client.readMessages()
   135  	go client.writeMessages()
   136  }
   137  
   138  // addClient will add clients to our clients list
   139  func (m *Manager) addClient(client *Client) {
   140  	m.Lock()
   141  	defer m.Unlock()
   142  
   143  	m.clients[client] = true
   144  }
   145  
   146  // removeClient will remove the client and clean up
   147  func (m *Manager) removeClient(client *Client) {
   148  	m.Lock()
   149  	defer m.Unlock()
   150  
   151  	if _, ok := m.clients[client]; ok {
   152  		if err := client.connection.Close(); err != nil {
   153  			m.log.Errorf("close client failed %v", err)
   154  		}
   155  		delete(m.clients, client)
   156  	}
   157  }