github.com/stampzilla/stampzilla-go@v2.0.0-rc9+incompatible/pkg/node/node.go (about) 1 package node 2 3 import ( 4 "context" 5 "crypto/rand" 6 "crypto/rsa" 7 "crypto/tls" 8 "crypto/x509" 9 "crypto/x509/pkix" 10 "encoding/json" 11 "encoding/pem" 12 "fmt" 13 "io/ioutil" 14 "log" 15 "net/http" 16 "os" 17 "os/signal" 18 "sync" 19 "syscall" 20 "time" 21 22 "github.com/google/uuid" 23 "github.com/sirupsen/logrus" 24 "github.com/stampzilla/stampzilla-go/nodes/stampzilla-server/models" 25 "github.com/stampzilla/stampzilla-go/nodes/stampzilla-server/models/devices" 26 "github.com/stampzilla/stampzilla-go/pkg/build" 27 "github.com/stampzilla/stampzilla-go/pkg/websocket" 28 ) 29 30 // OnFunc is used in all the callbacks 31 type OnFunc func(json.RawMessage) error 32 33 // Node is the main struct 34 type Node struct { 35 UUID string 36 Type string 37 Version string 38 Protocol string 39 40 Client websocket.Websocket 41 //DisconnectClient context.CancelFunc 42 wg *sync.WaitGroup 43 Config *models.Config 44 X509 *x509.Certificate 45 TLS *tls.Certificate 46 CA *x509.CertPool 47 callbacks map[string][]OnFunc 48 Devices *devices.List 49 shutdown []func() 50 stop chan struct{} 51 sendUpdate chan devices.ID 52 } 53 54 // New returns a new Node 55 func New(t string) *Node { 56 client := websocket.New() 57 node := NewWithClient(client) 58 node.Type = t 59 60 node.setup() 61 62 return node 63 } 64 65 // NewWithClient returns a new Node with a custom websocket client 66 func NewWithClient(client websocket.Websocket) *Node { 67 return &Node{ 68 Client: client, 69 wg: &sync.WaitGroup{}, 70 callbacks: make(map[string][]OnFunc), 71 Devices: devices.NewList(), 72 stop: make(chan struct{}), 73 sendUpdate: make(chan devices.ID), 74 } 75 } 76 77 // Stop will shutdown the node similar to a SIGTERM 78 func (n *Node) Stop() { 79 close(n.stop) 80 } 81 82 // Stopped is closed when the node is stopped by n.Stop or os signal 83 func (n *Node) Stopped() <-chan struct{} { 84 return n.stop 85 } 86 87 // Wait for node to be done after shutdown 88 func (n *Node) Wait() { 89 n.Client.Wait() 90 n.wg.Wait() 91 } 92 93 func (n *Node) setup() { 94 logrus.SetReportCaller(true) 95 logrus.SetFormatter(&logrus.TextFormatter{TimestampFormat: time.RFC3339Nano, FullTimestamp: true}) 96 97 //Make sure we have a config 98 n.Config = &models.Config{} 99 n.Config.MustLoad() 100 101 if n.Config.Version { 102 fmt.Println(build.String()) 103 os.Exit(1) 104 } 105 n.Version = build.String() 106 107 if n.Config.LogLevel != "" { 108 lvl, err := logrus.ParseLevel(n.Config.LogLevel) 109 if err != nil { 110 logrus.Fatal(err) 111 return 112 } 113 logrus.SetLevel(lvl) 114 } 115 116 //n.Config.Save("config.json") 117 } 118 119 // WriteMessage writes a message to the server over websocket client 120 func (n *Node) WriteMessage(msgType string, data interface{}) error { 121 msg, err := models.NewMessage(msgType, data) 122 logrus.WithFields(logrus.Fields{ 123 "type": msgType, 124 "body": data, 125 }).Tracef("Send to server") 126 if err != nil { 127 return err 128 } 129 return n.Client.WriteJSON(msg) 130 } 131 132 // WaitForMessage is a helper method to wait for a specific message type 133 func (n *Node) WaitForMessage(msgType string, dst interface{}) error { 134 135 for data := range n.Client.Read() { 136 msg, err := models.ParseMessage(data) 137 if err != nil { 138 return err 139 } 140 if msg.Type == msgType { 141 return json.Unmarshal(msg.Body, dst) 142 } 143 } 144 return nil 145 } 146 147 func (n *Node) fetchCertificate() error { 148 // Start with creating a CSR and assign a UUID 149 csr, err := n.generateCSR() 150 if err != nil { 151 return err 152 } 153 154 u := fmt.Sprintf("ws://%s:%s/ws", n.Config.Host, n.Config.Port) 155 ctx, cancel := context.WithCancel(context.Background()) 156 defer cancel() 157 158 interrupt := make(chan os.Signal, 1) 159 signal.Notify(interrupt, os.Interrupt, syscall.SIGQUIT, syscall.SIGTERM) 160 go func() { 161 select { 162 case <-interrupt: 163 close(n.stop) 164 case <-n.stop: 165 } 166 cancel() 167 go func() { 168 <-time.After(time.Second * 10) 169 log.Fatal("force shutdown after 10 seconds") 170 }() 171 for _, f := range n.shutdown { 172 f() 173 } 174 }() 175 176 n.connect(ctx, u) 177 178 // wait for server info so we can update our config 179 serverInfo := &models.ServerInfo{} 180 err = n.WaitForMessage("server-info", serverInfo) 181 if err != nil { 182 return err 183 } 184 n.Config.Port = serverInfo.Port 185 n.Config.TLSPort = serverInfo.TLSPort 186 187 n.WriteMessage("certificate-signing-request", models.Request{ 188 Type: n.Type, 189 Version: n.Version, 190 CSR: string(csr), 191 }) 192 if err != nil { 193 return err 194 } 195 196 // wait for our new certificate 197 198 var rawCert string 199 err = n.WaitForMessage("approved-certificate-signing-request", &rawCert) 200 201 err = ioutil.WriteFile("crt.crt", []byte(rawCert), 0644) 202 if err != nil { 203 return err 204 } 205 206 var caCert string 207 err = n.WaitForMessage("certificate-authority", &caCert) 208 209 err = ioutil.WriteFile("ca.crt", []byte(caCert), 0644) 210 if err != nil { 211 return err 212 } 213 214 logrus.Info("Disconnect inseure connection") 215 cancel() 216 n.Wait() 217 218 // We should have a certificate now. Try to load it 219 return n.loadCertificateKeyPair("crt") 220 } 221 222 // Connect starts the node and makes connection to the server. Normally discovered using mdns but can be configured aswell. 223 func (n *Node) Connect() error { 224 225 if n.Config.Host == "" { 226 ip, port := queryMDNS() 227 n.Config.Host = ip 228 n.Config.Port = port 229 } 230 231 // Load our signed certificate and get our UUID 232 err := n.loadCertificateKeyPair("crt") 233 234 if err != nil { 235 logrus.Error("Error trying to load certificate: ", err) 236 err = n.fetchCertificate() 237 if err != nil { 238 return err 239 } 240 } 241 242 //If we have certificate we can connect to TLS immedietly 243 tlsConfig := &tls.Config{ 244 Certificates: []tls.Certificate{*n.TLS}, 245 RootCAs: n.CA, 246 ServerName: "localhost", 247 } 248 249 n.Client.SetTLSConfig(tlsConfig) 250 251 u := fmt.Sprintf("wss://%s:%s/ws", n.Config.Host, n.Config.TLSPort) 252 ctx, cancel := context.WithCancel(context.Background()) 253 254 interrupt := make(chan os.Signal, 1) 255 signal.Notify(interrupt, os.Interrupt, syscall.SIGQUIT, syscall.SIGTERM) 256 n.wg.Add(1) 257 go func() { 258 defer n.wg.Done() 259 select { 260 case <-interrupt: 261 close(n.stop) 262 case <-n.stop: 263 } 264 cancel() 265 go func() { 266 <-time.After(time.Second * 10) 267 log.Fatal("force shutdown after 10 seconds") 268 }() 269 for _, f := range n.shutdown { 270 f() 271 } 272 }() 273 274 n.Client.OnConnect(func() { 275 for what := range n.callbacks { 276 n.Subscribe(what) 277 } 278 n.SyncDevices() 279 }) 280 n.connect(ctx, u) 281 n.wg.Add(1) 282 go n.reader(ctx) 283 go n.syncWorker() 284 return nil 285 } 286 287 func (n *Node) reader(ctx context.Context) { 288 defer n.wg.Done() 289 for { 290 select { 291 case <-ctx.Done(): 292 logrus.Info("Stopping node reader because:", ctx.Err()) 293 return 294 case data := <-n.Client.Read(): 295 msg, err := models.ParseMessage(data) 296 if err != nil { 297 logrus.Error("node:", err) 298 continue 299 } 300 for _, cb := range n.callbacks[msg.Type] { 301 err := cb(msg.Body) 302 if err != nil { 303 logrus.Error(err) 304 continue 305 } 306 } 307 if n.callbacks[msg.Type] == nil || len(n.callbacks[msg.Type]) == 0 { 308 logrus.WithFields(logrus.Fields{ 309 "type": msg.Type, 310 }).Warn("Received message but no one cared") 311 } 312 } 313 } 314 } 315 316 func (n *Node) connect(ctx context.Context, addr string) { 317 headers := http.Header{} 318 headers.Add("X-UUID", n.UUID) 319 headers.Add("X-TYPE", n.Type) 320 headers.Set("Sec-WebSocket-Protocol", n.Protocol) 321 if n.Protocol == "" { 322 headers.Set("Sec-WebSocket-Protocol", "node") 323 } 324 n.Client.ConnectWithRetry(ctx, addr, headers) 325 } 326 327 func (n *Node) loadCertificateKeyPair(name string) error { 328 certTLS, err := tls.LoadX509KeyPair(name+".crt", name+".key") 329 if err != nil { 330 return err 331 } 332 certX509, err := x509.ParseCertificate(certTLS.Certificate[0]) 333 if err != nil { 334 return err 335 } 336 337 n.TLS = &certTLS 338 n.X509 = certX509 339 n.UUID = certX509.Subject.CommonName 340 341 // Load CA cert 342 caCert, err := ioutil.ReadFile("ca.crt") 343 if err != nil { 344 log.Fatal(err) 345 } 346 caCertPool := x509.NewCertPool() 347 caCertPool.AppendCertsFromPEM(caCert) 348 n.CA = caCertPool 349 350 return nil 351 } 352 353 func (n *Node) loadOrGenerateKey() (*rsa.PrivateKey, error) { 354 data, err := ioutil.ReadFile("crt.key") 355 if err != nil { 356 if os.IsNotExist(err) { 357 return n.generateKey() 358 } 359 return nil, err 360 } 361 block, _ := pem.Decode(data) 362 return x509.ParsePKCS1PrivateKey(block.Bytes) 363 } 364 365 func (n *Node) generateKey() (*rsa.PrivateKey, error) { 366 priv, _ := rsa.GenerateKey(rand.Reader, 2048) 367 keyOut, err := os.OpenFile("crt.key", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) 368 if err != nil { 369 return nil, err 370 } 371 err = pem.Encode(keyOut, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) 372 keyOut.Close() 373 return priv, err 374 } 375 376 func (n *Node) generateCSR() ([]byte, error) { 377 hostname, err := os.Hostname() 378 if err != nil { 379 return nil, err 380 } 381 382 subj := pkix.Name{ 383 CommonName: uuid.New().String(), 384 Organization: []string{"stampzilla-go"}, 385 OrganizationalUnit: []string{hostname, n.Type}, 386 } 387 388 template := x509.CertificateRequest{ 389 Subject: subj, 390 SignatureAlgorithm: x509.SHA256WithRSA, 391 } 392 393 priv, err := n.loadOrGenerateKey() 394 if err != nil { 395 return nil, err 396 } 397 398 csrBytes, _ := x509.CreateCertificateRequest(rand.Reader, &template, priv) 399 d := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csrBytes}) 400 401 n.UUID = template.Subject.CommonName 402 403 return d, nil 404 } 405 406 // On sets up a callback that is run when a message recieved with type what 407 func (n *Node) On(what string, cb OnFunc) { 408 n.Subscribe(what) 409 n.callbacks[what] = append(n.callbacks[what], cb) 410 } 411 412 //OnConfig is run when node recieves updated configuration from the server 413 func (n *Node) OnConfig(cb OnFunc) { 414 n.On("setup", func(data json.RawMessage) error { 415 conf := &models.Node{} 416 err := json.Unmarshal(data, conf) 417 if err != nil { 418 return err 419 } 420 421 if len(conf.Config) == 0 { 422 return nil 423 } 424 425 return cb(conf.Config) 426 }) 427 } 428 429 // WaitForFirstConfig blocks until we recieve the first config from server 430 func (n *Node) WaitForFirstConfig() func() error { 431 var once sync.Once 432 waitForConfig := make(chan struct{}) 433 n.OnConfig(func(data json.RawMessage) error { 434 var err error 435 once.Do(func() { 436 close(waitForConfig) 437 }) 438 return err 439 }) 440 441 return func() error { 442 select { 443 case <-waitForConfig: 444 return nil 445 case <-n.stop: 446 return fmt.Errorf("node: stopped before first config") 447 } 448 } 449 } 450 451 // OnShutdown registeres a callback that is run before the server shuts down 452 func (n *Node) OnShutdown(cb func()) { 453 n.shutdown = append(n.shutdown, cb) 454 } 455 456 // OnRequestStateChange is run if we get a state-change request from the server to update our devices (for example we are requested to turn on a light) 457 func (n *Node) OnRequestStateChange(cb func(state devices.State, device *devices.Device) error) { 458 n.On("state-change", func(data json.RawMessage) error { 459 //devs := devices.NewList() 460 devs := make(map[devices.ID]devices.State) 461 err := json.Unmarshal(data, &devs) 462 if err != nil { 463 return err 464 } 465 466 for devID, state := range devs { 467 // loop over all devices and compare state 468 stateChange := make(devices.State) 469 foundChange := false 470 oldDev := n.Devices.Get(devID) 471 for s, newState := range state { 472 oldState := oldDev.State[s] 473 if newState != oldState { 474 //fmt.Printf("oldstate %T %#v\n", oldState, newState) 475 //fmt.Printf("newState %T %#v\n", newState, newState) 476 stateChange[s] = newState 477 foundChange = true 478 } 479 } 480 if foundChange { 481 err := cb(stateChange, oldDev) 482 if err != nil { 483 // set state back to before. we could not change it as requested 484 // continue to next device 485 if err == ErrSkipSync { // skip sync without logging error if needed 486 continue 487 } 488 logrus.Error(err) 489 continue 490 } 491 492 // set the new state and send it to the server 493 err = n.Devices.SetState(devID, state.Merge(stateChange)) 494 if err != nil { 495 logrus.Error(err) 496 continue 497 } 498 499 err = n.WriteMessage("update-device", n.Devices.Get(devID)) 500 if err != nil { 501 logrus.Error(err) 502 continue 503 } 504 } 505 } 506 507 return nil 508 }) 509 } 510 511 var ErrSkipSync = fmt.Errorf("skipping device sync after RequestStateChange") 512 513 // AddOrUpdate adds or updates a device in our local device store and notifies the server about the new state of the device. 514 func (n *Node) AddOrUpdate(d *devices.Device) { 515 d.ID.Node = n.UUID 516 n.Devices.Add(d) 517 n.sendUpdate <- d.ID 518 } 519 520 // syncWorker is a debouncer to send multiple devices to the server if we change many rapidly 521 func (n *Node) syncWorker() { 522 for { 523 que := make([]devices.ID, 0) 524 id := <-n.sendUpdate 525 que = append(que, id) 526 527 max := time.NewTimer(10 * time.Millisecond) 528 outer: 529 for { 530 select { 531 case id := <-n.sendUpdate: 532 que = append(que, id) 533 case <-time.After(1 * time.Millisecond): 534 break outer 535 case <-max.C: 536 break outer 537 538 } 539 } 540 541 // send message to server 542 devs := make(devices.DeviceMap) 543 for _, id := range que { 544 d := n.GetDevice(id.ID) 545 devs[d.ID] = d.Copy() 546 } 547 err := n.WriteMessage("update-devices", devs) 548 if err != nil { 549 logrus.Error(err) 550 return 551 } 552 } 553 } 554 555 func (n *Node) GetDevice(id string) *devices.Device { 556 return n.Devices.Get(devices.ID{Node: n.UUID, ID: id}) 557 } 558 559 // UpdateState updates the new state on the node if if differs and sends update to server if there was a diff 560 func (n *Node) UpdateState(id string, newState devices.State) { 561 device := n.GetDevice(id) 562 563 if device == nil { 564 return 565 } 566 567 if len(newState) == 0 { 568 return 569 } 570 571 if diff := device.State.Diff(newState); len(diff) != 0 { 572 device.Lock() 573 device.State.MergeWith(diff) 574 device.Unlock() 575 n.SyncDevice(id) 576 } 577 } 578 579 //SyncDevices notifies the server about the state of all our known devices. 580 func (n *Node) SyncDevices() error { 581 return n.WriteMessage("update-devices", n.Devices) 582 } 583 584 // SyncDevice sync single device 585 func (n *Node) SyncDevice(id string) { 586 n.sendUpdate <- devices.ID{ID: id} 587 } 588 589 //Subscribe subscribes to a topic in the server 590 func (n *Node) Subscribe(what ...string) error { 591 return n.WriteMessage("subscribe", what) 592 }