go.dedis.ch/onet/v3@v3.2.11-0.20210930124529-e36530bca7ef/websocket.go (about) 1 package onet 2 3 import ( 4 "context" 5 "crypto/tls" 6 "crypto/x509" 7 "fmt" 8 "net/http" 9 "strings" 10 "sync" 11 "time" 12 13 "github.com/gorilla/websocket" 14 "go.dedis.ch/onet/v3/log" 15 "go.dedis.ch/onet/v3/network" 16 "golang.org/x/xerrors" 17 ) 18 19 const certificateReloaderLeeway = 1 * time.Hour 20 21 // CertificateReloader takes care of reloading a TLS certificate when 22 // requested. 23 type CertificateReloader struct { 24 sync.RWMutex 25 cert *tls.Certificate 26 certPath string 27 keyPath string 28 } 29 30 // NewCertificateReloader takes two file paths as parameter that contain 31 // the certificate and the key data to create an automatic reloader. It will 32 // try to read again the files when the certificate is almost expired. 33 func NewCertificateReloader(certPath, keyPath string) (*CertificateReloader, error) { 34 loader := &CertificateReloader{ 35 certPath: certPath, 36 keyPath: keyPath, 37 } 38 39 err := loader.reload() 40 if err != nil { 41 return nil, xerrors.Errorf("reloading certificate: %v", err) 42 } 43 44 return loader, nil 45 } 46 47 func (cr *CertificateReloader) reload() error { 48 newCert, err := tls.LoadX509KeyPair(cr.certPath, cr.keyPath) 49 if err != nil { 50 return xerrors.Errorf("load x509: %v", err) 51 } 52 53 cr.Lock() 54 cr.cert = &newCert 55 // Successful parse means at least one certificate. 56 cr.cert.Leaf, err = x509.ParseCertificate(newCert.Certificate[0]) 57 cr.Unlock() 58 59 if err != nil { 60 return xerrors.Errorf("parse x509: %v", err) 61 } 62 return nil 63 } 64 65 // GetCertificateFunc makes a function that can be passed to the TLSConfig 66 // so that it resolves the most up-to-date one. 67 func (cr *CertificateReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) { 68 return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { 69 cr.RLock() 70 71 exp := time.Now().Add(certificateReloaderLeeway) 72 73 // Here we know the leaf has been parsed successfully as an error 74 // would have been thrown otherwise. 75 if cr.cert == nil || exp.After(cr.cert.Leaf.NotAfter) { 76 // Certificate has expired so we try to load the new one. 77 78 // Free the read lock to be able to reload. 79 cr.RUnlock() 80 err := cr.reload() 81 if err != nil { 82 return nil, xerrors.Errorf("reload certificate: %v", err) 83 } 84 85 cr.RLock() 86 } 87 88 defer cr.RUnlock() 89 return cr.cert, nil 90 } 91 } 92 93 // WebSocket handles incoming client-requests using the websocket 94 // protocol. When making a new WebSocket, it will listen one port above the 95 // ServerIdentity-port-#. 96 // The websocket protocol has been chosen as smallest common denominator 97 // for languages including JavaScript. 98 type WebSocket struct { 99 services map[string]Service 100 server *http.Server 101 mux *http.ServeMux 102 startstop chan bool 103 started bool 104 TLSConfig *tls.Config // can only be modified before Start is called 105 sync.Mutex 106 } 107 108 // NewWebSocket opens a webservice-listener at the given si.URL. 109 func NewWebSocket(si *network.ServerIdentity) *WebSocket { 110 w := &WebSocket{ 111 services: make(map[string]Service), 112 startstop: make(chan bool), 113 } 114 webHost, err := getWSHostPort(si, true) 115 log.ErrFatal(err) 116 w.mux = http.NewServeMux() 117 w.mux.HandleFunc("/ok", func(w http.ResponseWriter, r *http.Request) { 118 log.Lvl4("ok?", r.RemoteAddr) 119 ok := []byte("ok\n") 120 w.Write(ok) 121 }) 122 123 if allowPprof() { 124 log.Warn("HTTP pprof profiling is enabled") 125 initPprof(w.mux) 126 } 127 128 // Add a catch-all handler (longest paths take precedence, so "/" takes 129 // all non-registered paths) and correctly upgrade to a websocket and 130 // throw an error. 131 w.mux.HandleFunc("/", func(wr http.ResponseWriter, re *http.Request) { 132 log.Error("request from ", re.RemoteAddr, "for invalid path ", re.URL.Path) 133 134 u := websocket.Upgrader{ 135 // The mobile app on iOS doesn't support compression well... 136 EnableCompression: false, 137 // As the website will not be served from ourselves, we 138 // need to accept _all_ origins. Cross-site scripting is 139 // required. 140 CheckOrigin: func(*http.Request) bool { 141 return true 142 }, 143 } 144 ws, err := u.Upgrade(wr, re, http.Header{}) 145 if err != nil { 146 log.Error(err) 147 return 148 } 149 150 ws.WriteControl(websocket.CloseMessage, 151 websocket.FormatCloseMessage(4001, "This service doesn't exist"), 152 time.Now().Add(time.Millisecond*500)) 153 ws.Close() 154 }) 155 w.server = &http.Server{ 156 Addr: webHost, 157 Handler: w.mux, 158 } 159 return w 160 } 161 162 // Listening returns true if the server has been started and is 163 // listening on the ports for incoming connections. 164 func (w *WebSocket) Listening() bool { 165 w.Lock() 166 defer w.Unlock() 167 return w.started 168 } 169 170 // start listening on the port. 171 func (w *WebSocket) start() { 172 w.Lock() 173 w.started = true 174 w.server.TLSConfig = w.TLSConfig 175 log.Lvl2("Starting to listen on", w.server.Addr) 176 started := make(chan bool) 177 go func() { 178 // Check if server is configured for TLS 179 started <- true 180 if w.server.TLSConfig != nil && (w.server.TLSConfig.GetCertificate != nil || len(w.server.TLSConfig.Certificates) >= 1) { 181 w.server.ListenAndServeTLS("", "") 182 } else { 183 w.server.ListenAndServe() 184 } 185 }() 186 <-started 187 w.Unlock() 188 w.startstop <- true 189 } 190 191 // registerService stores a service to the given path. All requests to that 192 // path and it's sub-endpoints will be forwarded to ProcessClientRequest. 193 func (w *WebSocket) registerService(service string, s Service) error { 194 if service == "ok" { 195 return xerrors.New("service name \"ok\" is not allowed") 196 } 197 198 w.services[service] = s 199 h := &wsHandler{ 200 service: s, 201 serviceName: service, 202 } 203 w.mux.Handle(fmt.Sprintf("/%s/", service), h) 204 return nil 205 } 206 207 // stop the websocket and free the port. 208 func (w *WebSocket) stop() { 209 w.Lock() 210 defer w.Unlock() 211 if !w.started { 212 return 213 } 214 log.Lvl3("Stopping", w.server.Addr) 215 216 d := time.Now().Add(100 * time.Millisecond) 217 ctx, cancel := context.WithDeadline(context.Background(), d) 218 w.server.Shutdown(ctx) 219 cancel() 220 221 <-w.startstop 222 w.started = false 223 } 224 225 // Pass the request to the websocket. 226 type wsHandler struct { 227 serviceName string 228 service Service 229 } 230 231 // Wrapper-function so that http.Requests get 'upgraded' to websockets 232 // and handled correctly. 233 func (t wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 234 rx := 0 235 tx := 0 236 n := 0 237 238 defer func() { 239 log.Lvl2("ws close", r.RemoteAddr, "n", n, "rx", rx, "tx", tx) 240 }() 241 242 u := websocket.Upgrader{ 243 // The mobile app on iOS doesn't support compression well... 244 EnableCompression: false, 245 // As the website will not be served from ourselves, we 246 // need to accept _all_ origins. Cross-site scripting is 247 // required. 248 CheckOrigin: func(*http.Request) bool { 249 return true 250 }, 251 } 252 ws, err := u.Upgrade(w, r, http.Header{}) 253 if err != nil { 254 log.Error(err) 255 return 256 } 257 defer ws.Close() 258 259 // Loop for each message 260 outerReadLoop: 261 for err == nil { 262 mt, buf, rerr := ws.ReadMessage() 263 if rerr != nil { 264 err = rerr 265 break 266 } 267 rx += len(buf) 268 n++ 269 270 s := t.service 271 var reply []byte 272 var outChan chan []byte 273 path := strings.TrimPrefix(r.URL.Path, "/"+t.serviceName+"/") 274 log.Lvlf2("ws request from %s: %s/%s", r.RemoteAddr, t.serviceName, path) 275 276 isStreaming := false 277 bidirectionalStreamer, ok := s.(BidirectionalStreamer) 278 if ok { 279 isStreaming, err = bidirectionalStreamer.IsStreaming(path) 280 if err != nil { 281 log.Errorf("failed to check if it is a streaming "+ 282 "request %s/%s: %+v", t.serviceName, path, err) 283 continue 284 } 285 } 286 287 if !isStreaming { 288 reply, _, err = s.ProcessClientRequest(r, path, buf) 289 if err != nil { 290 log.Errorf("Got an error while executing %s/%s: %+v", 291 t.serviceName, path, err) 292 continue 293 } 294 295 tx += len(reply) 296 err = ws.SetWriteDeadline(time.Now().Add(5 * time.Minute)) 297 if err != nil { 298 log.Error(xerrors.Errorf("failed to set the write deadline "+ 299 "with request request %s/%s: %v", t.serviceName, path, err)) 300 break 301 } 302 303 err = ws.WriteMessage(mt, reply) 304 if err != nil { 305 log.Error(xerrors.Errorf("failed to write message with "+ 306 "request %s/%s: %v", t.serviceName, path, err)) 307 break 308 } 309 310 continue 311 } 312 313 clientInputs := make(chan []byte, 10) 314 clientInputs <- buf 315 outChan, err = bidirectionalStreamer.ProcessClientStreamRequest(r, 316 path, clientInputs) 317 if err != nil { 318 log.Errorf("got an error while processing streaming "+ 319 "request %s/%s: %+v", t.serviceName, path, err) 320 continue 321 } 322 323 closing := make(chan bool) 324 go func() { 325 for { 326 // Listen for incoming messages to know if the client wants to 327 // close the stream. If this is an error, we assume the client 328 // wants to close the stream, otherwise we forward the message 329 // to the service. 330 _, buf, err := ws.ReadMessage() 331 if err != nil { 332 close(closing) 333 return 334 } 335 clientInputs <- buf 336 } 337 }() 338 339 for { 340 select { 341 case <-closing: 342 close(clientInputs) 343 break outerReadLoop 344 case reply, ok := <-outChan: 345 if !ok { 346 ws.WriteControl(websocket.CloseMessage, 347 websocket.FormatCloseMessage(websocket.CloseNormalClosure, "service finished streaming"), 348 time.Now().Add(time.Millisecond*500)) 349 close(clientInputs) 350 return 351 } 352 tx += len(reply) 353 354 err = ws.SetWriteDeadline(time.Now().Add(5 * time.Minute)) 355 if err != nil { 356 log.Error(xerrors.Errorf("failed to set the write "+ 357 "deadline in the streaming loop: %v", err)) 358 close(clientInputs) 359 break outerReadLoop 360 } 361 362 err = ws.WriteMessage(mt, reply) 363 if err != nil { 364 log.Error(xerrors.Errorf("failed to write next message "+ 365 "in the streaming loop: %v", err)) 366 close(clientInputs) 367 break outerReadLoop 368 } 369 } 370 } 371 372 } 373 374 errMessage := "unexpected error: " 375 if err != nil { 376 errMessage += err.Error() 377 } 378 379 ws.WriteControl(websocket.CloseMessage, 380 websocket.FormatCloseMessage(websocket.CloseProtocolError, errMessage), 381 time.Now().Add(time.Millisecond*500)) 382 return 383 } 384 385 type destination struct { 386 si *network.ServerIdentity 387 path string 388 }