github.com/e154/smart-home@v0.17.2-0.20240311175135-e530a6e5cd45/system/gate/server/wsp/server.go (about)

     1  // This file is part of the Smart Home
     2  // Program complex distribution https://github.com/e154/smart-home
     3  // Copyright (C) 2023, Filippov Alex
     4  //
     5  // This library is free software: you can redistribute it and/or
     6  // modify it under the terms of the GNU Lesser General Public
     7  // License as published by the Free Software Foundation; either
     8  // version 3 of the License, or (at your option) any later version.
     9  //
    10  // This library is distributed in the hope that it will be useful,
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    13  // Library General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Lesser General Public
    16  // License along with this library.  If not, see
    17  // <https://www.gnu.org/licenses/>.
    18  
    19  package wsp
    20  
    21  import (
    22  	"net/http"
    23  	"net/url"
    24  	"time"
    25  
    26  	"github.com/gorilla/websocket"
    27  	"github.com/pkg/errors"
    28  
    29  	"github.com/e154/smart-home/common/logger"
    30  	"github.com/e154/smart-home/system/gate/common"
    31  )
    32  
    33  var (
    34  	log = logger.MustGetLogger("wsp server")
    35  )
    36  
    37  // Server is a Reverse HTTP Proxy over WebSocket
    38  // This is the Server part, Clients will offer websocket connections,
    39  // those will be pooled to transfer HTTP Request and response
    40  type Server struct {
    41  	config   *Config
    42  	upgrader websocket.Upgrader
    43  	pools    *Pools
    44  	done     chan struct{}
    45  	server   *http.Server
    46  }
    47  
    48  func NewServer(config *Config) (server *Server) {
    49  	server = &Server{
    50  		config:   config,
    51  		upgrader: websocket.Upgrader{},
    52  		done:     make(chan struct{}),
    53  		pools:    NewPools(config.Timeout, config.IdleTimeout),
    54  	}
    55  	return
    56  }
    57  
    58  // Start Server HTTP server
    59  func (s *Server) Start() {
    60  	go func() {
    61  		for {
    62  			select {
    63  			case <-s.done:
    64  				return
    65  			case <-time.After(5 * time.Second):
    66  				s.pools.Clean()
    67  			}
    68  		}
    69  	}()
    70  
    71  }
    72  
    73  func (s *Server) Ws(w http.ResponseWriter, r *http.Request) {
    74  	log.Infof("[%s] %s", r.Method, r.URL.String())
    75  
    76  	if s.pools.IsEmpty() {
    77  		common.ProxyErrorf(w, "No proxy available")
    78  		return
    79  	}
    80  
    81  	serverId, err := s.GetServerID(r)
    82  	if err != nil {
    83  		common.ProxyErrorf(w, err.Error())
    84  		return
    85  	}
    86  
    87  	pool, ok := s.pools.GetPool(PoolID(serverId))
    88  	if !ok {
    89  		common.ProxyErrorf(w, "Unable to get a pool")
    90  		return
    91  	}
    92  
    93  	connection := pool.GetIdleConnection(r.Context())
    94  	if connection == nil {
    95  		common.ProxyErrorf(w, "Unable to get a proxy connection")
    96  		return
    97  	}
    98  
    99  	if err := connection.proxyWs(w, r); err != nil {
   100  		// An error occurred throw the connection away
   101  		log.Error(err.Error())
   102  		connection.Close()
   103  
   104  		// Try to return an error to the client
   105  		// This might fail if response headers have already been sent
   106  		common.ProxyError(w, err)
   107  	}
   108  }
   109  
   110  func (s *Server) Request(w http.ResponseWriter, r *http.Request) {
   111  
   112  	if s.pools.IsEmpty() {
   113  		common.ProxyErrorf(w, "No proxy available")
   114  		return
   115  	}
   116  
   117  	serverId, err := s.GetServerID(r)
   118  	if err != nil {
   119  		common.ProxyErrorf(w, err.Error())
   120  		return
   121  	}
   122  
   123  	r.URL = &url.URL{
   124  		Path:        r.URL.Path,
   125  		RawQuery:    r.URL.RawQuery,
   126  		Fragment:    r.URL.Fragment,
   127  		RawFragment: r.URL.RawFragment,
   128  	}
   129  
   130  	log.Infof("[%s] %s", r.Method, r.URL.String())
   131  
   132  	pool, ok := s.pools.GetPool(PoolID(serverId))
   133  	if !ok {
   134  		common.ProxyErrorf(w, "Unable to get a proxy connection")
   135  		return
   136  	}
   137  
   138  	connection := pool.GetIdleConnection(r.Context())
   139  	if connection == nil {
   140  		common.ProxyErrorf(w, "Unable to get a proxy connection")
   141  		return
   142  	}
   143  
   144  	// [3]: Send the request to the peer through the WebSocket connection.
   145  	if err := connection.proxyRequest(w, r); err != nil {
   146  		// An error occurred throw the connection away
   147  		log.Error(err.Error())
   148  		connection.Close()
   149  
   150  		// Try to return an error to the client
   151  		// This might fail if response headers have already been sent
   152  		common.ProxyError(w, err)
   153  	}
   154  }
   155  
   156  // Request receives the WebSocket upgrade handshake request from wsp_client.
   157  func (s *Server) Register(w http.ResponseWriter, r *http.Request) {
   158  	// 1. Upgrade a received HTTP request to a WebSocket connection
   159  	secretKey := r.Header.Get("X-SECRET-KEY")
   160  	if s.config.SecretKey != "" && secretKey != s.config.SecretKey {
   161  		common.ProxyErrorf(w, "Invalid X-SECRET-KEY")
   162  		return
   163  	}
   164  
   165  	ws, err := s.upgrader.Upgrade(w, r, nil)
   166  	if err != nil {
   167  		common.ProxyErrorf(w, "HTTP upgrade error : %v", err)
   168  		return
   169  	}
   170  
   171  	if err = s.pools.RegisterConnection(ws); err != nil {
   172  		common.ProxyErrorf(w, err.Error())
   173  	}
   174  }
   175  
   176  // Shutdown stop the Server
   177  func (s *Server) Shutdown() {
   178  	close(s.done)
   179  	s.pools.Shutdown()
   180  }
   181  
   182  func (s *Server) GetServerID(r *http.Request) (serverID string, err error) {
   183  	serverID = r.Header.Get("X-SERVER-ID")
   184  	if serverID != "" {
   185  		return
   186  	}
   187  
   188  	query := r.URL.Query()
   189  	serverID = query.Get("serverId")
   190  	if serverID != "" {
   191  		return
   192  	}
   193  
   194  	serverID = query.Get("server_id")
   195  	if serverID != "" {
   196  		return
   197  	}
   198  
   199  	if serverID == "" {
   200  		err = errors.New("Unable to parse DESTINATION params")
   201  	}
   202  
   203  	return
   204  }