github.com/stampzilla/stampzilla-go@v2.0.0-rc9+incompatible/pkg/node/node.go (about)

     1  package node
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"crypto/rsa"
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"crypto/x509/pkix"
    10  	"encoding/json"
    11  	"encoding/pem"
    12  	"fmt"
    13  	"io/ioutil"
    14  	"log"
    15  	"net/http"
    16  	"os"
    17  	"os/signal"
    18  	"sync"
    19  	"syscall"
    20  	"time"
    21  
    22  	"github.com/google/uuid"
    23  	"github.com/sirupsen/logrus"
    24  	"github.com/stampzilla/stampzilla-go/nodes/stampzilla-server/models"
    25  	"github.com/stampzilla/stampzilla-go/nodes/stampzilla-server/models/devices"
    26  	"github.com/stampzilla/stampzilla-go/pkg/build"
    27  	"github.com/stampzilla/stampzilla-go/pkg/websocket"
    28  )
    29  
    30  // OnFunc is used in all the callbacks
    31  type OnFunc func(json.RawMessage) error
    32  
    33  // Node is the main struct
    34  type Node struct {
    35  	UUID     string
    36  	Type     string
    37  	Version  string
    38  	Protocol string
    39  
    40  	Client websocket.Websocket
    41  	//DisconnectClient context.CancelFunc
    42  	wg         *sync.WaitGroup
    43  	Config     *models.Config
    44  	X509       *x509.Certificate
    45  	TLS        *tls.Certificate
    46  	CA         *x509.CertPool
    47  	callbacks  map[string][]OnFunc
    48  	Devices    *devices.List
    49  	shutdown   []func()
    50  	stop       chan struct{}
    51  	sendUpdate chan devices.ID
    52  }
    53  
    54  // New returns a new Node
    55  func New(t string) *Node {
    56  	client := websocket.New()
    57  	node := NewWithClient(client)
    58  	node.Type = t
    59  
    60  	node.setup()
    61  
    62  	return node
    63  }
    64  
    65  // NewWithClient returns a new Node with a custom websocket client
    66  func NewWithClient(client websocket.Websocket) *Node {
    67  	return &Node{
    68  		Client:     client,
    69  		wg:         &sync.WaitGroup{},
    70  		callbacks:  make(map[string][]OnFunc),
    71  		Devices:    devices.NewList(),
    72  		stop:       make(chan struct{}),
    73  		sendUpdate: make(chan devices.ID),
    74  	}
    75  }
    76  
    77  // Stop will shutdown the node similar to a SIGTERM
    78  func (n *Node) Stop() {
    79  	close(n.stop)
    80  }
    81  
    82  // Stopped is closed when the node is stopped by n.Stop or os signal
    83  func (n *Node) Stopped() <-chan struct{} {
    84  	return n.stop
    85  }
    86  
    87  // Wait for node to be done after shutdown
    88  func (n *Node) Wait() {
    89  	n.Client.Wait()
    90  	n.wg.Wait()
    91  }
    92  
    93  func (n *Node) setup() {
    94  	logrus.SetReportCaller(true)
    95  	logrus.SetFormatter(&logrus.TextFormatter{TimestampFormat: time.RFC3339Nano, FullTimestamp: true})
    96  
    97  	//Make sure we have a config
    98  	n.Config = &models.Config{}
    99  	n.Config.MustLoad()
   100  
   101  	if n.Config.Version {
   102  		fmt.Println(build.String())
   103  		os.Exit(1)
   104  	}
   105  	n.Version = build.String()
   106  
   107  	if n.Config.LogLevel != "" {
   108  		lvl, err := logrus.ParseLevel(n.Config.LogLevel)
   109  		if err != nil {
   110  			logrus.Fatal(err)
   111  			return
   112  		}
   113  		logrus.SetLevel(lvl)
   114  	}
   115  
   116  	//n.Config.Save("config.json")
   117  }
   118  
   119  // WriteMessage writes a message to the server over websocket client
   120  func (n *Node) WriteMessage(msgType string, data interface{}) error {
   121  	msg, err := models.NewMessage(msgType, data)
   122  	logrus.WithFields(logrus.Fields{
   123  		"type": msgType,
   124  		"body": data,
   125  	}).Tracef("Send to server")
   126  	if err != nil {
   127  		return err
   128  	}
   129  	return n.Client.WriteJSON(msg)
   130  }
   131  
   132  // WaitForMessage is a helper method to wait for a specific message type
   133  func (n *Node) WaitForMessage(msgType string, dst interface{}) error {
   134  
   135  	for data := range n.Client.Read() {
   136  		msg, err := models.ParseMessage(data)
   137  		if err != nil {
   138  			return err
   139  		}
   140  		if msg.Type == msgType {
   141  			return json.Unmarshal(msg.Body, dst)
   142  		}
   143  	}
   144  	return nil
   145  }
   146  
   147  func (n *Node) fetchCertificate() error {
   148  	// Start with creating a CSR and assign a UUID
   149  	csr, err := n.generateCSR()
   150  	if err != nil {
   151  		return err
   152  	}
   153  
   154  	u := fmt.Sprintf("ws://%s:%s/ws", n.Config.Host, n.Config.Port)
   155  	ctx, cancel := context.WithCancel(context.Background())
   156  	defer cancel()
   157  
   158  	interrupt := make(chan os.Signal, 1)
   159  	signal.Notify(interrupt, os.Interrupt, syscall.SIGQUIT, syscall.SIGTERM)
   160  	go func() {
   161  		select {
   162  		case <-interrupt:
   163  			close(n.stop)
   164  		case <-n.stop:
   165  		}
   166  		cancel()
   167  		go func() {
   168  			<-time.After(time.Second * 10)
   169  			log.Fatal("force shutdown after 10 seconds")
   170  		}()
   171  		for _, f := range n.shutdown {
   172  			f()
   173  		}
   174  	}()
   175  
   176  	n.connect(ctx, u)
   177  
   178  	// wait for server info so we can update our config
   179  	serverInfo := &models.ServerInfo{}
   180  	err = n.WaitForMessage("server-info", serverInfo)
   181  	if err != nil {
   182  		return err
   183  	}
   184  	n.Config.Port = serverInfo.Port
   185  	n.Config.TLSPort = serverInfo.TLSPort
   186  
   187  	n.WriteMessage("certificate-signing-request", models.Request{
   188  		Type:    n.Type,
   189  		Version: n.Version,
   190  		CSR:     string(csr),
   191  	})
   192  	if err != nil {
   193  		return err
   194  	}
   195  
   196  	// wait for our new certificate
   197  
   198  	var rawCert string
   199  	err = n.WaitForMessage("approved-certificate-signing-request", &rawCert)
   200  
   201  	err = ioutil.WriteFile("crt.crt", []byte(rawCert), 0644)
   202  	if err != nil {
   203  		return err
   204  	}
   205  
   206  	var caCert string
   207  	err = n.WaitForMessage("certificate-authority", &caCert)
   208  
   209  	err = ioutil.WriteFile("ca.crt", []byte(caCert), 0644)
   210  	if err != nil {
   211  		return err
   212  	}
   213  
   214  	logrus.Info("Disconnect inseure connection")
   215  	cancel()
   216  	n.Wait()
   217  
   218  	// We should have a certificate now. Try to load it
   219  	return n.loadCertificateKeyPair("crt")
   220  }
   221  
   222  // Connect starts the node and makes connection to the server. Normally discovered using mdns but can be configured aswell.
   223  func (n *Node) Connect() error {
   224  
   225  	if n.Config.Host == "" {
   226  		ip, port := queryMDNS()
   227  		n.Config.Host = ip
   228  		n.Config.Port = port
   229  	}
   230  
   231  	// Load our signed certificate and get our UUID
   232  	err := n.loadCertificateKeyPair("crt")
   233  
   234  	if err != nil {
   235  		logrus.Error("Error trying to load certificate: ", err)
   236  		err = n.fetchCertificate()
   237  		if err != nil {
   238  			return err
   239  		}
   240  	}
   241  
   242  	//If we have certificate we can connect to TLS immedietly
   243  	tlsConfig := &tls.Config{
   244  		Certificates: []tls.Certificate{*n.TLS},
   245  		RootCAs:      n.CA,
   246  		ServerName:   "localhost",
   247  	}
   248  
   249  	n.Client.SetTLSConfig(tlsConfig)
   250  
   251  	u := fmt.Sprintf("wss://%s:%s/ws", n.Config.Host, n.Config.TLSPort)
   252  	ctx, cancel := context.WithCancel(context.Background())
   253  
   254  	interrupt := make(chan os.Signal, 1)
   255  	signal.Notify(interrupt, os.Interrupt, syscall.SIGQUIT, syscall.SIGTERM)
   256  	n.wg.Add(1)
   257  	go func() {
   258  		defer n.wg.Done()
   259  		select {
   260  		case <-interrupt:
   261  			close(n.stop)
   262  		case <-n.stop:
   263  		}
   264  		cancel()
   265  		go func() {
   266  			<-time.After(time.Second * 10)
   267  			log.Fatal("force shutdown after 10 seconds")
   268  		}()
   269  		for _, f := range n.shutdown {
   270  			f()
   271  		}
   272  	}()
   273  
   274  	n.Client.OnConnect(func() {
   275  		for what := range n.callbacks {
   276  			n.Subscribe(what)
   277  		}
   278  		n.SyncDevices()
   279  	})
   280  	n.connect(ctx, u)
   281  	n.wg.Add(1)
   282  	go n.reader(ctx)
   283  	go n.syncWorker()
   284  	return nil
   285  }
   286  
   287  func (n *Node) reader(ctx context.Context) {
   288  	defer n.wg.Done()
   289  	for {
   290  		select {
   291  		case <-ctx.Done():
   292  			logrus.Info("Stopping node reader because:", ctx.Err())
   293  			return
   294  		case data := <-n.Client.Read():
   295  			msg, err := models.ParseMessage(data)
   296  			if err != nil {
   297  				logrus.Error("node:", err)
   298  				continue
   299  			}
   300  			for _, cb := range n.callbacks[msg.Type] {
   301  				err := cb(msg.Body)
   302  				if err != nil {
   303  					logrus.Error(err)
   304  					continue
   305  				}
   306  			}
   307  			if n.callbacks[msg.Type] == nil || len(n.callbacks[msg.Type]) == 0 {
   308  				logrus.WithFields(logrus.Fields{
   309  					"type": msg.Type,
   310  				}).Warn("Received message but no one cared")
   311  			}
   312  		}
   313  	}
   314  }
   315  
   316  func (n *Node) connect(ctx context.Context, addr string) {
   317  	headers := http.Header{}
   318  	headers.Add("X-UUID", n.UUID)
   319  	headers.Add("X-TYPE", n.Type)
   320  	headers.Set("Sec-WebSocket-Protocol", n.Protocol)
   321  	if n.Protocol == "" {
   322  		headers.Set("Sec-WebSocket-Protocol", "node")
   323  	}
   324  	n.Client.ConnectWithRetry(ctx, addr, headers)
   325  }
   326  
   327  func (n *Node) loadCertificateKeyPair(name string) error {
   328  	certTLS, err := tls.LoadX509KeyPair(name+".crt", name+".key")
   329  	if err != nil {
   330  		return err
   331  	}
   332  	certX509, err := x509.ParseCertificate(certTLS.Certificate[0])
   333  	if err != nil {
   334  		return err
   335  	}
   336  
   337  	n.TLS = &certTLS
   338  	n.X509 = certX509
   339  	n.UUID = certX509.Subject.CommonName
   340  
   341  	// Load CA cert
   342  	caCert, err := ioutil.ReadFile("ca.crt")
   343  	if err != nil {
   344  		log.Fatal(err)
   345  	}
   346  	caCertPool := x509.NewCertPool()
   347  	caCertPool.AppendCertsFromPEM(caCert)
   348  	n.CA = caCertPool
   349  
   350  	return nil
   351  }
   352  
   353  func (n *Node) loadOrGenerateKey() (*rsa.PrivateKey, error) {
   354  	data, err := ioutil.ReadFile("crt.key")
   355  	if err != nil {
   356  		if os.IsNotExist(err) {
   357  			return n.generateKey()
   358  		}
   359  		return nil, err
   360  	}
   361  	block, _ := pem.Decode(data)
   362  	return x509.ParsePKCS1PrivateKey(block.Bytes)
   363  }
   364  
   365  func (n *Node) generateKey() (*rsa.PrivateKey, error) {
   366  	priv, _ := rsa.GenerateKey(rand.Reader, 2048)
   367  	keyOut, err := os.OpenFile("crt.key", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
   368  	if err != nil {
   369  		return nil, err
   370  	}
   371  	err = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
   372  	keyOut.Close()
   373  	return priv, err
   374  }
   375  
   376  func (n *Node) generateCSR() ([]byte, error) {
   377  	hostname, err := os.Hostname()
   378  	if err != nil {
   379  		return nil, err
   380  	}
   381  
   382  	subj := pkix.Name{
   383  		CommonName:         uuid.New().String(),
   384  		Organization:       []string{"stampzilla-go"},
   385  		OrganizationalUnit: []string{hostname, n.Type},
   386  	}
   387  
   388  	template := x509.CertificateRequest{
   389  		Subject:            subj,
   390  		SignatureAlgorithm: x509.SHA256WithRSA,
   391  	}
   392  
   393  	priv, err := n.loadOrGenerateKey()
   394  	if err != nil {
   395  		return nil, err
   396  	}
   397  
   398  	csrBytes, _ := x509.CreateCertificateRequest(rand.Reader, &template, priv)
   399  	d := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csrBytes})
   400  
   401  	n.UUID = template.Subject.CommonName
   402  
   403  	return d, nil
   404  }
   405  
   406  // On sets up a callback that is run when a message recieved with type what
   407  func (n *Node) On(what string, cb OnFunc) {
   408  	n.Subscribe(what)
   409  	n.callbacks[what] = append(n.callbacks[what], cb)
   410  }
   411  
   412  //OnConfig is run when node recieves updated configuration from the server
   413  func (n *Node) OnConfig(cb OnFunc) {
   414  	n.On("setup", func(data json.RawMessage) error {
   415  		conf := &models.Node{}
   416  		err := json.Unmarshal(data, conf)
   417  		if err != nil {
   418  			return err
   419  		}
   420  
   421  		if len(conf.Config) == 0 {
   422  			return nil
   423  		}
   424  
   425  		return cb(conf.Config)
   426  	})
   427  }
   428  
   429  // WaitForFirstConfig blocks until we recieve the first config from server
   430  func (n *Node) WaitForFirstConfig() func() error {
   431  	var once sync.Once
   432  	waitForConfig := make(chan struct{})
   433  	n.OnConfig(func(data json.RawMessage) error {
   434  		var err error
   435  		once.Do(func() {
   436  			close(waitForConfig)
   437  		})
   438  		return err
   439  	})
   440  
   441  	return func() error {
   442  		select {
   443  		case <-waitForConfig:
   444  			return nil
   445  		case <-n.stop:
   446  			return fmt.Errorf("node: stopped before first config")
   447  		}
   448  	}
   449  }
   450  
   451  // OnShutdown registeres a callback that is run before the server shuts down
   452  func (n *Node) OnShutdown(cb func()) {
   453  	n.shutdown = append(n.shutdown, cb)
   454  }
   455  
   456  // OnRequestStateChange is run if we get a state-change request from the server to update our devices (for example we are requested to turn on a light)
   457  func (n *Node) OnRequestStateChange(cb func(state devices.State, device *devices.Device) error) {
   458  	n.On("state-change", func(data json.RawMessage) error {
   459  		//devs := devices.NewList()
   460  		devs := make(map[devices.ID]devices.State)
   461  		err := json.Unmarshal(data, &devs)
   462  		if err != nil {
   463  			return err
   464  		}
   465  
   466  		for devID, state := range devs {
   467  			// loop over all devices and compare state
   468  			stateChange := make(devices.State)
   469  			foundChange := false
   470  			oldDev := n.Devices.Get(devID)
   471  			for s, newState := range state {
   472  				oldState := oldDev.State[s]
   473  				if newState != oldState {
   474  					//fmt.Printf("oldstate %T %#v\n", oldState, newState)
   475  					//fmt.Printf("newState %T %#v\n", newState, newState)
   476  					stateChange[s] = newState
   477  					foundChange = true
   478  				}
   479  			}
   480  			if foundChange {
   481  				err := cb(stateChange, oldDev)
   482  				if err != nil {
   483  					// set state back to before. we could not change it as requested
   484  					// continue to next device
   485  					if err == ErrSkipSync { // skip sync without logging error if needed
   486  						continue
   487  					}
   488  					logrus.Error(err)
   489  					continue
   490  				}
   491  
   492  				// set the new state and send it to the server
   493  				err = n.Devices.SetState(devID, state.Merge(stateChange))
   494  				if err != nil {
   495  					logrus.Error(err)
   496  					continue
   497  				}
   498  
   499  				err = n.WriteMessage("update-device", n.Devices.Get(devID))
   500  				if err != nil {
   501  					logrus.Error(err)
   502  					continue
   503  				}
   504  			}
   505  		}
   506  
   507  		return nil
   508  	})
   509  }
   510  
   511  var ErrSkipSync = fmt.Errorf("skipping device sync after RequestStateChange")
   512  
   513  // AddOrUpdate adds or updates a device in our local device store and notifies the server about the new state of the device.
   514  func (n *Node) AddOrUpdate(d *devices.Device) {
   515  	d.ID.Node = n.UUID
   516  	n.Devices.Add(d)
   517  	n.sendUpdate <- d.ID
   518  }
   519  
   520  // syncWorker is a debouncer to send multiple devices to the server if we change many rapidly
   521  func (n *Node) syncWorker() {
   522  	for {
   523  		que := make([]devices.ID, 0)
   524  		id := <-n.sendUpdate
   525  		que = append(que, id)
   526  
   527  		max := time.NewTimer(10 * time.Millisecond)
   528  	outer:
   529  		for {
   530  			select {
   531  			case id := <-n.sendUpdate:
   532  				que = append(que, id)
   533  			case <-time.After(1 * time.Millisecond):
   534  				break outer
   535  			case <-max.C:
   536  				break outer
   537  
   538  			}
   539  		}
   540  
   541  		// send message to server
   542  		devs := make(devices.DeviceMap)
   543  		for _, id := range que {
   544  			d := n.GetDevice(id.ID)
   545  			devs[d.ID] = d.Copy()
   546  		}
   547  		err := n.WriteMessage("update-devices", devs)
   548  		if err != nil {
   549  			logrus.Error(err)
   550  			return
   551  		}
   552  	}
   553  }
   554  
   555  func (n *Node) GetDevice(id string) *devices.Device {
   556  	return n.Devices.Get(devices.ID{Node: n.UUID, ID: id})
   557  }
   558  
   559  // UpdateState updates the new state on the node if if differs and sends update to server if there was a diff
   560  func (n *Node) UpdateState(id string, newState devices.State) {
   561  	device := n.GetDevice(id)
   562  
   563  	if device == nil {
   564  		return
   565  	}
   566  
   567  	if len(newState) == 0 {
   568  		return
   569  	}
   570  
   571  	if diff := device.State.Diff(newState); len(diff) != 0 {
   572  		device.Lock()
   573  		device.State.MergeWith(diff)
   574  		device.Unlock()
   575  		n.SyncDevice(id)
   576  	}
   577  }
   578  
   579  //SyncDevices notifies the server about the state of all our known devices.
   580  func (n *Node) SyncDevices() error {
   581  	return n.WriteMessage("update-devices", n.Devices)
   582  }
   583  
   584  // SyncDevice sync single device
   585  func (n *Node) SyncDevice(id string) {
   586  	n.sendUpdate <- devices.ID{ID: id}
   587  }
   588  
   589  //Subscribe subscribes to a topic in the server
   590  func (n *Node) Subscribe(what ...string) error {
   591  	return n.WriteMessage("subscribe", what)
   592  }