github.com/simpleiot/simpleiot@v0.18.3/client/update.go (about)

     1  package client
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"log"
     7  	"net/http"
     8  	"net/url"
     9  	"os"
    10  	"os/exec"
    11  	"path/filepath"
    12  	"regexp"
    13  	"slices"
    14  	"sort"
    15  	"strconv"
    16  	"strings"
    17  	"time"
    18  
    19  	"github.com/blang/semver/v4"
    20  	"github.com/nats-io/nats.go"
    21  	"github.com/simpleiot/simpleiot/data"
    22  	"github.com/simpleiot/simpleiot/system"
    23  )
    24  
    25  // Update represents the config of a metrics node type
    26  type Update struct {
    27  	ID              string   `node:"id"`
    28  	Parent          string   `node:"parent"`
    29  	Description     string   `point:"description"`
    30  	VersionOS       string   `point:"versionOS"`
    31  	URI             string   `point:"uri"`
    32  	OSUpdates       []string `point:"osUpdate"`
    33  	DownloadOS      string   `point:"downloadOS"`
    34  	OSDownloaded    string   `point:"osDownloaded"`
    35  	DiscardDownload string   `point:"discardDownload"`
    36  	Prefix          string   `point:"prefix"`
    37  	Directory       string   `point:"directory"`
    38  	PollPeriod      int      `point:"pollPeriod"`
    39  	Refresh         bool     `point:"refresh"`
    40  	AutoDownload    bool     `point:"autoDownload"`
    41  	AutoReboot      bool     `point:"autoReboot"`
    42  }
    43  
    44  // UpdateClient is a SIOT client used to collect system or app metrics
    45  type UpdateClient struct {
    46  	log           *log.Logger
    47  	nc            *nats.Conn
    48  	config        Update
    49  	stop          chan struct{}
    50  	newPoints     chan NewPoints
    51  	newEdgePoints chan NewPoints
    52  }
    53  
    54  // NewUpdateClient ...
    55  func NewUpdateClient(nc *nats.Conn, config Update) Client {
    56  	return &UpdateClient{
    57  		log:           log.New(os.Stderr, "Update: ", log.LstdFlags|log.Lmsgprefix),
    58  		nc:            nc,
    59  		config:        config,
    60  		stop:          make(chan struct{}),
    61  		newPoints:     make(chan NewPoints),
    62  		newEdgePoints: make(chan NewPoints),
    63  	}
    64  }
    65  
    66  func (m *UpdateClient) setError(err error) {
    67  	errS := ""
    68  	if err != nil {
    69  		errS = err.Error()
    70  		m.log.Println(err)
    71  	}
    72  
    73  	p := data.Point{
    74  		Type: data.PointTypeError,
    75  		Time: time.Now(),
    76  		Text: errS,
    77  	}
    78  
    79  	e := SendNodePoint(m.nc, m.config.ID, p, true)
    80  	if e != nil {
    81  		m.log.Println("error sending point:", e)
    82  	}
    83  }
    84  
    85  var reUpd = regexp.MustCompile(`(.*)_(\d+\.\d+\.\d+)\.upd`)
    86  
    87  // Run the main logic for this client and blocks until stopped
    88  func (m *UpdateClient) Run() error {
    89  	cDownloadFinished := make(chan struct{})
    90  	// cSetError is used in any goroutines
    91  	cSetError := make(chan error)
    92  
    93  	download := func(v string) error {
    94  		defer func() {
    95  			cDownloadFinished <- struct{}{}
    96  			_ = SendNodePoint(m.nc, m.config.ID,
    97  				data.Point{Time: time.Now(), Type: data.PointTypeDownloadOS, Text: ""},
    98  				false,
    99  			)
   100  			m.config.DownloadOS = ""
   101  		}()
   102  
   103  		u, err := url.JoinPath(m.config.URI, m.config.Prefix+"_"+v+".upd")
   104  		if err != nil {
   105  			return fmt.Errorf("URI error: %w", err)
   106  		}
   107  
   108  		m.log.Println("Downloading update: ", u)
   109  
   110  		fileName := filepath.Base(u)
   111  		destPath := filepath.Join(m.config.Directory, fileName)
   112  
   113  		out, err := os.Create(destPath)
   114  		if err != nil {
   115  			return fmt.Errorf("Error creating OS update file: %w", err)
   116  		}
   117  		defer out.Close()
   118  
   119  		resp, err := http.Get(u)
   120  		if err != nil {
   121  			return fmt.Errorf("Error http get fetching OS update: %w", err)
   122  		}
   123  		defer resp.Body.Close()
   124  
   125  		c, err := io.Copy(out, resp.Body)
   126  		if err != nil {
   127  			return fmt.Errorf("io.Copy error: %w", err)
   128  		}
   129  
   130  		if c <= 0 {
   131  			os.Remove(destPath)
   132  			return fmt.Errorf("Failed to download: %v", u)
   133  		}
   134  
   135  		return nil
   136  	}
   137  
   138  	// fill in default prefix
   139  	if m.config.Prefix == "" {
   140  		p, err := os.Hostname()
   141  		if err != nil {
   142  			m.log.Println("Error getting hostname: ", err)
   143  		} else {
   144  			m.log.Println("Setting update prefix to: ", p)
   145  			err := SendNodePoint(m.nc, m.config.ID, data.Point{
   146  				Time: time.Now(),
   147  				Type: data.PointTypePrefix,
   148  				Key:  "0",
   149  				Text: p}, false)
   150  			if err != nil {
   151  				m.log.Println("Error sending point: ", err)
   152  			} else {
   153  				m.config.Prefix = p
   154  			}
   155  		}
   156  	}
   157  
   158  	if m.config.Directory == "" {
   159  		d := "/data"
   160  		m.log.Println("Setting directory to: ", d)
   161  		err := SendNodePoint(m.nc, m.config.ID, data.Point{
   162  			Time: time.Now(),
   163  			Type: data.PointTypeDirectory,
   164  			Key:  "0",
   165  			Text: d}, false)
   166  		if err != nil {
   167  			m.log.Println("Error sending point: ", err)
   168  		} else {
   169  			m.config.Directory = d
   170  		}
   171  	}
   172  
   173  	if m.config.PollPeriod <= 0 {
   174  		p := 30
   175  		m.log.Println("Setting poll period to: ", p)
   176  		err := SendNodePoint(m.nc, m.config.ID, data.Point{
   177  			Time:  time.Now(),
   178  			Type:  data.PointTypePollPeriod,
   179  			Key:   "0",
   180  			Value: float64(p)}, false)
   181  		if err != nil {
   182  			m.log.Println("Error sending point: ", err)
   183  		} else {
   184  			m.config.PollPeriod = p
   185  		}
   186  	}
   187  
   188  	getUpdates := func() error {
   189  		clearUpdateList := func() {
   190  			cnt := len(m.config.OSUpdates)
   191  
   192  			if cnt > 0 {
   193  				pts := data.Points{}
   194  				for i := 0; i < cnt; i++ {
   195  					pts = append(pts, data.Point{
   196  						Time: time.Now(), Type: data.PointTypeOSUpdate, Key: strconv.Itoa(i), Tombstone: 1,
   197  					})
   198  				}
   199  
   200  				err := SendNodePoints(m.nc, m.config.ID, pts, false)
   201  				if err != nil {
   202  					m.log.Println("Error sending version points: ", err)
   203  				}
   204  			}
   205  		}
   206  
   207  		p, err := url.JoinPath(m.config.URI, "files.txt")
   208  		if err != nil {
   209  			clearUpdateList()
   210  			return fmt.Errorf("URI error: %w", err)
   211  		}
   212  		resp, err := http.Get(p)
   213  		if err != nil {
   214  			clearUpdateList()
   215  			return fmt.Errorf("Error getting updates: %w", err)
   216  		}
   217  
   218  		if resp.StatusCode != 200 {
   219  			clearUpdateList()
   220  			return fmt.Errorf("Error getting updates: %v", resp.Status)
   221  		}
   222  
   223  		defer resp.Body.Close()
   224  
   225  		body, err := io.ReadAll(resp.Body)
   226  		if err != nil {
   227  			return fmt.Errorf("Error reading http response: %w", err)
   228  		}
   229  
   230  		updates := strings.Split(string(body), "\n")
   231  
   232  		updates = slices.DeleteFunc(updates, func(u string) bool {
   233  			return !strings.HasPrefix(u, m.config.Prefix)
   234  		})
   235  
   236  		versions := semver.Versions{}
   237  
   238  		for _, u := range updates {
   239  			matches := reUpd.FindStringSubmatch(u)
   240  			if len(matches) > 1 {
   241  				prefix := matches[1]
   242  				version := matches[2]
   243  				sv, err := semver.Parse(version)
   244  				if err != nil {
   245  					m.log.Printf("Error parsing version %v: %v\n", version, err)
   246  				}
   247  				if prefix == m.config.Prefix {
   248  					versions = append(versions, sv)
   249  				}
   250  			} else {
   251  				m.log.Println("Version not found in filename: ", u)
   252  			}
   253  		}
   254  
   255  		sort.Sort(versions)
   256  
   257  		// need to update versions available
   258  		pts := data.Points{}
   259  		now := time.Now()
   260  		for i, v := range versions {
   261  			pts = append(pts, data.Point{
   262  				Time: now, Type: data.PointTypeOSUpdate, Text: v.String(), Key: strconv.Itoa(i),
   263  			})
   264  		}
   265  
   266  		err = SendNodePoints(m.nc, m.config.ID, pts, false)
   267  		if err != nil {
   268  			m.log.Println("Error sending version points: ", err)
   269  
   270  		}
   271  
   272  		err = data.MergePoints(m.config.ID, pts, &m.config)
   273  		if err != nil {
   274  			log.Println("error merging new points:", err)
   275  		}
   276  
   277  		underflowCount := len(m.config.OSUpdates) - len(versions)
   278  
   279  		if underflowCount > 0 {
   280  			pts := data.Points{}
   281  			for i := len(versions); i < len(versions)+underflowCount; i++ {
   282  				pts = append(pts, data.Point{
   283  					Time: now, Type: data.PointTypeOSUpdate, Key: strconv.Itoa(i), Tombstone: 1,
   284  				})
   285  			}
   286  
   287  			err = SendNodePoints(m.nc, m.config.ID, pts, false)
   288  			if err != nil {
   289  				m.log.Println("Error sending version points: ", err)
   290  			}
   291  		}
   292  		return nil
   293  	}
   294  
   295  	cleanDownloads := func() error {
   296  		files, err := os.ReadDir(m.config.Directory)
   297  		var errRet error
   298  		if err != nil {
   299  			return fmt.Errorf("Error getting files in data dir: %w", err)
   300  		}
   301  
   302  		for _, file := range files {
   303  			if !file.IsDir() && filepath.Ext(file.Name()) == ".upd" {
   304  				p := filepath.Join(m.config.Directory, file.Name())
   305  				err = os.Remove(p)
   306  				if err != nil {
   307  					m.log.Printf("Error removing %v: %v\n", file.Name(), err)
   308  					errRet = err
   309  				}
   310  			}
   311  		}
   312  
   313  		m.config.OSDownloaded = ""
   314  		err = SendNodePoint(m.nc, m.config.ID, data.Point{
   315  			Time: time.Now(),
   316  			Type: data.PointTypeOSDownloaded,
   317  			Text: "",
   318  			Key:  "0",
   319  		}, true)
   320  		if err != nil {
   321  			m.log.Println("Error clearing downloaded point: ", err)
   322  		}
   323  
   324  		err = SendNodePoints(m.nc, m.config.ID, data.Points{
   325  			{Time: time.Now(), Type: data.PointTypeDiscardDownload, Value: 0},
   326  		}, true)
   327  		if err != nil {
   328  			m.log.Println("Error discarding download: ", err)
   329  		}
   330  
   331  		return errRet
   332  	}
   333  
   334  	checkDownloads := func() error {
   335  		files, err := os.ReadDir(m.config.Directory)
   336  		if err != nil {
   337  			return fmt.Errorf("Error getting files in data dir: %w", err)
   338  		}
   339  
   340  		updFiles := []string{}
   341  		for _, file := range files {
   342  			if !file.IsDir() && filepath.Ext(file.Name()) == ".upd" {
   343  				updFiles = append(updFiles, file.Name())
   344  			}
   345  		}
   346  
   347  		versions := semver.Versions{}
   348  		for _, u := range updFiles {
   349  
   350  			matches := reUpd.FindStringSubmatch(u)
   351  			if len(matches) > 1 {
   352  				prefix := matches[1]
   353  				version := matches[2]
   354  				sv, err := semver.Parse(version)
   355  				if err != nil {
   356  					m.log.Printf("Error parsing version %v: %v\n", version, err)
   357  				}
   358  				if prefix == m.config.Prefix {
   359  					versions = append(versions, sv)
   360  				}
   361  			} else {
   362  				m.log.Println("Version not found in filename: ", u)
   363  			}
   364  		}
   365  
   366  		sort.Sort(versions)
   367  
   368  		if len(versions) > 0 {
   369  			m.config.OSDownloaded = versions[len(versions)-1].String()
   370  			err := SendNodePoint(m.nc, m.config.ID, data.Point{
   371  				Time: time.Now(),
   372  				Type: data.PointTypeOSDownloaded,
   373  				Key:  "0",
   374  				Text: m.config.OSDownloaded}, true)
   375  
   376  			if err != nil {
   377  				m.log.Println("Error sending point: ", err)
   378  			}
   379  		} else {
   380  			m.config.OSDownloaded = ""
   381  			err = SendNodePoint(m.nc, m.config.ID, data.Point{
   382  				Time: time.Now(),
   383  				Type: data.PointTypeOSDownloaded,
   384  				Text: "",
   385  				Key:  "0",
   386  			}, true)
   387  			if err != nil {
   388  				m.log.Println("Error clearing downloaded point: ", err)
   389  			}
   390  		}
   391  		return nil
   392  	}
   393  
   394  	reboot := func() {
   395  		err := exec.Command("reboot").Run()
   396  		if err != nil {
   397  			m.log.Println("Error rebooting: ", err)
   398  		} else {
   399  			m.log.Println("Rebooting ...")
   400  		}
   401  	}
   402  
   403  	autoDownload := func() error {
   404  		newestUpdate := ""
   405  		if len(m.config.OSUpdates) > 0 {
   406  			newestUpdate = m.config.OSUpdates[len(m.config.OSUpdates)-1]
   407  		} else {
   408  			return nil
   409  		}
   410  
   411  		currentOSV, err := semver.Parse(m.config.VersionOS)
   412  		if err != nil {
   413  			return fmt.Errorf("Autodownload, Error parsing current OS version: %w", err)
   414  		}
   415  
   416  		newestUpdateV, err := semver.Parse(newestUpdate)
   417  		if err != nil {
   418  			return fmt.Errorf("autodownload: Error parsing newest OS update version: %w", err)
   419  		}
   420  
   421  		if newestUpdateV.GT(currentOSV) &&
   422  			newestUpdate != m.config.OSDownloaded &&
   423  			newestUpdate != m.config.DownloadOS {
   424  			// download a newer update
   425  			err := SendNodePoint(m.nc, m.config.ID, data.Point{
   426  				Time: time.Now(),
   427  				Type: data.PointTypeDownloadOS,
   428  				Text: newestUpdate,
   429  			}, true)
   430  			if err != nil {
   431  				return fmt.Errorf("Error sending point: %w", err)
   432  			}
   433  			m.config.DownloadOS = newestUpdate
   434  
   435  			go func(f string) {
   436  				err := download(f)
   437  				if err != nil {
   438  					cSetError <- fmt.Errorf("error downloading update: %w", err)
   439  				}
   440  			}(newestUpdate)
   441  		}
   442  		return nil
   443  	}
   444  
   445  	m.setError(nil)
   446  	err := getUpdates()
   447  	if err != nil {
   448  		m.setError(err)
   449  	}
   450  	err = checkDownloads()
   451  	if err != nil {
   452  		m.setError(err)
   453  	}
   454  
   455  	osVersion, err := system.ReadOSVersion("VERSION_ID")
   456  	if err != nil {
   457  		m.log.Println("Error reading OS version: ", err)
   458  	} else {
   459  		err := SendNodePoint(m.nc, m.config.ID, data.Point{
   460  			Time: time.Now(),
   461  			Type: data.PointTypeVersionOS,
   462  			Key:  "0",
   463  			Text: osVersion.String(),
   464  		}, true)
   465  
   466  		if err != nil {
   467  			m.log.Println("Error sending OS version point: ", err)
   468  		}
   469  
   470  		m.config.VersionOS = osVersion.String()
   471  	}
   472  
   473  	if m.config.DownloadOS != "" {
   474  		go func() {
   475  			err := download(m.config.DownloadOS)
   476  			if err != nil {
   477  				cSetError <- fmt.Errorf("Error downloading file: %w", err)
   478  			}
   479  		}()
   480  	}
   481  
   482  	checkTickerTime := time.Minute * time.Duration(m.config.PollPeriod)
   483  	checkTicker := time.NewTicker(checkTickerTime)
   484  	if m.config.AutoDownload {
   485  		m.setError(nil)
   486  		err := getUpdates()
   487  		if err != nil {
   488  			m.setError(err)
   489  		} else {
   490  			err := autoDownload()
   491  			if err != nil {
   492  				m.setError(err)
   493  			}
   494  		}
   495  	}
   496  
   497  done:
   498  	for {
   499  		select {
   500  		case <-m.stop:
   501  			break done
   502  
   503  		case pts := <-m.newPoints:
   504  			err := data.MergePoints(pts.ID, pts.Points, &m.config)
   505  			if err != nil {
   506  				log.Println("error merging new points:", err)
   507  			}
   508  
   509  			for _, p := range pts.Points {
   510  				switch p.Type {
   511  				case data.PointTypeDownloadOS:
   512  					if p.Text != "" {
   513  						go func(f string) {
   514  							err := download(f)
   515  							if err != nil {
   516  								cSetError <- fmt.Errorf("Error downloading update: %w", err)
   517  							}
   518  						}(p.Text)
   519  					}
   520  				case data.PointTypeDiscardDownload:
   521  					if p.Value != 0 {
   522  						m.setError(nil)
   523  						err := cleanDownloads()
   524  						if err != nil {
   525  							m.setError(fmt.Errorf("Error cleaning downloads: %w", err))
   526  						}
   527  						err = checkDownloads()
   528  						if err != nil {
   529  							m.setError(err)
   530  						}
   531  					}
   532  				case data.PointTypeReboot:
   533  					err := SendNodePoints(m.nc, m.config.ID, data.Points{
   534  						{Time: time.Now(), Type: data.PointTypeReboot, Value: 0},
   535  					}, true)
   536  					if err != nil {
   537  						m.log.Println("Error clearing reboot point: ", err)
   538  					}
   539  
   540  					reboot()
   541  
   542  				case data.PointTypeRefresh:
   543  					err := SendNodePoints(m.nc, m.config.ID, data.Points{
   544  						{Time: time.Now(), Type: data.PointTypeRefresh, Value: 0},
   545  					}, true)
   546  					if err != nil {
   547  						m.log.Println("Error clearing reboot reboot point: ", err)
   548  					}
   549  
   550  					m.setError(nil)
   551  					err = getUpdates()
   552  					if err != nil {
   553  						m.setError(err)
   554  					}
   555  
   556  				case data.PointTypePollPeriod:
   557  					checkTickerTime := time.Minute * time.Duration(p.Value)
   558  					checkTicker.Reset(checkTickerTime)
   559  
   560  				case data.PointTypeAutoDownload:
   561  					if p.Value == 1 {
   562  						m.setError(nil)
   563  						err := getUpdates()
   564  						if err != nil {
   565  							m.setError(err)
   566  						} else {
   567  							err :=
   568  								autoDownload()
   569  							if err != nil {
   570  								m.setError(err)
   571  							}
   572  						}
   573  					}
   574  
   575  				case data.PointTypePrefix:
   576  					m.setError(nil)
   577  					err := cleanDownloads()
   578  					if err != nil {
   579  						m.setError(fmt.Errorf("Error cleaning downloads: %w", err))
   580  					}
   581  					err = checkDownloads()
   582  					if err != nil {
   583  						m.setError(err)
   584  					}
   585  					err = getUpdates()
   586  					if err != nil {
   587  						m.setError(err)
   588  					}
   589  				case data.PointTypeURI:
   590  					m.setError(nil)
   591  					err := getUpdates()
   592  					if err != nil {
   593  						m.setError(err)
   594  					}
   595  				}
   596  			}
   597  
   598  		case pts := <-m.newEdgePoints:
   599  			err := data.MergeEdgePoints(pts.ID, pts.Parent, pts.Points, &m.config)
   600  			if err != nil {
   601  				log.Println("error merging new points:", err)
   602  			}
   603  
   604  		case <-cDownloadFinished:
   605  			now := time.Now()
   606  			err := checkDownloads()
   607  			if err != nil {
   608  				m.setError(err)
   609  			}
   610  
   611  			pts := data.Points{
   612  				{Time: now, Type: data.PointTypeDownloadOS, Text: ""},
   613  				{Time: now, Type: data.PointTypeOSDownloaded, Text: m.config.OSDownloaded},
   614  			}
   615  			err = SendNodePoints(m.nc, m.config.ID, pts, true)
   616  			if err != nil {
   617  				m.log.Println("Error sending node points: ", err)
   618  			}
   619  			m.log.Println("Download process finished")
   620  
   621  			if m.config.AutoReboot {
   622  				// make sure points have time to stick
   623  				time.Sleep(2 * time.Second)
   624  				reboot()
   625  			}
   626  
   627  		case <-checkTicker.C:
   628  			m.setError(nil)
   629  			err := getUpdates()
   630  			if err != nil {
   631  				m.setError(err)
   632  				break
   633  			}
   634  			if m.config.AutoDownload {
   635  				err := autoDownload()
   636  				if err != nil {
   637  					m.setError(err)
   638  				}
   639  			}
   640  			err = checkDownloads()
   641  			if err != nil {
   642  				m.setError(err)
   643  			}
   644  
   645  		case err := <-cSetError:
   646  			m.setError(err)
   647  		}
   648  	}
   649  
   650  	close(cDownloadFinished)
   651  	close(cSetError)
   652  
   653  	return nil
   654  }
   655  
   656  // Stop sends a signal to the Run function to exit
   657  func (m *UpdateClient) Stop(_ error) {
   658  	close(m.stop)
   659  }
   660  
   661  // Points is called by the Manager when new points for this
   662  // node are received.
   663  func (m *UpdateClient) Points(nodeID string, points []data.Point) {
   664  	m.newPoints <- NewPoints{nodeID, "", points}
   665  }
   666  
   667  // EdgePoints is called by the Manager when new edge points for this
   668  // node are received.
   669  func (m *UpdateClient) EdgePoints(nodeID, parentID string, points []data.Point) {
   670  	m.newEdgePoints <- NewPoints{nodeID, parentID, points}
   671  }
   672  
   673  // below is code that used to be in the store and is in process of being
   674  // ported to a client
   675  
   676  // StartUpdate starts an update
   677  /*
   678  func StartUpdate(id, url string) error {
   679  	if _, ok := st.updates[id]; ok {
   680  		return fmt.Errorf("Update already in process for dev: %v", id)
   681  	}
   682  
   683  	st.updates[id] = time.Now()
   684  
   685  	err := st.setSwUpdateState(id, data.SwUpdateState{
   686  		Running: true,
   687  	})
   688  
   689  	if err != nil {
   690  		delete(st.updates, id)
   691  		return err
   692  	}
   693  
   694  	go func() {
   695  		err := NatsSendFileFromHTTP(st.nc, id, url, func(bytesTx int) {
   696  			err := st.setSwUpdateState(id, data.SwUpdateState{
   697  				Running:     true,
   698  				PercentDone: bytesTx,
   699  			})
   700  
   701  			if err != nil {
   702  				log.Println("Error setting update status in DB:", err)
   703  			}
   704  		})
   705  
   706  		state := data.SwUpdateState{
   707  			Running: false,
   708  		}
   709  
   710  		if err != nil {
   711  			state.Error = "Error updating software"
   712  			state.PercentDone = 0
   713  		} else {
   714  			state.PercentDone = 100
   715  		}
   716  
   717  		st.lock.Lock()
   718  		delete(st.updates, id)
   719  		st.lock.Unlock()
   720  
   721  		err = st.setSwUpdateState(id, state)
   722  		if err != nil {
   723  			log.Println("Error setting sw update state:", err)
   724  		}
   725  	}()
   726  
   727  	return nil
   728  }
   729  */