git.sr.ht/~pingoo/stdx@v0.0.0-20240218134121-094174641f6e/selfupdate/update.go (about)

     1  package selfupdate
     2  
     3  import (
     4  	"archive/tar"
     5  	"archive/zip"
     6  	"bytes"
     7  	"compress/gzip"
     8  	"context"
     9  	"encoding/hex"
    10  	"encoding/json"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"log/slog"
    15  	"net/http"
    16  	"os"
    17  	"path/filepath"
    18  	"runtime"
    19  	"strings"
    20  
    21  	"git.sr.ht/~pingoo/stdx/httpx"
    22  	"git.sr.ht/~pingoo/stdx/log/slogx"
    23  	"git.sr.ht/~pingoo/stdx/semver"
    24  	"git.sr.ht/~pingoo/stdx/zign"
    25  )
    26  
    27  func (updater *Updater) CheckUpdate(ctx context.Context) (manifest ChannelManifest, err error) {
    28  	logger := slogx.FromCtx(ctx)
    29  
    30  	manifestUrl := fmt.Sprintf("%s/%s.json", updater.baseUrl, updater.releaseChannel)
    31  
    32  	logger.Debug("selfupdate.CheckUpdate: fetching channel manifest", slog.String("url", manifestUrl))
    33  
    34  	req, err := http.NewRequestWithContext(ctx, http.MethodGet, manifestUrl, nil)
    35  	if err != nil {
    36  		err = fmt.Errorf("selfupdate.CheckUpdate: creating channel manifest HTTP request: %w", err)
    37  		return
    38  	}
    39  
    40  	req.Header.Add(httpx.HeaderAccept, httpx.MediaTypeJson)
    41  	req.Header.Add(httpx.HeaderUserAgent, updater.userAgent)
    42  
    43  	res, err := updater.httpClient.Do(req)
    44  	if err != nil {
    45  		err = fmt.Errorf("selfupdate.CheckUpdate: fetching channel manifest: %w", err)
    46  		return
    47  	}
    48  	defer res.Body.Close()
    49  
    50  	if res.StatusCode != http.StatusOK {
    51  		err = fmt.Errorf("selfupdate.CheckUpdate: Status code is not 200 when fetching channel manifest: %d", res.StatusCode)
    52  		return
    53  
    54  	}
    55  
    56  	mainfestData, err := io.ReadAll(res.Body)
    57  	if err != nil {
    58  		err = fmt.Errorf("selfupdate.CheckUpdate: Reading manifest response: %d", res.StatusCode)
    59  		return
    60  	}
    61  
    62  	err = json.Unmarshal(mainfestData, &manifest)
    63  	if err != nil {
    64  		err = fmt.Errorf("selfupdate.CheckUpdate: parsing manifest: %w", err)
    65  		return
    66  	}
    67  
    68  	if !semver.IsValid(manifest.Version) {
    69  		err = fmt.Errorf("selfupdate.CheckUpdate: version (%s) is not a valid semantic version string", manifest.Version)
    70  		return
    71  	}
    72  
    73  	updater.latestVersionAvailable = manifest.Version
    74  
    75  	return
    76  }
    77  
    78  func (updater *Updater) Update(ctx context.Context, channelManifest ChannelManifest) (err error) {
    79  	updater.updateInProgress.Lock()
    80  	defer updater.updateInProgress.Unlock()
    81  
    82  	zignManifest, err := updater.fetchZignManifest(ctx, channelManifest)
    83  	if err != nil {
    84  		return
    85  	}
    86  
    87  	tmpDir, err := os.MkdirTemp("", channelManifest.Name+"_autoupdate_"+channelManifest.Version)
    88  	if err != nil {
    89  		err = fmt.Errorf("selfupdate: creating temporary directory: %w", err)
    90  		return
    91  	}
    92  	destPath := filepath.Join(tmpDir, channelManifest.Name)
    93  
    94  	platform := runtime.GOOS + "_" + runtime.GOARCH
    95  
    96  	artifactExists := false
    97  	var artifactToDownload zign.SignOutput
    98  	for _, artifact := range zignManifest.Files {
    99  		if strings.Contains(artifact.Filename, platform) {
   100  			artifactExists = true
   101  			artifactToDownload = artifact
   102  		}
   103  	}
   104  	if !artifactExists {
   105  		err = fmt.Errorf("selfupdate: No file found for platform: %s", platform)
   106  		return
   107  	}
   108  
   109  	artifactUrl := updater.baseUrl + "/" + channelManifest.Version + "/" + artifactToDownload.Filename
   110  
   111  	res, err := updater.httpClient.Get(artifactUrl)
   112  	if err != nil {
   113  		err = fmt.Errorf("selfupdate: fetching artifact: %w", err)
   114  		return
   115  	}
   116  	defer res.Body.Close()
   117  
   118  	if res.StatusCode != http.StatusOK {
   119  		err = fmt.Errorf("selfupdate: Status code is not 200 when fetching artifact: %d", res.StatusCode)
   120  		return
   121  	}
   122  
   123  	artifactFile, err := io.ReadAll(res.Body)
   124  	if err != nil {
   125  		err = fmt.Errorf("selfupdate: reading artifact's response: %d", res.StatusCode)
   126  		return
   127  	}
   128  
   129  	artifactFileReader := bytes.NewReader(artifactFile)
   130  
   131  	hash, err := hex.DecodeString(artifactToDownload.HashBlake3)
   132  	if err != nil {
   133  		err = fmt.Errorf("selfupdate: decoding hash for file %s: %w", artifactToDownload.Filename, err)
   134  		return
   135  	}
   136  
   137  	zignVeryfiInput := zign.VerifyInput{
   138  		Reader:     artifactFileReader,
   139  		HashBlake3: hash,
   140  		Signature:  artifactToDownload.Signature,
   141  	}
   142  	err = zign.Verify(updater.zingPublicKey, zignVeryfiInput)
   143  	if err != nil {
   144  		err = fmt.Errorf("selfupdate: verifying signature: %w", err)
   145  		return
   146  	}
   147  
   148  	artifactFileReader.Seek(0, io.SeekStart)
   149  
   150  	// handle both .tar.gz and .zip artifacts
   151  	if strings.HasSuffix(artifactToDownload.Filename, ".tar.gz") {
   152  		err = updater.extractTarGzArchive(artifactFileReader, destPath)
   153  	} else if strings.HasSuffix(artifactToDownload.Filename, ".zip") {
   154  		err = updater.extractZipArchive(artifactFileReader, int64(artifactFileReader.Len()), destPath)
   155  	} else {
   156  		err = fmt.Errorf("selfupdate: unsupported archive format: %s", filepath.Ext(artifactToDownload.Filename))
   157  	}
   158  	if err != nil {
   159  		return
   160  	}
   161  
   162  	execPath, err := os.Executable()
   163  	if err != nil {
   164  		err = fmt.Errorf("selfupdate: getting current executable path: %w", err)
   165  		return
   166  	}
   167  
   168  	err = os.Rename(destPath, execPath)
   169  	if err != nil {
   170  		err = fmt.Errorf("selfupdate: moving update to executable path: %w", err)
   171  		return
   172  	}
   173  
   174  	updater.latestVersionInstalled = channelManifest.Version
   175  
   176  	_ = os.RemoveAll(tmpDir)
   177  
   178  	return
   179  }
   180  
   181  func (updater *Updater) extractTarGzArchive(dataReader io.Reader, destPath string) (err error) {
   182  	gzipReader, err := gzip.NewReader(dataReader)
   183  	if err != nil {
   184  		err = fmt.Errorf("selfupdate: creating gzip reader: %w", err)
   185  		return
   186  	}
   187  	defer gzipReader.Close()
   188  
   189  	tarReader := tar.NewReader(gzipReader)
   190  
   191  	fileToExtractHeader, err := tarReader.Next()
   192  	if fileToExtractHeader == nil || err == io.EOF {
   193  		err = errors.New("selfupdate: no file inside .tar.gz archive")
   194  		return
   195  	} else if err != nil {
   196  		err = fmt.Errorf("selfupdate: reading .tar.gz archive: %w", err)
   197  		return
   198  	}
   199  
   200  	if fileToExtractHeader.Typeflag != tar.TypeReg {
   201  		err = fmt.Errorf("selfupdate: reading .tar.gz archive: %s is not a regular file", fileToExtractHeader.Name)
   202  		return
   203  	}
   204  
   205  	updatedExecutable, err := os.OpenFile(destPath, updatedExecutableOpenFlags, fileToExtractHeader.FileInfo().Mode())
   206  	if err != nil {
   207  		err = fmt.Errorf("selfupdate: creating dest file (%s): %w", destPath, err)
   208  		return
   209  	}
   210  	defer updatedExecutable.Close()
   211  
   212  	_, err = io.Copy(updatedExecutable, tarReader)
   213  	if err != nil {
   214  		err = fmt.Errorf("selfupdate: extracting .tar.gzipped file (%s): %w", fileToExtractHeader.Name, err)
   215  		return
   216  	}
   217  
   218  	return
   219  }
   220  
   221  func (updater *Updater) extractZipArchive(dataReader io.ReaderAt, dataLen int64, destPath string) (err error) {
   222  	zipReader, err := zip.NewReader(dataReader, dataLen)
   223  	if err != nil {
   224  		err = fmt.Errorf("selfupdate: creating zip reader: %w", err)
   225  		return
   226  	}
   227  
   228  	zippedFiles := zipReader.File
   229  	if len(zippedFiles) != 1 {
   230  		err = fmt.Errorf("selfupdate: zip archive contains more than 1 file (%d)", len(zippedFiles))
   231  		return
   232  	}
   233  
   234  	zippedFileToExtract := zippedFiles[0]
   235  
   236  	srcFile, err := zippedFileToExtract.Open()
   237  	if err != nil {
   238  		err = fmt.Errorf("selfupdate: Opening zipped file (%s): %w", zippedFileToExtract.Name, err)
   239  		return
   240  	}
   241  	defer srcFile.Close()
   242  
   243  	updatedExecutable, err := os.OpenFile(destPath, updatedExecutableOpenFlags, zippedFileToExtract.Mode())
   244  	if err != nil {
   245  		err = fmt.Errorf("selfupdate: creating dest file (%s): %w", destPath, err)
   246  		return
   247  	}
   248  	defer updatedExecutable.Close()
   249  
   250  	_, err = io.Copy(updatedExecutable, srcFile)
   251  	if err != nil {
   252  		err = fmt.Errorf("selfupdate: extracting zipped file (%s): %w", zippedFileToExtract.Name, err)
   253  		return
   254  	}
   255  
   256  	return
   257  }
   258  
   259  func (updater *Updater) fetchZignManifest(ctx context.Context, channelManifest ChannelManifest) (zignManifest zign.Manifest, err error) {
   260  	logger := slogx.FromCtx(ctx)
   261  
   262  	zignManifestUrl := fmt.Sprintf("%s/%s/zign.json", updater.baseUrl, channelManifest.Version)
   263  
   264  	logger.Debug("selfupdate.fetchZignManifest: fetching zign manifest", slog.String("url", zignManifestUrl))
   265  
   266  	req, err := http.NewRequestWithContext(ctx, http.MethodGet, zignManifestUrl, nil)
   267  	if err != nil {
   268  		err = fmt.Errorf("selfupdate.fetchZignManifest: creating zign manifest HTTP request: %w", err)
   269  		return
   270  	}
   271  
   272  	req.Header.Add(httpx.HeaderAccept, httpx.MediaTypeJson)
   273  	req.Header.Add(httpx.HeaderUserAgent, updater.userAgent)
   274  
   275  	res, err := updater.httpClient.Do(req)
   276  	if err != nil {
   277  		err = fmt.Errorf("selfupdate.fetchZignManifest: fetching zign manifest: %w", err)
   278  		return
   279  	}
   280  	defer res.Body.Close()
   281  
   282  	if res.StatusCode != http.StatusOK {
   283  		err = fmt.Errorf("selfupdate.fetchZignManifest: Status code is not 200 when fetching zign manifest: %d", res.StatusCode)
   284  		return
   285  
   286  	}
   287  
   288  	mainfestData, err := io.ReadAll(res.Body)
   289  	if err != nil {
   290  		err = fmt.Errorf("selfupdate.fetchZignManifest: Reading manifest response: %d", res.StatusCode)
   291  		return
   292  	}
   293  
   294  	err = json.Unmarshal(mainfestData, &zignManifest)
   295  	if err != nil {
   296  		err = fmt.Errorf("selfupdate.fetchZignManifest: parsing zign manifest: %w", err)
   297  		return
   298  	}
   299  	return
   300  }