gopkg.in/dedis/onet.v2@v2.0.0-20181115163211-c8f3724038a7/websocket.go (about)

     1  package onet
     2  
     3  import (
     4  	"crypto/tls"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"net/http"
     9  	"net/url"
    10  	"reflect"
    11  	"strconv"
    12  	"strings"
    13  	"sync"
    14  	"time"
    15  
    16  	"github.com/dedis/protobuf"
    17  	"github.com/gorilla/websocket"
    18  	"gopkg.in/dedis/onet.v2/log"
    19  	"gopkg.in/dedis/onet.v2/network"
    20  	"gopkg.in/tylerb/graceful.v1"
    21  )
    22  
    23  // WebSocket handles incoming client-requests using the websocket
    24  // protocol. When making a new WebSocket, it will listen one port above the
    25  // ServerIdentity-port-#.
    26  // The websocket protocol has been chosen as smallest common denominator
    27  // for languages including JavaScript.
    28  type WebSocket struct {
    29  	services  map[string]Service
    30  	server    *graceful.Server
    31  	mux       *http.ServeMux
    32  	startstop chan bool
    33  	started   bool
    34  	TLSConfig *tls.Config // can only be modified before Start is called
    35  	sync.Mutex
    36  }
    37  
    38  // NewWebSocket opens a webservice-listener one port above the given
    39  // ServerIdentity.
    40  func NewWebSocket(si *network.ServerIdentity) *WebSocket {
    41  	w := &WebSocket{
    42  		services:  make(map[string]Service),
    43  		startstop: make(chan bool),
    44  	}
    45  	webHost, err := getWSHostPort(si, true)
    46  	log.ErrFatal(err)
    47  	w.mux = http.NewServeMux()
    48  	w.mux.HandleFunc("/ok", func(w http.ResponseWriter, r *http.Request) {
    49  		log.Lvl4("ok?", r.RemoteAddr)
    50  		ok := []byte("ok\n")
    51  		w.Write(ok)
    52  	})
    53  
    54  	// Add a catch-all handler (longest paths take precedence, so "/" takes
    55  	// all non-registered paths) and correctly upgrade to a websocket and
    56  	// throw an error.
    57  	w.mux.HandleFunc("/", func(wr http.ResponseWriter, re *http.Request) {
    58  		log.Error("request from ", re.RemoteAddr, "for invalid path ", re.URL.Path)
    59  
    60  		u := websocket.Upgrader{
    61  			EnableCompression: true,
    62  			// As the website will not be served from ourselves, we
    63  			// need to accept _all_ origins. Cross-site scripting is
    64  			// required.
    65  			CheckOrigin: func(*http.Request) bool {
    66  				return true
    67  			},
    68  		}
    69  		ws, err := u.Upgrade(wr, re, http.Header{})
    70  		if err != nil {
    71  			log.Error(err)
    72  			return
    73  		}
    74  
    75  		ws.WriteControl(websocket.CloseMessage,
    76  			websocket.FormatCloseMessage(4001, "This service doesn't exist"),
    77  			time.Now().Add(time.Millisecond*500))
    78  		ws.Close()
    79  	})
    80  	w.server = &graceful.Server{
    81  		Timeout: 100 * time.Millisecond,
    82  		Server: &http.Server{
    83  			Addr:    webHost,
    84  			Handler: w.mux,
    85  		},
    86  		NoSignalHandling: true,
    87  	}
    88  	return w
    89  }
    90  
    91  // Listening returns true if the server has been started and is
    92  // listening on the ports for incoming connections.
    93  func (w *WebSocket) Listening() bool {
    94  	w.Lock()
    95  	defer w.Unlock()
    96  	return w.started
    97  }
    98  
    99  // start listening on the port.
   100  func (w *WebSocket) start() {
   101  	w.Lock()
   102  	w.started = true
   103  	w.server.Server.TLSConfig = w.TLSConfig
   104  	log.Lvl2("Starting to listen on", w.server.Server.Addr)
   105  	started := make(chan bool)
   106  	go func() {
   107  		// Check if server is configured for TLS
   108  		started <- true
   109  		if w.server.Server.TLSConfig != nil && len(w.server.Server.TLSConfig.Certificates) >= 1 {
   110  			w.server.ListenAndServeTLS("", "")
   111  		} else {
   112  			w.server.ListenAndServe()
   113  		}
   114  	}()
   115  	<-started
   116  	w.Unlock()
   117  	w.startstop <- true
   118  }
   119  
   120  // registerService stores a service to the given path. All requests to that
   121  // path and it's sub-endpoints will be forwarded to ProcessClientRequest.
   122  func (w *WebSocket) registerService(service string, s Service) error {
   123  	if service == "ok" {
   124  		return errors.New("service name \"ok\" is not allowed")
   125  	}
   126  
   127  	w.services[service] = s
   128  	h := &wsHandler{
   129  		service:     s,
   130  		serviceName: service,
   131  	}
   132  	w.mux.Handle(fmt.Sprintf("/%s/", service), h)
   133  	return nil
   134  }
   135  
   136  // stop the websocket and free the port.
   137  func (w *WebSocket) stop() {
   138  	w.Lock()
   139  	defer w.Unlock()
   140  	if !w.started {
   141  		return
   142  	}
   143  	log.Lvl3("Stopping", w.server.Server.Addr)
   144  	w.server.Stop(100 * time.Millisecond)
   145  	<-w.startstop
   146  	w.started = false
   147  }
   148  
   149  // Pass the request to the websocket.
   150  type wsHandler struct {
   151  	serviceName string
   152  	service     Service
   153  }
   154  
   155  // Wrapper-function so that http.Requests get 'upgraded' to websockets
   156  // and handled correctly.
   157  func (t wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   158  	rx := 0
   159  	tx := 0
   160  	n := 0
   161  
   162  	defer func() {
   163  		log.Lvl2("ws close", r.RemoteAddr, "n", n, "rx", rx, "tx", tx)
   164  	}()
   165  
   166  	u := websocket.Upgrader{
   167  		EnableCompression: true,
   168  		// As the website will not be served from ourselves, we
   169  		// need to accept _all_ origins. Cross-site scripting is
   170  		// required.
   171  		CheckOrigin: func(*http.Request) bool {
   172  			return true
   173  		},
   174  	}
   175  	ws, err := u.Upgrade(w, r, http.Header{})
   176  	if err != nil {
   177  		log.Error(err)
   178  		return
   179  	}
   180  	defer func() {
   181  		ws.Close()
   182  	}()
   183  
   184  	// Loop for each message
   185  outerReadLoop:
   186  	for err == nil {
   187  		mt, buf, rerr := ws.ReadMessage()
   188  		if rerr != nil {
   189  			err = rerr
   190  			break
   191  		}
   192  		rx += len(buf)
   193  		n++
   194  
   195  		s := t.service
   196  		var reply []byte
   197  		var tun *StreamingTunnel
   198  		path := strings.TrimPrefix(r.URL.Path, "/"+t.serviceName+"/")
   199  		log.Lvlf2("ws request from %s: %s/%s", r.RemoteAddr, t.serviceName, path)
   200  		reply, tun, err = s.ProcessClientRequest(r, path, buf)
   201  		if err == nil {
   202  			if tun == nil {
   203  				tx += len(reply)
   204  				if err := ws.SetWriteDeadline(time.Now().Add(5 * time.Minute)); err != nil {
   205  					log.Error(err)
   206  					break
   207  				}
   208  				if err := ws.WriteMessage(mt, reply); err != nil {
   209  					log.Error(err)
   210  					break
   211  				}
   212  			} else {
   213  				for {
   214  					select {
   215  					case reply, ok := <-tun.out:
   216  						if !ok {
   217  							err = errors.New("service finished streaming")
   218  							close(tun.close)
   219  							break outerReadLoop
   220  						}
   221  						tx += len(reply)
   222  						if err = ws.SetWriteDeadline(time.Now().Add(5 * time.Minute)); err != nil {
   223  							log.Error(err)
   224  							close(tun.close)
   225  							break outerReadLoop
   226  						}
   227  						if err = ws.WriteMessage(mt, reply); err != nil {
   228  							log.Error(err)
   229  							close(tun.close)
   230  							break outerReadLoop
   231  						}
   232  					}
   233  				}
   234  			}
   235  		} else {
   236  			log.Errorf("Got an error while executing %s/%s: %s", t.serviceName, path, err.Error())
   237  		}
   238  	}
   239  
   240  	ws.WriteControl(websocket.CloseMessage,
   241  		websocket.FormatCloseMessage(4000, err.Error()),
   242  		time.Now().Add(time.Millisecond*500))
   243  	return
   244  }
   245  
   246  type destination struct {
   247  	si   *network.ServerIdentity
   248  	path string
   249  }
   250  
   251  // Client is a struct used to communicate with a remote Service running on a
   252  // onet.Server. Using Send it can connect to multiple remote Servers.
   253  type Client struct {
   254  	service     string
   255  	connections map[destination]*websocket.Conn
   256  	suite       network.Suite
   257  	// if not nil, use TLS
   258  	TLSClientConfig *tls.Config
   259  	// whether to keep the connection
   260  	keep bool
   261  	rx   uint64
   262  	tx   uint64
   263  	sync.Mutex
   264  }
   265  
   266  // NewClient returns a client using the service s. On the first Send, the
   267  // connection will be started, until Close is called.
   268  func NewClient(suite network.Suite, s string) *Client {
   269  	return &Client{
   270  		service:     s,
   271  		connections: make(map[destination]*websocket.Conn),
   272  		suite:       suite,
   273  	}
   274  }
   275  
   276  // NewClientKeep returns a Client that doesn't close the connection between
   277  // two messages if it's the same server.
   278  func NewClientKeep(suite network.Suite, s string) *Client {
   279  	return &Client{
   280  		service:     s,
   281  		keep:        true,
   282  		connections: make(map[destination]*websocket.Conn),
   283  		suite:       suite,
   284  	}
   285  }
   286  
   287  // Suite returns the cryptographic suite in use on this connection.
   288  func (c *Client) Suite() network.Suite {
   289  	return c.suite
   290  }
   291  
   292  func (c *Client) closeSingleUseConn(dst *network.ServerIdentity, path string) {
   293  	dest := destination{dst, path}
   294  	if !c.keep {
   295  		if err := c.closeConn(dest); err != nil {
   296  			log.Errorf("error while closing the connection to %v : %v\n", dest, err)
   297  		}
   298  	}
   299  }
   300  
   301  func (c *Client) newConnIfNotExist(dst *network.ServerIdentity, path string) (*websocket.Conn, error) {
   302  	var err error
   303  
   304  	// TODO we are opening a new connection for every new path?
   305  	// not possible to use an existing connection for the same service?
   306  	dest := destination{dst, path}
   307  	conn, ok := c.connections[dest]
   308  
   309  	if !ok {
   310  		d := &websocket.Dialer{}
   311  		d.TLSClientConfig = c.TLSClientConfig
   312  
   313  		var serverURL string
   314  		var header http.Header
   315  
   316  		// If the URL is in the dst, then use it.
   317  		if dst.URL != "" {
   318  			u, err := url.Parse(dst.URL)
   319  			if err != nil {
   320  				return nil, err
   321  			}
   322  			if u.Scheme == "https" {
   323  				u.Scheme = "wss"
   324  			} else {
   325  				u.Scheme = "ws"
   326  			}
   327  			u.Path += "/" + c.service + "/" + path
   328  			serverURL = u.String()
   329  			header = http.Header{"Origin": []string{dst.URL}}
   330  		} else {
   331  			// Open connection to service.
   332  			hp, err := getWSHostPort(dst, false)
   333  			if err != nil {
   334  				return nil, err
   335  			}
   336  
   337  			var wsProtocol string
   338  			var protocol string
   339  
   340  			// The old hacky way of deciding if this server has HTTPS or not:
   341  			// the client somehow magically knows and tells onet by setting
   342  			// c.TLSClientConfig to a non-nil value.
   343  			if c.TLSClientConfig != nil {
   344  				wsProtocol = "wss"
   345  				protocol = "https"
   346  			} else {
   347  				wsProtocol = "ws"
   348  				protocol = "http"
   349  			}
   350  			serverURL = fmt.Sprintf("%s://%s/%s/%s", wsProtocol, hp, c.service, path)
   351  			header = http.Header{"Origin": []string{protocol + "://" + hp}}
   352  		}
   353  
   354  		// Re-try to connect in case the websocket is just about to start
   355  		for a := 0; a < network.MaxRetryConnect; a++ {
   356  			conn, _, err = d.Dial(serverURL, header)
   357  			if err == nil {
   358  				break
   359  			}
   360  			time.Sleep(network.WaitRetry)
   361  		}
   362  		if err != nil {
   363  			return nil, err
   364  		}
   365  		c.connections[dest] = conn
   366  	}
   367  	return conn, nil
   368  }
   369  
   370  // Send will marshal the message into a ClientRequest message and send it.
   371  func (c *Client) Send(dst *network.ServerIdentity, path string, buf []byte) ([]byte, error) {
   372  	c.Lock()
   373  	defer c.Unlock()
   374  
   375  	conn, err := c.newConnIfNotExist(dst, path)
   376  	if err != nil {
   377  		return nil, err
   378  	}
   379  	defer c.closeSingleUseConn(dst, path)
   380  
   381  	log.Lvlf4("Sending %x to %s/%s", buf, c.service, path)
   382  	if err := conn.WriteMessage(websocket.BinaryMessage, buf); err != nil {
   383  		return nil, err
   384  	}
   385  	c.tx += uint64(len(buf))
   386  
   387  	if err := conn.SetReadDeadline(time.Now().Add(5 * time.Minute)); err != nil {
   388  		return nil, err
   389  	}
   390  	_, rcv, err := conn.ReadMessage()
   391  	if err != nil {
   392  		return nil, err
   393  	}
   394  	log.Lvlf4("Received %x", rcv)
   395  	c.rx += uint64(len(rcv))
   396  	return rcv, nil
   397  }
   398  
   399  // SendProtobuf wraps protobuf.(En|De)code over the Client.Send-function. It
   400  // takes the destination, a pointer to a msg-structure that will be
   401  // protobuf-encoded and sent over the websocket. If ret is non-nil, it
   402  // has to be a pointer to the struct that is sent back to the
   403  // client. If there is no error, the ret-structure is filled with the
   404  // data from the service.
   405  func (c *Client) SendProtobuf(dst *network.ServerIdentity, msg interface{}, ret interface{}) error {
   406  	buf, err := protobuf.Encode(msg)
   407  	if err != nil {
   408  		return err
   409  	}
   410  	path := strings.Split(reflect.TypeOf(msg).String(), ".")[1]
   411  	reply, err := c.Send(dst, path, buf)
   412  	if err != nil {
   413  		return err
   414  	}
   415  	if ret != nil {
   416  		return protobuf.DecodeWithConstructors(reply, ret,
   417  			network.DefaultConstructors(c.suite))
   418  	}
   419  	return nil
   420  }
   421  
   422  // StreamingConn allows clients to read from it without sending additional
   423  // requests.
   424  type StreamingConn struct {
   425  	conn  *websocket.Conn
   426  	suite network.Suite
   427  }
   428  
   429  // ReadMessage read more data from the connection, it will block if there are
   430  // no messages.
   431  func (c *StreamingConn) ReadMessage(ret interface{}) error {
   432  	if err := c.conn.SetReadDeadline(time.Now().Add(5 * time.Minute)); err != nil {
   433  		return err
   434  	}
   435  	// No need to add bytes to counter here because this function is only
   436  	// called by the client.
   437  	_, buf, err := c.conn.ReadMessage()
   438  	if err != nil {
   439  		return err
   440  	}
   441  	return protobuf.DecodeWithConstructors(buf, ret,
   442  		network.DefaultConstructors(c.suite))
   443  }
   444  
   445  // Stream will send a request to start streaming, it returns a connection where
   446  // the client can continue to read values from it.
   447  func (c *Client) Stream(dst *network.ServerIdentity, msg interface{}) (StreamingConn, error) {
   448  	buf, err := protobuf.Encode(msg)
   449  	if err != nil {
   450  		return StreamingConn{}, err
   451  	}
   452  	path := strings.Split(reflect.TypeOf(msg).String(), ".")[1]
   453  
   454  	c.Lock()
   455  	defer c.Unlock()
   456  	conn, err := c.newConnIfNotExist(dst, path)
   457  	if err != nil {
   458  		return StreamingConn{}, err
   459  	}
   460  	err = conn.WriteMessage(websocket.BinaryMessage, buf)
   461  	if err != nil {
   462  		return StreamingConn{}, err
   463  	}
   464  	c.tx += uint64(len(buf))
   465  	return StreamingConn{conn, c.Suite()}, nil
   466  }
   467  
   468  // SendToAll sends a message to all ServerIdentities of the Roster and returns
   469  // all errors encountered concatenated together as a string.
   470  func (c *Client) SendToAll(dst *Roster, path string, buf []byte) ([][]byte, error) {
   471  	msgs := make([][]byte, len(dst.List))
   472  	var errstrs []string
   473  	for i, e := range dst.List {
   474  		var err error
   475  		msgs[i], err = c.Send(e, path, buf)
   476  		if err != nil {
   477  			errstrs = append(errstrs, fmt.Sprint(e.String(), err.Error()))
   478  		}
   479  	}
   480  	var err error
   481  	if len(errstrs) > 0 {
   482  		err = errors.New(strings.Join(errstrs, "\n"))
   483  	}
   484  	return msgs, err
   485  }
   486  
   487  // Close sends a close-command to all open connections and returns nil if no
   488  // errors occurred or all errors encountered concatenated together as a string.
   489  func (c *Client) Close() error {
   490  	c.Lock()
   491  	defer c.Unlock()
   492  	var errstrs []string
   493  	for dest := range c.connections {
   494  		if err := c.closeConn(dest); err != nil {
   495  			errstrs = append(errstrs, err.Error())
   496  		}
   497  	}
   498  	var err error
   499  	if len(errstrs) > 0 {
   500  		err = errors.New(strings.Join(errstrs, "\n"))
   501  	}
   502  	return err
   503  }
   504  
   505  // closeConn sends a close-command to the connection.
   506  func (c *Client) closeConn(dst destination) error {
   507  	conn, ok := c.connections[dst]
   508  	if ok {
   509  		delete(c.connections, dst)
   510  		conn.WriteMessage(websocket.CloseMessage, nil)
   511  		return conn.Close()
   512  	}
   513  	return nil
   514  }
   515  
   516  // Tx returns the number of bytes transmitted by this Client. It implements
   517  // the monitor.CounterIOMeasure interface.
   518  func (c *Client) Tx() uint64 {
   519  	c.Lock()
   520  	defer c.Unlock()
   521  	return c.tx
   522  }
   523  
   524  // Rx returns the number of bytes read by this Client. It implements
   525  // the monitor.CounterIOMeasure interface.
   526  func (c *Client) Rx() uint64 {
   527  	c.Lock()
   528  	defer c.Unlock()
   529  	return c.rx
   530  }
   531  
   532  // getWSHostPort returns the host:port+1 of the serverIdentity. If
   533  // global is true, the address is set to the unspecified 0.0.0.0-address.
   534  func getWSHostPort(si *network.ServerIdentity, global bool) (string, error) {
   535  	p, err := strconv.Atoi(si.Address.Port())
   536  	if err != nil {
   537  		return "", err
   538  	}
   539  	host := si.Address.Host()
   540  	if global {
   541  		host = "0.0.0.0"
   542  	}
   543  	return net.JoinHostPort(host, strconv.Itoa(p+1)), nil
   544  }