github.com/arduino/arduino-cloud-cli@v0.0.0-20240517070944-e7a449561083/internal/ota-api/client.go (about)

     1  // This file is part of arduino-cloud-cli.
     2  //
     3  // Copyright (C) 2021 ARDUINO SA (http://www.arduino.cc/)
     4  //
     5  // This program is free software: you can redistribute it and/or modify
     6  // it under the terms of the GNU Affero General Public License as published
     7  // by the Free Software Foundation, either version 3 of the License, or
     8  // (at your option) any later version.
     9  //
    10  // This program is distributed in the hope that it will be useful,
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13  // GNU Affero General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Affero General Public License
    16  // along with this program.  If not, see <https://www.gnu.org/licenses/>.
    17  
    18  package otaapi
    19  
    20  import (
    21  	"encoding/json"
    22  	"errors"
    23  	"fmt"
    24  	"io"
    25  	"net/http"
    26  	"sort"
    27  	"strings"
    28  	"time"
    29  
    30  	"github.com/arduino/arduino-cloud-cli/config"
    31  	"github.com/arduino/arduino-cloud-cli/internal/iot"
    32  	"golang.org/x/oauth2"
    33  )
    34  
    35  const (
    36  	OrderDesc = "desc"
    37  	OrderAsc  = "asc"
    38  )
    39  
    40  var ErrAlreadyInProgress = fmt.Errorf("already in progress")
    41  var ErrAlreadyCancelled = fmt.Errorf("already cancelled")
    42  
    43  type OtaApiClient struct {
    44  	client       *http.Client
    45  	host         string
    46  	src          oauth2.TokenSource
    47  	organization string
    48  }
    49  
    50  func NewClient(credentials *config.Credentials) *OtaApiClient {
    51  	host := iot.GetArduinoAPIBaseURL()
    52  	tokenSource := iot.NewUserTokenSource(credentials.Client, credentials.Secret, host)
    53  	return &OtaApiClient{
    54  		client:       &http.Client{},
    55  		src:          tokenSource,
    56  		host:         host,
    57  		organization: credentials.Organization,
    58  	}
    59  }
    60  
    61  func (c *OtaApiClient) performGetRequest(endpoint, token string) (*http.Response, error) {
    62  	return c.performRequest(endpoint, "GET", token)
    63  }
    64  
    65  func (c *OtaApiClient) performRequest(endpoint, method, token string) (*http.Response, error) {
    66  	req, err := http.NewRequest(method, endpoint, nil)
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  	req.Header.Add("Authorization", "Bearer "+token)
    71  	req.Header.Add("Content-Type", "application/json")
    72  	if c.organization != "" {
    73  		req.Header.Add("X-Organization", c.organization)
    74  	}
    75  	res, err := c.client.Do(req)
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  	return res, nil
    80  }
    81  
    82  func (c *OtaApiClient) GetOtaStatusByOtaID(otaid string, limit int, order string) (*OtaStatusResponse, error) {
    83  
    84  	if otaid == "" {
    85  		return nil, fmt.Errorf("invalid ota-id: empty")
    86  	}
    87  
    88  	userRequestToken, err := c.src.Token()
    89  	if err != nil {
    90  		if strings.Contains(err.Error(), "401") {
    91  			return nil, errors.New("wrong credentials")
    92  		}
    93  		return nil, fmt.Errorf("cannot retrieve a valid token: %w", err)
    94  	}
    95  
    96  	endpoint := c.host + "/ota/v1/ota/" + otaid
    97  	res, err := c.performGetRequest(endpoint, userRequestToken.AccessToken)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  	defer res.Body.Close()
   102  	bodyb, err := io.ReadAll(res.Body)
   103  
   104  	if res.StatusCode == 200 {
   105  		var otaResponse OtaStatusResponse
   106  		if err == nil && bodyb != nil {
   107  			err = json.Unmarshal(bodyb, &otaResponse)
   108  			if err != nil {
   109  				return nil, err
   110  			}
   111  		}
   112  
   113  		if len(otaResponse.States) > 0 {
   114  			// Sort output by StartedAt
   115  			sort.Slice(otaResponse.States, func(i, j int) bool {
   116  				t1, err := time.Parse(time.RFC3339, otaResponse.States[i].Timestamp)
   117  				if err != nil {
   118  					return false
   119  				}
   120  				t2, err := time.Parse(time.RFC3339, otaResponse.States[j].Timestamp)
   121  				if err != nil {
   122  					return false
   123  				}
   124  				if order == "asc" {
   125  					return t1.Before(t2)
   126  				}
   127  				return t1.After(t2)
   128  			})
   129  			if limit > 0 && len(otaResponse.States) > limit {
   130  				otaResponse.States = otaResponse.States[:limit]
   131  			}
   132  		}
   133  
   134  		return &otaResponse, nil
   135  	} else if res.StatusCode == 404 || res.StatusCode == 400 {
   136  		return nil, fmt.Errorf("ota-id %s not found", otaid)
   137  	}
   138  
   139  	return nil, err
   140  }
   141  
   142  func (c *OtaApiClient) GetOtaStatusByOtaIDs(otaids string) (*OtaStatusList, error) {
   143  
   144  	ids := strings.Split(otaids, ",")
   145  	if len(ids) == 0 {
   146  		return nil, fmt.Errorf("invalid ota-ids: empty")
   147  	}
   148  
   149  	returnStatus := OtaStatusList{}
   150  	for _, id := range ids {
   151  		if id != "" {
   152  			resp, err := c.GetOtaStatusByOtaID(id, 1, OrderDesc)
   153  			if err != nil {
   154  				return nil, err
   155  			}
   156  			returnStatus.Ota = append(returnStatus.Ota, resp.Ota)
   157  		}
   158  
   159  	}
   160  
   161  	return &returnStatus, nil
   162  }
   163  
   164  func (c *OtaApiClient) GetOtaLastStatusByDeviceID(deviceID string) (*OtaStatusList, error) {
   165  	return c.GetOtaStatusByDeviceID(deviceID, 1, OrderDesc)
   166  }
   167  
   168  func (c *OtaApiClient) GetOtaStatusByDeviceID(deviceID string, limit int, order string) (*OtaStatusList, error) {
   169  
   170  	if deviceID == "" {
   171  		return nil, fmt.Errorf("invalid device-id: empty")
   172  	}
   173  
   174  	userRequestToken, err := c.src.Token()
   175  	if err != nil {
   176  		if strings.Contains(err.Error(), "401") {
   177  			return nil, errors.New("wrong credentials")
   178  		}
   179  		return nil, fmt.Errorf("cannot retrieve a valid token: %w", err)
   180  	}
   181  
   182  	endpoint := c.host + "/ota/v1/ota?device_id=" + deviceID
   183  	if limit > 0 {
   184  		endpoint += "&limit=" + fmt.Sprintf("%d", limit)
   185  	}
   186  	if order != "" && (order == "asc" || order == "desc") {
   187  		endpoint += "&order=" + order
   188  	}
   189  	res, err := c.performGetRequest(endpoint, userRequestToken.AccessToken)
   190  	if err != nil {
   191  		return nil, err
   192  	}
   193  	defer res.Body.Close()
   194  	bodyb, err := io.ReadAll(res.Body)
   195  
   196  	if res.StatusCode == 200 {
   197  		var otaResponse OtaStatusList
   198  		if err == nil && bodyb != nil {
   199  			err = json.Unmarshal(bodyb, &otaResponse)
   200  			if err != nil {
   201  				return nil, err
   202  			}
   203  		}
   204  		return &otaResponse, nil
   205  	} else if res.StatusCode == 404 || res.StatusCode == 400 {
   206  		return nil, fmt.Errorf("device-id %s not found", deviceID)
   207  	} else if res.StatusCode == 409 {
   208  		return nil, ErrAlreadyInProgress
   209  	}
   210  
   211  	return nil, err
   212  }
   213  
   214  func (c *OtaApiClient) CancelOta(otaid string) (bool, error) {
   215  
   216  	if otaid == "" {
   217  		return false, fmt.Errorf("invalid ota-id: empty")
   218  	}
   219  
   220  	userRequestToken, err := c.src.Token()
   221  	if err != nil {
   222  		if strings.Contains(err.Error(), "401") {
   223  			return false, errors.New("wrong credentials")
   224  		}
   225  		return false, fmt.Errorf("cannot retrieve a valid token: %w", err)
   226  	}
   227  
   228  	endpoint := c.host + "/ota/v1/ota/" + otaid + "/cancel"
   229  	res, err := c.performRequest(endpoint, "PUT", userRequestToken.AccessToken)
   230  	if err != nil {
   231  		return false, err
   232  	}
   233  	defer res.Body.Close()
   234  
   235  	if res.StatusCode == 200 {
   236  		return true, nil
   237  	} else if res.StatusCode == 404 || res.StatusCode == 400 {
   238  		return false, fmt.Errorf("ota-id %s not found", otaid)
   239  	} else if res.StatusCode == 409 {
   240  		return false, ErrAlreadyCancelled
   241  	}
   242  
   243  	return false, err
   244  }