go.dedis.ch/onet/v3@v3.2.11-0.20210930124529-e36530bca7ef/websocket.go (about)

     1  package onet
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"fmt"
     8  	"net/http"
     9  	"strings"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/gorilla/websocket"
    14  	"go.dedis.ch/onet/v3/log"
    15  	"go.dedis.ch/onet/v3/network"
    16  	"golang.org/x/xerrors"
    17  )
    18  
    19  const certificateReloaderLeeway = 1 * time.Hour
    20  
    21  // CertificateReloader takes care of reloading a TLS certificate when
    22  // requested.
    23  type CertificateReloader struct {
    24  	sync.RWMutex
    25  	cert     *tls.Certificate
    26  	certPath string
    27  	keyPath  string
    28  }
    29  
    30  // NewCertificateReloader takes two file paths as parameter that contain
    31  // the certificate and the key data to create an automatic reloader. It will
    32  // try to read again the files when the certificate is almost expired.
    33  func NewCertificateReloader(certPath, keyPath string) (*CertificateReloader, error) {
    34  	loader := &CertificateReloader{
    35  		certPath: certPath,
    36  		keyPath:  keyPath,
    37  	}
    38  
    39  	err := loader.reload()
    40  	if err != nil {
    41  		return nil, xerrors.Errorf("reloading certificate: %v", err)
    42  	}
    43  
    44  	return loader, nil
    45  }
    46  
    47  func (cr *CertificateReloader) reload() error {
    48  	newCert, err := tls.LoadX509KeyPair(cr.certPath, cr.keyPath)
    49  	if err != nil {
    50  		return xerrors.Errorf("load x509: %v", err)
    51  	}
    52  
    53  	cr.Lock()
    54  	cr.cert = &newCert
    55  	// Successful parse means at least one certificate.
    56  	cr.cert.Leaf, err = x509.ParseCertificate(newCert.Certificate[0])
    57  	cr.Unlock()
    58  
    59  	if err != nil {
    60  		return xerrors.Errorf("parse x509: %v", err)
    61  	}
    62  	return nil
    63  }
    64  
    65  // GetCertificateFunc makes a function that can be passed to the TLSConfig
    66  // so that it resolves the most up-to-date one.
    67  func (cr *CertificateReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
    68  	return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
    69  		cr.RLock()
    70  
    71  		exp := time.Now().Add(certificateReloaderLeeway)
    72  
    73  		// Here we know the leaf has been parsed successfully as an error
    74  		// would have been thrown otherwise.
    75  		if cr.cert == nil || exp.After(cr.cert.Leaf.NotAfter) {
    76  			// Certificate has expired so we try to load the new one.
    77  
    78  			// Free the read lock to be able to reload.
    79  			cr.RUnlock()
    80  			err := cr.reload()
    81  			if err != nil {
    82  				return nil, xerrors.Errorf("reload certificate: %v", err)
    83  			}
    84  
    85  			cr.RLock()
    86  		}
    87  
    88  		defer cr.RUnlock()
    89  		return cr.cert, nil
    90  	}
    91  }
    92  
    93  // WebSocket handles incoming client-requests using the websocket
    94  // protocol. When making a new WebSocket, it will listen one port above the
    95  // ServerIdentity-port-#.
    96  // The websocket protocol has been chosen as smallest common denominator
    97  // for languages including JavaScript.
    98  type WebSocket struct {
    99  	services  map[string]Service
   100  	server    *http.Server
   101  	mux       *http.ServeMux
   102  	startstop chan bool
   103  	started   bool
   104  	TLSConfig *tls.Config // can only be modified before Start is called
   105  	sync.Mutex
   106  }
   107  
   108  // NewWebSocket opens a webservice-listener at the given si.URL.
   109  func NewWebSocket(si *network.ServerIdentity) *WebSocket {
   110  	w := &WebSocket{
   111  		services:  make(map[string]Service),
   112  		startstop: make(chan bool),
   113  	}
   114  	webHost, err := getWSHostPort(si, true)
   115  	log.ErrFatal(err)
   116  	w.mux = http.NewServeMux()
   117  	w.mux.HandleFunc("/ok", func(w http.ResponseWriter, r *http.Request) {
   118  		log.Lvl4("ok?", r.RemoteAddr)
   119  		ok := []byte("ok\n")
   120  		w.Write(ok)
   121  	})
   122  
   123  	if allowPprof() {
   124  		log.Warn("HTTP pprof profiling is enabled")
   125  		initPprof(w.mux)
   126  	}
   127  
   128  	// Add a catch-all handler (longest paths take precedence, so "/" takes
   129  	// all non-registered paths) and correctly upgrade to a websocket and
   130  	// throw an error.
   131  	w.mux.HandleFunc("/", func(wr http.ResponseWriter, re *http.Request) {
   132  		log.Error("request from ", re.RemoteAddr, "for invalid path ", re.URL.Path)
   133  
   134  		u := websocket.Upgrader{
   135  			// The mobile app on iOS doesn't support compression well...
   136  			EnableCompression: false,
   137  			// As the website will not be served from ourselves, we
   138  			// need to accept _all_ origins. Cross-site scripting is
   139  			// required.
   140  			CheckOrigin: func(*http.Request) bool {
   141  				return true
   142  			},
   143  		}
   144  		ws, err := u.Upgrade(wr, re, http.Header{})
   145  		if err != nil {
   146  			log.Error(err)
   147  			return
   148  		}
   149  
   150  		ws.WriteControl(websocket.CloseMessage,
   151  			websocket.FormatCloseMessage(4001, "This service doesn't exist"),
   152  			time.Now().Add(time.Millisecond*500))
   153  		ws.Close()
   154  	})
   155  	w.server = &http.Server{
   156  		Addr:    webHost,
   157  		Handler: w.mux,
   158  	}
   159  	return w
   160  }
   161  
   162  // Listening returns true if the server has been started and is
   163  // listening on the ports for incoming connections.
   164  func (w *WebSocket) Listening() bool {
   165  	w.Lock()
   166  	defer w.Unlock()
   167  	return w.started
   168  }
   169  
   170  // start listening on the port.
   171  func (w *WebSocket) start() {
   172  	w.Lock()
   173  	w.started = true
   174  	w.server.TLSConfig = w.TLSConfig
   175  	log.Lvl2("Starting to listen on", w.server.Addr)
   176  	started := make(chan bool)
   177  	go func() {
   178  		// Check if server is configured for TLS
   179  		started <- true
   180  		if w.server.TLSConfig != nil && (w.server.TLSConfig.GetCertificate != nil || len(w.server.TLSConfig.Certificates) >= 1) {
   181  			w.server.ListenAndServeTLS("", "")
   182  		} else {
   183  			w.server.ListenAndServe()
   184  		}
   185  	}()
   186  	<-started
   187  	w.Unlock()
   188  	w.startstop <- true
   189  }
   190  
   191  // registerService stores a service to the given path. All requests to that
   192  // path and it's sub-endpoints will be forwarded to ProcessClientRequest.
   193  func (w *WebSocket) registerService(service string, s Service) error {
   194  	if service == "ok" {
   195  		return xerrors.New("service name \"ok\" is not allowed")
   196  	}
   197  
   198  	w.services[service] = s
   199  	h := &wsHandler{
   200  		service:     s,
   201  		serviceName: service,
   202  	}
   203  	w.mux.Handle(fmt.Sprintf("/%s/", service), h)
   204  	return nil
   205  }
   206  
   207  // stop the websocket and free the port.
   208  func (w *WebSocket) stop() {
   209  	w.Lock()
   210  	defer w.Unlock()
   211  	if !w.started {
   212  		return
   213  	}
   214  	log.Lvl3("Stopping", w.server.Addr)
   215  
   216  	d := time.Now().Add(100 * time.Millisecond)
   217  	ctx, cancel := context.WithDeadline(context.Background(), d)
   218  	w.server.Shutdown(ctx)
   219  	cancel()
   220  
   221  	<-w.startstop
   222  	w.started = false
   223  }
   224  
   225  // Pass the request to the websocket.
   226  type wsHandler struct {
   227  	serviceName string
   228  	service     Service
   229  }
   230  
   231  // Wrapper-function so that http.Requests get 'upgraded' to websockets
   232  // and handled correctly.
   233  func (t wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   234  	rx := 0
   235  	tx := 0
   236  	n := 0
   237  
   238  	defer func() {
   239  		log.Lvl2("ws close", r.RemoteAddr, "n", n, "rx", rx, "tx", tx)
   240  	}()
   241  
   242  	u := websocket.Upgrader{
   243  		// The mobile app on iOS doesn't support compression well...
   244  		EnableCompression: false,
   245  		// As the website will not be served from ourselves, we
   246  		// need to accept _all_ origins. Cross-site scripting is
   247  		// required.
   248  		CheckOrigin: func(*http.Request) bool {
   249  			return true
   250  		},
   251  	}
   252  	ws, err := u.Upgrade(w, r, http.Header{})
   253  	if err != nil {
   254  		log.Error(err)
   255  		return
   256  	}
   257  	defer ws.Close()
   258  
   259  	// Loop for each message
   260  outerReadLoop:
   261  	for err == nil {
   262  		mt, buf, rerr := ws.ReadMessage()
   263  		if rerr != nil {
   264  			err = rerr
   265  			break
   266  		}
   267  		rx += len(buf)
   268  		n++
   269  
   270  		s := t.service
   271  		var reply []byte
   272  		var outChan chan []byte
   273  		path := strings.TrimPrefix(r.URL.Path, "/"+t.serviceName+"/")
   274  		log.Lvlf2("ws request from %s: %s/%s", r.RemoteAddr, t.serviceName, path)
   275  
   276  		isStreaming := false
   277  		bidirectionalStreamer, ok := s.(BidirectionalStreamer)
   278  		if ok {
   279  			isStreaming, err = bidirectionalStreamer.IsStreaming(path)
   280  			if err != nil {
   281  				log.Errorf("failed to check if it is a streaming "+
   282  					"request %s/%s: %+v", t.serviceName, path, err)
   283  				continue
   284  			}
   285  		}
   286  
   287  		if !isStreaming {
   288  			reply, _, err = s.ProcessClientRequest(r, path, buf)
   289  			if err != nil {
   290  				log.Errorf("Got an error while executing %s/%s: %+v",
   291  					t.serviceName, path, err)
   292  				continue
   293  			}
   294  
   295  			tx += len(reply)
   296  			err = ws.SetWriteDeadline(time.Now().Add(5 * time.Minute))
   297  			if err != nil {
   298  				log.Error(xerrors.Errorf("failed to set the write deadline "+
   299  					"with request request %s/%s: %v", t.serviceName, path, err))
   300  				break
   301  			}
   302  
   303  			err = ws.WriteMessage(mt, reply)
   304  			if err != nil {
   305  				log.Error(xerrors.Errorf("failed to write message with "+
   306  					"request %s/%s: %v", t.serviceName, path, err))
   307  				break
   308  			}
   309  
   310  			continue
   311  		}
   312  
   313  		clientInputs := make(chan []byte, 10)
   314  		clientInputs <- buf
   315  		outChan, err = bidirectionalStreamer.ProcessClientStreamRequest(r,
   316  			path, clientInputs)
   317  		if err != nil {
   318  			log.Errorf("got an error while processing streaming "+
   319  				"request %s/%s: %+v", t.serviceName, path, err)
   320  			continue
   321  		}
   322  
   323  		closing := make(chan bool)
   324  		go func() {
   325  			for {
   326  				// Listen for incoming messages to know if the client wants to
   327  				// close the stream. If this is an error, we assume the client
   328  				// wants to close the stream, otherwise we forward the message
   329  				// to the service.
   330  				_, buf, err := ws.ReadMessage()
   331  				if err != nil {
   332  					close(closing)
   333  					return
   334  				}
   335  				clientInputs <- buf
   336  			}
   337  		}()
   338  
   339  		for {
   340  			select {
   341  			case <-closing:
   342  				close(clientInputs)
   343  				break outerReadLoop
   344  			case reply, ok := <-outChan:
   345  				if !ok {
   346  					ws.WriteControl(websocket.CloseMessage,
   347  						websocket.FormatCloseMessage(websocket.CloseNormalClosure, "service finished streaming"),
   348  						time.Now().Add(time.Millisecond*500))
   349  					close(clientInputs)
   350  					return
   351  				}
   352  				tx += len(reply)
   353  
   354  				err = ws.SetWriteDeadline(time.Now().Add(5 * time.Minute))
   355  				if err != nil {
   356  					log.Error(xerrors.Errorf("failed to set the write "+
   357  						"deadline in the streaming loop: %v", err))
   358  					close(clientInputs)
   359  					break outerReadLoop
   360  				}
   361  
   362  				err = ws.WriteMessage(mt, reply)
   363  				if err != nil {
   364  					log.Error(xerrors.Errorf("failed to write next message "+
   365  						"in the streaming loop: %v", err))
   366  					close(clientInputs)
   367  					break outerReadLoop
   368  				}
   369  			}
   370  		}
   371  
   372  	}
   373  
   374  	errMessage := "unexpected error: "
   375  	if err != nil {
   376  		errMessage += err.Error()
   377  	}
   378  
   379  	ws.WriteControl(websocket.CloseMessage,
   380  		websocket.FormatCloseMessage(websocket.CloseProtocolError, errMessage),
   381  		time.Now().Add(time.Millisecond*500))
   382  	return
   383  }
   384  
   385  type destination struct {
   386  	si   *network.ServerIdentity
   387  	path string
   388  }