gopkg.in/dedis/onet.v2@v2.0.0-20181115163211-c8f3724038a7/websocket.go (about) 1 package onet 2 3 import ( 4 "crypto/tls" 5 "errors" 6 "fmt" 7 "net" 8 "net/http" 9 "net/url" 10 "reflect" 11 "strconv" 12 "strings" 13 "sync" 14 "time" 15 16 "github.com/dedis/protobuf" 17 "github.com/gorilla/websocket" 18 "gopkg.in/dedis/onet.v2/log" 19 "gopkg.in/dedis/onet.v2/network" 20 "gopkg.in/tylerb/graceful.v1" 21 ) 22 23 // WebSocket handles incoming client-requests using the websocket 24 // protocol. When making a new WebSocket, it will listen one port above the 25 // ServerIdentity-port-#. 26 // The websocket protocol has been chosen as smallest common denominator 27 // for languages including JavaScript. 28 type WebSocket struct { 29 services map[string]Service 30 server *graceful.Server 31 mux *http.ServeMux 32 startstop chan bool 33 started bool 34 TLSConfig *tls.Config // can only be modified before Start is called 35 sync.Mutex 36 } 37 38 // NewWebSocket opens a webservice-listener one port above the given 39 // ServerIdentity. 40 func NewWebSocket(si *network.ServerIdentity) *WebSocket { 41 w := &WebSocket{ 42 services: make(map[string]Service), 43 startstop: make(chan bool), 44 } 45 webHost, err := getWSHostPort(si, true) 46 log.ErrFatal(err) 47 w.mux = http.NewServeMux() 48 w.mux.HandleFunc("/ok", func(w http.ResponseWriter, r *http.Request) { 49 log.Lvl4("ok?", r.RemoteAddr) 50 ok := []byte("ok\n") 51 w.Write(ok) 52 }) 53 54 // Add a catch-all handler (longest paths take precedence, so "/" takes 55 // all non-registered paths) and correctly upgrade to a websocket and 56 // throw an error. 57 w.mux.HandleFunc("/", func(wr http.ResponseWriter, re *http.Request) { 58 log.Error("request from ", re.RemoteAddr, "for invalid path ", re.URL.Path) 59 60 u := websocket.Upgrader{ 61 EnableCompression: true, 62 // As the website will not be served from ourselves, we 63 // need to accept _all_ origins. Cross-site scripting is 64 // required. 65 CheckOrigin: func(*http.Request) bool { 66 return true 67 }, 68 } 69 ws, err := u.Upgrade(wr, re, http.Header{}) 70 if err != nil { 71 log.Error(err) 72 return 73 } 74 75 ws.WriteControl(websocket.CloseMessage, 76 websocket.FormatCloseMessage(4001, "This service doesn't exist"), 77 time.Now().Add(time.Millisecond*500)) 78 ws.Close() 79 }) 80 w.server = &graceful.Server{ 81 Timeout: 100 * time.Millisecond, 82 Server: &http.Server{ 83 Addr: webHost, 84 Handler: w.mux, 85 }, 86 NoSignalHandling: true, 87 } 88 return w 89 } 90 91 // Listening returns true if the server has been started and is 92 // listening on the ports for incoming connections. 93 func (w *WebSocket) Listening() bool { 94 w.Lock() 95 defer w.Unlock() 96 return w.started 97 } 98 99 // start listening on the port. 100 func (w *WebSocket) start() { 101 w.Lock() 102 w.started = true 103 w.server.Server.TLSConfig = w.TLSConfig 104 log.Lvl2("Starting to listen on", w.server.Server.Addr) 105 started := make(chan bool) 106 go func() { 107 // Check if server is configured for TLS 108 started <- true 109 if w.server.Server.TLSConfig != nil && len(w.server.Server.TLSConfig.Certificates) >= 1 { 110 w.server.ListenAndServeTLS("", "") 111 } else { 112 w.server.ListenAndServe() 113 } 114 }() 115 <-started 116 w.Unlock() 117 w.startstop <- true 118 } 119 120 // registerService stores a service to the given path. All requests to that 121 // path and it's sub-endpoints will be forwarded to ProcessClientRequest. 122 func (w *WebSocket) registerService(service string, s Service) error { 123 if service == "ok" { 124 return errors.New("service name \"ok\" is not allowed") 125 } 126 127 w.services[service] = s 128 h := &wsHandler{ 129 service: s, 130 serviceName: service, 131 } 132 w.mux.Handle(fmt.Sprintf("/%s/", service), h) 133 return nil 134 } 135 136 // stop the websocket and free the port. 137 func (w *WebSocket) stop() { 138 w.Lock() 139 defer w.Unlock() 140 if !w.started { 141 return 142 } 143 log.Lvl3("Stopping", w.server.Server.Addr) 144 w.server.Stop(100 * time.Millisecond) 145 <-w.startstop 146 w.started = false 147 } 148 149 // Pass the request to the websocket. 150 type wsHandler struct { 151 serviceName string 152 service Service 153 } 154 155 // Wrapper-function so that http.Requests get 'upgraded' to websockets 156 // and handled correctly. 157 func (t wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 158 rx := 0 159 tx := 0 160 n := 0 161 162 defer func() { 163 log.Lvl2("ws close", r.RemoteAddr, "n", n, "rx", rx, "tx", tx) 164 }() 165 166 u := websocket.Upgrader{ 167 EnableCompression: true, 168 // As the website will not be served from ourselves, we 169 // need to accept _all_ origins. Cross-site scripting is 170 // required. 171 CheckOrigin: func(*http.Request) bool { 172 return true 173 }, 174 } 175 ws, err := u.Upgrade(w, r, http.Header{}) 176 if err != nil { 177 log.Error(err) 178 return 179 } 180 defer func() { 181 ws.Close() 182 }() 183 184 // Loop for each message 185 outerReadLoop: 186 for err == nil { 187 mt, buf, rerr := ws.ReadMessage() 188 if rerr != nil { 189 err = rerr 190 break 191 } 192 rx += len(buf) 193 n++ 194 195 s := t.service 196 var reply []byte 197 var tun *StreamingTunnel 198 path := strings.TrimPrefix(r.URL.Path, "/"+t.serviceName+"/") 199 log.Lvlf2("ws request from %s: %s/%s", r.RemoteAddr, t.serviceName, path) 200 reply, tun, err = s.ProcessClientRequest(r, path, buf) 201 if err == nil { 202 if tun == nil { 203 tx += len(reply) 204 if err := ws.SetWriteDeadline(time.Now().Add(5 * time.Minute)); err != nil { 205 log.Error(err) 206 break 207 } 208 if err := ws.WriteMessage(mt, reply); err != nil { 209 log.Error(err) 210 break 211 } 212 } else { 213 for { 214 select { 215 case reply, ok := <-tun.out: 216 if !ok { 217 err = errors.New("service finished streaming") 218 close(tun.close) 219 break outerReadLoop 220 } 221 tx += len(reply) 222 if err = ws.SetWriteDeadline(time.Now().Add(5 * time.Minute)); err != nil { 223 log.Error(err) 224 close(tun.close) 225 break outerReadLoop 226 } 227 if err = ws.WriteMessage(mt, reply); err != nil { 228 log.Error(err) 229 close(tun.close) 230 break outerReadLoop 231 } 232 } 233 } 234 } 235 } else { 236 log.Errorf("Got an error while executing %s/%s: %s", t.serviceName, path, err.Error()) 237 } 238 } 239 240 ws.WriteControl(websocket.CloseMessage, 241 websocket.FormatCloseMessage(4000, err.Error()), 242 time.Now().Add(time.Millisecond*500)) 243 return 244 } 245 246 type destination struct { 247 si *network.ServerIdentity 248 path string 249 } 250 251 // Client is a struct used to communicate with a remote Service running on a 252 // onet.Server. Using Send it can connect to multiple remote Servers. 253 type Client struct { 254 service string 255 connections map[destination]*websocket.Conn 256 suite network.Suite 257 // if not nil, use TLS 258 TLSClientConfig *tls.Config 259 // whether to keep the connection 260 keep bool 261 rx uint64 262 tx uint64 263 sync.Mutex 264 } 265 266 // NewClient returns a client using the service s. On the first Send, the 267 // connection will be started, until Close is called. 268 func NewClient(suite network.Suite, s string) *Client { 269 return &Client{ 270 service: s, 271 connections: make(map[destination]*websocket.Conn), 272 suite: suite, 273 } 274 } 275 276 // NewClientKeep returns a Client that doesn't close the connection between 277 // two messages if it's the same server. 278 func NewClientKeep(suite network.Suite, s string) *Client { 279 return &Client{ 280 service: s, 281 keep: true, 282 connections: make(map[destination]*websocket.Conn), 283 suite: suite, 284 } 285 } 286 287 // Suite returns the cryptographic suite in use on this connection. 288 func (c *Client) Suite() network.Suite { 289 return c.suite 290 } 291 292 func (c *Client) closeSingleUseConn(dst *network.ServerIdentity, path string) { 293 dest := destination{dst, path} 294 if !c.keep { 295 if err := c.closeConn(dest); err != nil { 296 log.Errorf("error while closing the connection to %v : %v\n", dest, err) 297 } 298 } 299 } 300 301 func (c *Client) newConnIfNotExist(dst *network.ServerIdentity, path string) (*websocket.Conn, error) { 302 var err error 303 304 // TODO we are opening a new connection for every new path? 305 // not possible to use an existing connection for the same service? 306 dest := destination{dst, path} 307 conn, ok := c.connections[dest] 308 309 if !ok { 310 d := &websocket.Dialer{} 311 d.TLSClientConfig = c.TLSClientConfig 312 313 var serverURL string 314 var header http.Header 315 316 // If the URL is in the dst, then use it. 317 if dst.URL != "" { 318 u, err := url.Parse(dst.URL) 319 if err != nil { 320 return nil, err 321 } 322 if u.Scheme == "https" { 323 u.Scheme = "wss" 324 } else { 325 u.Scheme = "ws" 326 } 327 u.Path += "/" + c.service + "/" + path 328 serverURL = u.String() 329 header = http.Header{"Origin": []string{dst.URL}} 330 } else { 331 // Open connection to service. 332 hp, err := getWSHostPort(dst, false) 333 if err != nil { 334 return nil, err 335 } 336 337 var wsProtocol string 338 var protocol string 339 340 // The old hacky way of deciding if this server has HTTPS or not: 341 // the client somehow magically knows and tells onet by setting 342 // c.TLSClientConfig to a non-nil value. 343 if c.TLSClientConfig != nil { 344 wsProtocol = "wss" 345 protocol = "https" 346 } else { 347 wsProtocol = "ws" 348 protocol = "http" 349 } 350 serverURL = fmt.Sprintf("%s://%s/%s/%s", wsProtocol, hp, c.service, path) 351 header = http.Header{"Origin": []string{protocol + "://" + hp}} 352 } 353 354 // Re-try to connect in case the websocket is just about to start 355 for a := 0; a < network.MaxRetryConnect; a++ { 356 conn, _, err = d.Dial(serverURL, header) 357 if err == nil { 358 break 359 } 360 time.Sleep(network.WaitRetry) 361 } 362 if err != nil { 363 return nil, err 364 } 365 c.connections[dest] = conn 366 } 367 return conn, nil 368 } 369 370 // Send will marshal the message into a ClientRequest message and send it. 371 func (c *Client) Send(dst *network.ServerIdentity, path string, buf []byte) ([]byte, error) { 372 c.Lock() 373 defer c.Unlock() 374 375 conn, err := c.newConnIfNotExist(dst, path) 376 if err != nil { 377 return nil, err 378 } 379 defer c.closeSingleUseConn(dst, path) 380 381 log.Lvlf4("Sending %x to %s/%s", buf, c.service, path) 382 if err := conn.WriteMessage(websocket.BinaryMessage, buf); err != nil { 383 return nil, err 384 } 385 c.tx += uint64(len(buf)) 386 387 if err := conn.SetReadDeadline(time.Now().Add(5 * time.Minute)); err != nil { 388 return nil, err 389 } 390 _, rcv, err := conn.ReadMessage() 391 if err != nil { 392 return nil, err 393 } 394 log.Lvlf4("Received %x", rcv) 395 c.rx += uint64(len(rcv)) 396 return rcv, nil 397 } 398 399 // SendProtobuf wraps protobuf.(En|De)code over the Client.Send-function. It 400 // takes the destination, a pointer to a msg-structure that will be 401 // protobuf-encoded and sent over the websocket. If ret is non-nil, it 402 // has to be a pointer to the struct that is sent back to the 403 // client. If there is no error, the ret-structure is filled with the 404 // data from the service. 405 func (c *Client) SendProtobuf(dst *network.ServerIdentity, msg interface{}, ret interface{}) error { 406 buf, err := protobuf.Encode(msg) 407 if err != nil { 408 return err 409 } 410 path := strings.Split(reflect.TypeOf(msg).String(), ".")[1] 411 reply, err := c.Send(dst, path, buf) 412 if err != nil { 413 return err 414 } 415 if ret != nil { 416 return protobuf.DecodeWithConstructors(reply, ret, 417 network.DefaultConstructors(c.suite)) 418 } 419 return nil 420 } 421 422 // StreamingConn allows clients to read from it without sending additional 423 // requests. 424 type StreamingConn struct { 425 conn *websocket.Conn 426 suite network.Suite 427 } 428 429 // ReadMessage read more data from the connection, it will block if there are 430 // no messages. 431 func (c *StreamingConn) ReadMessage(ret interface{}) error { 432 if err := c.conn.SetReadDeadline(time.Now().Add(5 * time.Minute)); err != nil { 433 return err 434 } 435 // No need to add bytes to counter here because this function is only 436 // called by the client. 437 _, buf, err := c.conn.ReadMessage() 438 if err != nil { 439 return err 440 } 441 return protobuf.DecodeWithConstructors(buf, ret, 442 network.DefaultConstructors(c.suite)) 443 } 444 445 // Stream will send a request to start streaming, it returns a connection where 446 // the client can continue to read values from it. 447 func (c *Client) Stream(dst *network.ServerIdentity, msg interface{}) (StreamingConn, error) { 448 buf, err := protobuf.Encode(msg) 449 if err != nil { 450 return StreamingConn{}, err 451 } 452 path := strings.Split(reflect.TypeOf(msg).String(), ".")[1] 453 454 c.Lock() 455 defer c.Unlock() 456 conn, err := c.newConnIfNotExist(dst, path) 457 if err != nil { 458 return StreamingConn{}, err 459 } 460 err = conn.WriteMessage(websocket.BinaryMessage, buf) 461 if err != nil { 462 return StreamingConn{}, err 463 } 464 c.tx += uint64(len(buf)) 465 return StreamingConn{conn, c.Suite()}, nil 466 } 467 468 // SendToAll sends a message to all ServerIdentities of the Roster and returns 469 // all errors encountered concatenated together as a string. 470 func (c *Client) SendToAll(dst *Roster, path string, buf []byte) ([][]byte, error) { 471 msgs := make([][]byte, len(dst.List)) 472 var errstrs []string 473 for i, e := range dst.List { 474 var err error 475 msgs[i], err = c.Send(e, path, buf) 476 if err != nil { 477 errstrs = append(errstrs, fmt.Sprint(e.String(), err.Error())) 478 } 479 } 480 var err error 481 if len(errstrs) > 0 { 482 err = errors.New(strings.Join(errstrs, "\n")) 483 } 484 return msgs, err 485 } 486 487 // Close sends a close-command to all open connections and returns nil if no 488 // errors occurred or all errors encountered concatenated together as a string. 489 func (c *Client) Close() error { 490 c.Lock() 491 defer c.Unlock() 492 var errstrs []string 493 for dest := range c.connections { 494 if err := c.closeConn(dest); err != nil { 495 errstrs = append(errstrs, err.Error()) 496 } 497 } 498 var err error 499 if len(errstrs) > 0 { 500 err = errors.New(strings.Join(errstrs, "\n")) 501 } 502 return err 503 } 504 505 // closeConn sends a close-command to the connection. 506 func (c *Client) closeConn(dst destination) error { 507 conn, ok := c.connections[dst] 508 if ok { 509 delete(c.connections, dst) 510 conn.WriteMessage(websocket.CloseMessage, nil) 511 return conn.Close() 512 } 513 return nil 514 } 515 516 // Tx returns the number of bytes transmitted by this Client. It implements 517 // the monitor.CounterIOMeasure interface. 518 func (c *Client) Tx() uint64 { 519 c.Lock() 520 defer c.Unlock() 521 return c.tx 522 } 523 524 // Rx returns the number of bytes read by this Client. It implements 525 // the monitor.CounterIOMeasure interface. 526 func (c *Client) Rx() uint64 { 527 c.Lock() 528 defer c.Unlock() 529 return c.rx 530 } 531 532 // getWSHostPort returns the host:port+1 of the serverIdentity. If 533 // global is true, the address is set to the unspecified 0.0.0.0-address. 534 func getWSHostPort(si *network.ServerIdentity, global bool) (string, error) { 535 p, err := strconv.Atoi(si.Address.Port()) 536 if err != nil { 537 return "", err 538 } 539 host := si.Address.Host() 540 if global { 541 host = "0.0.0.0" 542 } 543 return net.JoinHostPort(host, strconv.Itoa(p+1)), nil 544 }