github.com/stampzilla/stampzilla-go@v2.0.0-rc9+incompatible/nodes/stampzilla-server/webserver/webserver.go (about)

     1  package webserver
     2  
     3  import (
     4  	"crypto/tls"
     5  	"io"
     6  	"net/http"
     7  	"os"
     8  	"path"
     9  	"time"
    10  
    11  	"github.com/gin-contrib/cors"
    12  	"github.com/gin-contrib/gzip"
    13  	"github.com/gin-gonic/gin"
    14  	"github.com/google/uuid"
    15  	"github.com/jonaz/ginlogrus"
    16  	"github.com/jonaz/gograce"
    17  	"github.com/olahol/melody"
    18  	"github.com/rakyll/statik/fs"
    19  	"github.com/sirupsen/logrus"
    20  	"github.com/stampzilla/stampzilla-go/nodes/stampzilla-server/handlers"
    21  	"github.com/stampzilla/stampzilla-go/nodes/stampzilla-server/models"
    22  
    23  	"github.com/stampzilla/stampzilla-go/nodes/stampzilla-server/store"
    24  	"github.com/stampzilla/stampzilla-go/nodes/stampzilla-server/websocket"
    25  )
    26  
    27  type Webserver struct {
    28  	Store            *store.Store
    29  	Melody           *melody.Melody
    30  	Config           *models.Config
    31  	WebsocketHandler handlers.WebsocketHandler
    32  	router           http.Handler
    33  }
    34  
    35  func New(s *store.Store, conf *models.Config, wsh handlers.WebsocketHandler, m *melody.Melody) *Webserver {
    36  
    37  	return &Webserver{
    38  		Store:            s,
    39  		Config:           conf,
    40  		WebsocketHandler: wsh,
    41  		Melody:           m,
    42  	}
    43  }
    44  
    45  func (ws *Webserver) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    46  	ws.router.ServeHTTP(w, req)
    47  }
    48  
    49  func (ws *Webserver) Init() *gin.Engine {
    50  	gin.SetMode(gin.TestMode)
    51  
    52  	r := gin.New()
    53  	r.Use(gzip.Gzip(gzip.DefaultCompression))
    54  
    55  	ws.initMelody()
    56  
    57  	r.Use(ginlogrus.New(logrus.StandardLogger()))
    58  	r.Use(cors.Default())
    59  
    60  	statikFS, err := fs.New()
    61  	if err == nil { // we only service GUI if statik files can be found
    62  		r.GET("/service-worker.js", gin.WrapH(http.FileServer(statikFS)))
    63  		r.GET("/assets/*all", gin.WrapH(http.FileServer(statikFS)))
    64  		r.NoRoute(func(c *gin.Context) {
    65  			cspMiddleware()(c)
    66  			c.Request.URL.Path = "/" // force us to always return index.html and not the requested page to be compatible with HTML5 routing
    67  			http.FileServer(statikFS).ServeHTTP(c.Writer, c.Request)
    68  		})
    69  	}
    70  
    71  	// Setup gin
    72  	r.GET("/ca.crt", ws.handleDownloadCA())
    73  	r.GET("/ws", ws.handleWs(ws.Melody))
    74  
    75  	ws.router = r
    76  	return r
    77  }
    78  func (ws *Webserver) Start(addr string, tlsConfig *tls.Config) chan struct{} {
    79  
    80  	server, done := gograce.NewServerWithTimeout(10 * time.Second)
    81  
    82  	server.Handler = ws.Init()
    83  	server.Addr = addr
    84  
    85  	go func() {
    86  		if tlsConfig != nil {
    87  			server.TLSConfig = tlsConfig
    88  			logrus.Infof("Starting secure webserver at %s", addr)
    89  			logrus.Error(server.ListenAndServeTLS("", ""))
    90  		} else {
    91  			logrus.Infof("Starting webserver at %s", addr)
    92  			logrus.Error(server.ListenAndServe())
    93  		}
    94  	}()
    95  	return done
    96  }
    97  
    98  func (ws *Webserver) initMelody() {
    99  	// Setup melody
   100  	ws.Melody.Upgrader.CheckOrigin = func(r *http.Request) bool { return true }
   101  	ws.Melody.HandleConnect(ws.handleConnect())
   102  	ws.Melody.HandleMessage(ws.handleMessage())
   103  	ws.Melody.HandleDisconnect(ws.handleDisconnect())
   104  }
   105  
   106  func cspMiddleware() gin.HandlerFunc {
   107  	return func(c *gin.Context) {
   108  		c.Writer.Header().Set("Content-Security-Policy", "worker-src 'self';")
   109  		c.Next()
   110  	}
   111  }
   112  
   113  func (ws *Webserver) handleConnect() func(s *melody.Session) {
   114  	return func(s *melody.Session) {
   115  		proto, _ := s.Get(websocket.KeyProtocol.String())
   116  		id, _ := s.Get(websocket.KeyID.String())
   117  
   118  		ws.Store.AddOrUpdateConnection(id.(string), &models.Connection{
   119  			Type:       proto.(string),
   120  			RemoteAddr: s.Request.RemoteAddr,
   121  			Attributes: s.Keys,
   122  		})
   123  
   124  		err := ws.WebsocketHandler.Connect(s, s.Request, s.Keys)
   125  		if err != nil {
   126  			logrus.Error(err)
   127  			return
   128  		}
   129  
   130  	}
   131  }
   132  
   133  func (ws *Webserver) handleMessage() func(s *melody.Session, msg []byte) {
   134  	return func(s *melody.Session, msg []byte) {
   135  		data, err := models.ParseMessage(msg)
   136  		if err != nil {
   137  			logrus.Error("cannot parse incoming websocket message: ", err)
   138  			return
   139  		}
   140  
   141  		id, _ := s.Get(websocket.KeyID.String())
   142  		data.FromUUID = id.(string)
   143  		err = ws.WebsocketHandler.Message(s, data)
   144  		if err != nil {
   145  			logrus.Error(err)
   146  			return
   147  		}
   148  
   149  	}
   150  }
   151  
   152  func (ws *Webserver) handleDisconnect() func(s *melody.Session) {
   153  	return func(s *melody.Session) {
   154  		id, _ := s.Get(websocket.KeyID.String())
   155  		ws.Store.RemoveConnection(id.(string))
   156  		err := ws.WebsocketHandler.Disconnect(s)
   157  		if err != nil {
   158  			logrus.Error(err)
   159  			return
   160  		}
   161  	}
   162  }
   163  
   164  func (ws *Webserver) handleDownloadCA() func(c *gin.Context) {
   165  	return func(c *gin.Context) {
   166  		header := c.Writer.Header()
   167  		header["Content-Type"] = []string{"application/x-x509-ca-cert"}
   168  
   169  		file, err := os.Open(path.Join("certificates", "ca.crt"))
   170  		if err != nil {
   171  			c.String(http.StatusOK, "%v", err)
   172  			return
   173  		}
   174  		defer file.Close()
   175  
   176  		io.Copy(c.Writer, file)
   177  	}
   178  }
   179  
   180  func (ws *Webserver) handleWs(m *melody.Melody) func(c *gin.Context) {
   181  	return func(c *gin.Context) {
   182  		keys := make(map[string]interface{})
   183  		keys[websocket.KeyID.String()] = uuid.New().String()
   184  
   185  		if c.Request.TLS != nil {
   186  			keys["secure"] = true
   187  
   188  			certs := c.Request.TLS.PeerCertificates
   189  			if len(certs) > 0 {
   190  				keys["identity"] = certs[0].Subject.CommonName
   191  			}
   192  		}
   193  
   194  		// Accept the requested protocol if known
   195  		knownProtocols := []string{
   196  			"node",
   197  			"gui",
   198  			"metrics",
   199  		}
   200  		proto := c.Request.Header.Get("Sec-WebSocket-Protocol")
   201  
   202  		allowed := false
   203  		for _, v := range knownProtocols {
   204  			if proto == v {
   205  				allowed = true
   206  				break
   207  			}
   208  		}
   209  
   210  		if !allowed {
   211  			logrus.Errorf("webserver: protocol \"%s\" not allowed", proto)
   212  			c.AbortWithStatus(http.StatusForbidden)
   213  			return
   214  		}
   215  
   216  		if proto != "" {
   217  			c.Writer.Header().Set("Sec-WebSocket-Protocol", proto)
   218  			keys[websocket.KeyProtocol.String()] = proto
   219  		}
   220  
   221  		if c.Request.Header.Get("X-UUID") != "" {
   222  			keys[websocket.KeyID.String()] = c.Request.Header.Get("X-UUID")
   223  		}
   224  		keys["type"] = c.Request.Header.Get("X-TYPE")
   225  
   226  		if ws.Store.Connection(keys[websocket.KeyID.String()].(string)) != nil {
   227  			logrus.Error("Connection with same UUID already exists")
   228  			c.AbortWithStatus(http.StatusForbidden)
   229  			return
   230  		}
   231  
   232  		err := m.HandleRequestWithKeys(c.Writer, c.Request, keys)
   233  		if err != nil {
   234  			logrus.Errorf("webserver: %s", err.Error())
   235  			return
   236  		}
   237  
   238  	}
   239  }