github.com/arduino/arduino-cloud-cli@v0.0.0-20240517070944-e7a449561083/command/ota/massupload.go (about)

     1  // This file is part of arduino-cloud-cli.
     2  //
     3  // Copyright (C) 2021 ARDUINO SA (http://www.arduino.cc/)
     4  //
     5  // This program is free software: you can redistribute it and/or modify
     6  // it under the terms of the GNU Affero General Public License as published
     7  // by the Free Software Foundation, either version 3 of the License, or
     8  // (at your option) any later version.
     9  //
    10  // This program is distributed in the hope that it will be useful,
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13  // GNU Affero General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Affero General Public License
    16  // along with this program.  If not, see <https://www.gnu.org/licenses/>.
    17  
    18  package ota
    19  
    20  import (
    21  	"context"
    22  	"errors"
    23  	"fmt"
    24  	"github.com/sirupsen/logrus"
    25  	"os"
    26  	"path/filepath"
    27  
    28  	"github.com/arduino/arduino-cloud-cli/config"
    29  	"github.com/arduino/arduino-cloud-cli/internal/iot"
    30  	"github.com/arduino/arduino-cloud-cli/internal/ota"
    31  	otaapi "github.com/arduino/arduino-cloud-cli/internal/ota-api"
    32  
    33  	iotclient "github.com/arduino/iot-client-go"
    34  )
    35  
    36  const (
    37  	numConcurrentUploads = 10
    38  )
    39  
    40  // MassUploadParams contains the parameters needed to
    41  // perform a Mass OTA upload.
    42  type MassUploadParams struct {
    43  	DeviceIDs        []string
    44  	Tags             map[string]string
    45  	File             string
    46  	Deferred         bool
    47  	DoNotApplyHeader bool
    48  	FQBN             string
    49  }
    50  
    51  // Result of an ota upload on a device.
    52  type Result struct {
    53  	ID        string
    54  	Err       error
    55  	OtaStatus otaapi.Ota
    56  }
    57  
    58  func buildOtaFile(params *MassUploadParams) (string, string, error) {
    59  	var otaFile string
    60  	var otaDir string
    61  	var err error
    62  	if params.DoNotApplyHeader {
    63  		otaFile = params.File
    64  	} else {
    65  		otaDir, err = os.MkdirTemp("", "")
    66  		if err != nil {
    67  			return "", "", fmt.Errorf("%s: %w", "cannot create temporary folder", err)
    68  		}
    69  		otaFile = filepath.Join(otaDir, "temp.ota")
    70  
    71  		err = Generate(params.File, otaFile, params.FQBN)
    72  		if err != nil {
    73  			return "", "", fmt.Errorf("%s: %w", "cannot generate .ota file", err)
    74  		}
    75  	}
    76  	return otaFile, otaDir, nil
    77  }
    78  
    79  // MassUpload command is used to mass upload a firmware OTA,
    80  // on devices of Arduino IoT Cloud.
    81  func MassUpload(ctx context.Context, params *MassUploadParams, cred *config.Credentials) ([]Result, error) {
    82  	if params.DeviceIDs == nil && params.Tags == nil {
    83  		return nil, errors.New("provide either DeviceIDs or Tags")
    84  	} else if params.DeviceIDs != nil && params.Tags != nil {
    85  		return nil, errors.New("cannot use both DeviceIDs and Tags. only one of them should be not nil")
    86  	}
    87  
    88  	// Generate .ota file
    89  	logrus.Infoln("Uploading binary", params.File)
    90  	_, err := os.Stat(params.File)
    91  	if err != nil {
    92  		return nil, fmt.Errorf("file %s does not exists: %w", params.File, err)
    93  	}
    94  
    95  	if !params.DoNotApplyHeader {
    96  		//Verify if file has already an OTA header
    97  		header, _ := ota.DecodeOtaFirmwareHeaderFromFile(params.File)
    98  		if header != nil {
    99  			params.DoNotApplyHeader = true
   100  		}
   101  	}
   102  
   103  	// Generate .ota file
   104  	otaFile, otaDir, err := buildOtaFile(params)
   105  	if err != nil {
   106  		return nil, err
   107  	}
   108  	if otaDir != "" {
   109  		defer os.RemoveAll(otaDir)
   110  	}
   111  
   112  	iotClient, err := iot.NewClient(cred)
   113  	if err != nil {
   114  		return nil, err
   115  	}
   116  	otapi := otaapi.NewClient(cred)
   117  
   118  	// Prepare the list of device-ids to update
   119  	d, err := idsGivenTags(ctx, iotClient, params.Tags)
   120  	if err != nil {
   121  		return nil, err
   122  	}
   123  	d = append(params.DeviceIDs, d...)
   124  	valid, invalid, err := validateDevices(ctx, iotClient, d, params.FQBN)
   125  	if err != nil {
   126  		return nil, fmt.Errorf("failed to validate devices: %w", err)
   127  	}
   128  	if len(valid) == 0 {
   129  		return invalid, nil
   130  	}
   131  
   132  	expiration := otaExpirationMins
   133  	if params.Deferred {
   134  		expiration = otaDeferredExpirationMins
   135  	}
   136  
   137  	res := run(ctx, iotClient, otapi, valid, otaFile, expiration)
   138  	res = append(res, invalid...)
   139  	return res, nil
   140  }
   141  
   142  type deviceLister interface {
   143  	DeviceList(ctx context.Context, tags map[string]string) ([]iotclient.ArduinoDevicev2, error)
   144  }
   145  
   146  func idsGivenTags(ctx context.Context, lister deviceLister, tags map[string]string) ([]string, error) {
   147  	if tags == nil {
   148  		return nil, nil
   149  	}
   150  	devs, err := lister.DeviceList(ctx, tags)
   151  	if err != nil {
   152  		return nil, fmt.Errorf("%s: %w", "cannot retrieve devices from cloud", err)
   153  	}
   154  	devices := make([]string, 0, len(devs))
   155  	for _, d := range devs {
   156  		devices = append(devices, d.Id)
   157  	}
   158  	return devices, nil
   159  }
   160  
   161  func validateDevices(ctx context.Context, lister deviceLister, ids []string, fqbn string) (valid []string, invalid []Result, err error) {
   162  	devs, err := lister.DeviceList(ctx, nil)
   163  	if err != nil {
   164  		return nil, nil, fmt.Errorf("%s: %w", "cannot retrieve devices from cloud", err)
   165  	}
   166  
   167  	for _, id := range ids {
   168  		var found *iotclient.ArduinoDevicev2
   169  		for _, d := range devs {
   170  			if d.Id == id {
   171  				found = &d
   172  				break
   173  			}
   174  		}
   175  		// Device not found on the cloud
   176  		if found == nil {
   177  			inv := Result{ID: id, Err: fmt.Errorf("not found")}
   178  			invalid = append(invalid, inv)
   179  			continue
   180  		}
   181  		// Device FQBN doesn't match the passed one
   182  		if found.Fqbn != fqbn {
   183  			inv := Result{ID: id, Err: fmt.Errorf("has FQBN '%s' instead of '%s'", found.Fqbn, fqbn)}
   184  			invalid = append(invalid, inv)
   185  			continue
   186  		}
   187  		valid = append(valid, id)
   188  	}
   189  	return valid, invalid, nil
   190  }
   191  
   192  type otaUploader interface {
   193  	DeviceOTA(ctx context.Context, id string, file *os.File, expireMins int) error
   194  }
   195  
   196  type otaStatusGetter interface {
   197  	GetOtaLastStatusByDeviceID(deviceID string) (*otaapi.OtaStatusList, error)
   198  }
   199  
   200  func run(ctx context.Context, uploader otaUploader, otapi otaStatusGetter, ids []string, otaFile string, expiration int) []Result {
   201  	type job struct {
   202  		id   string
   203  		file *os.File
   204  	}
   205  	jobs := make(chan job, len(ids))
   206  
   207  	resCh := make(chan Result, len(ids))
   208  	results := make([]Result, 0, len(ids))
   209  
   210  	for _, id := range ids {
   211  		file, err := os.Open(otaFile)
   212  		if err != nil {
   213  			logrus.Error("cannot open ota file:", otaFile)
   214  			r := Result{ID: id, Err: fmt.Errorf("cannot open ota file")}
   215  			results = append(results, r)
   216  			continue
   217  		}
   218  		defer file.Close()
   219  		jobs <- job{id: id, file: file}
   220  	}
   221  	close(jobs)
   222  
   223  	logrus.Infoln("Uploading firmware to devices...")
   224  	for i := 0; i < numConcurrentUploads; i++ {
   225  		go func() {
   226  			for job := range jobs {
   227  				err := uploader.DeviceOTA(ctx, job.id, job.file, expiration)
   228  				otaResult := Result{ID: job.id, Err: err}
   229  
   230  				otaID, otaapierr := otapi.GetOtaLastStatusByDeviceID(job.id)
   231  				if otaapierr == nil && otaID != nil && len(otaID.Ota) > 0 {
   232  					otaResult.OtaStatus = otaID.Ota[0]
   233  				}
   234  
   235  				resCh <- otaResult
   236  			}
   237  		}()
   238  	}
   239  
   240  	for range ids {
   241  		r := <-resCh
   242  		results = append(results, r)
   243  	}
   244  	return results
   245  }