golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/updater/downloader.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package updater
     7  
     8  import (
     9  	"crypto/hmac"
    10  	"errors"
    11  	"fmt"
    12  	"hash"
    13  	"io"
    14  	"sync/atomic"
    15  
    16  	"golang.org/x/crypto/blake2b"
    17  
    18  	"golang.zx2c4.com/wireguard/windows/elevate"
    19  	"golang.zx2c4.com/wireguard/windows/updater/winhttp"
    20  	"golang.zx2c4.com/wireguard/windows/version"
    21  )
    22  
    23  type DownloadProgress struct {
    24  	Activity        string
    25  	BytesDownloaded uint64
    26  	BytesTotal      uint64
    27  	Error           error
    28  	Complete        bool
    29  }
    30  
    31  type progressHashWatcher struct {
    32  	dp        *DownloadProgress
    33  	c         chan DownloadProgress
    34  	hashState hash.Hash
    35  }
    36  
    37  func (pm *progressHashWatcher) Write(p []byte) (int, error) {
    38  	bytes := len(p)
    39  	pm.dp.BytesDownloaded += uint64(bytes)
    40  	pm.c <- *pm.dp
    41  	pm.hashState.Write(p)
    42  	return bytes, nil
    43  }
    44  
    45  type UpdateFound struct {
    46  	name string
    47  	hash [blake2b.Size256]byte
    48  }
    49  
    50  func CheckForUpdate() (updateFound *UpdateFound, err error) {
    51  	updateFound, _, _, err = checkForUpdate(false)
    52  	return
    53  }
    54  
    55  func checkForUpdate(keepSession bool) (*UpdateFound, *winhttp.Session, *winhttp.Connection, error) {
    56  	if !version.IsRunningOfficialVersion() {
    57  		return nil, nil, nil, errors.New("Build is not official, so updates are disabled")
    58  	}
    59  	session, err := winhttp.NewSession(version.UserAgent())
    60  	if err != nil {
    61  		return nil, nil, nil, err
    62  	}
    63  	defer func() {
    64  		if err != nil || !keepSession {
    65  			session.Close()
    66  		}
    67  	}()
    68  	connection, err := session.Connect(updateServerHost, updateServerPort, updateServerUseHttps)
    69  	if err != nil {
    70  		return nil, nil, nil, err
    71  	}
    72  	defer func() {
    73  		if err != nil || !keepSession {
    74  			connection.Close()
    75  		}
    76  	}()
    77  	response, err := connection.Get(latestVersionPath, true)
    78  	if err != nil {
    79  		return nil, nil, nil, err
    80  	}
    81  	defer response.Close()
    82  	var fileList [1024 * 512] /* 512 KiB */ byte
    83  	bytesRead, err := response.Read(fileList[:])
    84  	if err != nil && (err != io.EOF || bytesRead == 0) {
    85  		return nil, nil, nil, err
    86  	}
    87  	files, err := readFileList(fileList[:bytesRead])
    88  	if err != nil {
    89  		return nil, nil, nil, err
    90  	}
    91  	updateFound, err := findCandidate(files)
    92  	if err != nil {
    93  		return nil, nil, nil, err
    94  	}
    95  	if keepSession {
    96  		return updateFound, session, connection, nil
    97  	}
    98  	return updateFound, nil, nil, nil
    99  }
   100  
   101  var updateInProgress = uint32(0)
   102  
   103  func DownloadVerifyAndExecute(userToken uintptr) (progress chan DownloadProgress) {
   104  	progress = make(chan DownloadProgress, 128)
   105  	progress <- DownloadProgress{Activity: "Initializing"}
   106  
   107  	if !atomic.CompareAndSwapUint32(&updateInProgress, 0, 1) {
   108  		progress <- DownloadProgress{Error: errors.New("An update is already in progress")}
   109  		return
   110  	}
   111  
   112  	doIt := func() {
   113  		defer atomic.StoreUint32(&updateInProgress, 0)
   114  
   115  		progress <- DownloadProgress{Activity: "Checking for update"}
   116  		update, session, connection, err := checkForUpdate(true)
   117  		if err != nil {
   118  			progress <- DownloadProgress{Error: err}
   119  			return
   120  		}
   121  		defer connection.Close()
   122  		defer session.Close()
   123  		if update == nil {
   124  			progress <- DownloadProgress{Error: errors.New("No update was found")}
   125  			return
   126  		}
   127  
   128  		progress <- DownloadProgress{Activity: "Creating temporary file"}
   129  		file, err := msiTempFile()
   130  		if err != nil {
   131  			progress <- DownloadProgress{Error: err}
   132  			return
   133  		}
   134  		progress <- DownloadProgress{Activity: fmt.Sprintf("Msi destination is %#q", file.Name())}
   135  		defer func() {
   136  			if file != nil {
   137  				file.Delete()
   138  			}
   139  		}()
   140  
   141  		dp := DownloadProgress{Activity: "Downloading update"}
   142  		progress <- dp
   143  		response, err := connection.Get(fmt.Sprintf(msiPath, update.name), false)
   144  		if err != nil {
   145  			progress <- DownloadProgress{Error: err}
   146  			return
   147  		}
   148  		defer response.Close()
   149  		length, err := response.Length()
   150  		if err == nil && length >= 0 {
   151  			dp.BytesTotal = length
   152  			progress <- dp
   153  		}
   154  		hasher, err := blake2b.New256(nil)
   155  		if err != nil {
   156  			progress <- DownloadProgress{Error: err}
   157  			return
   158  		}
   159  		pm := &progressHashWatcher{&dp, progress, hasher}
   160  		_, err = io.Copy(file, io.TeeReader(io.LimitReader(response, 1024*1024*100 /* 100 MiB */), pm))
   161  		if err != nil {
   162  			progress <- DownloadProgress{Error: err}
   163  			return
   164  		}
   165  		if !hmac.Equal(hasher.Sum(nil), update.hash[:]) {
   166  			progress <- DownloadProgress{Error: errors.New("The downloaded update has the wrong hash")}
   167  			return
   168  		}
   169  
   170  		progress <- DownloadProgress{Activity: "Verifying authenticode signature"}
   171  		if !verifyAuthenticode(file.ExclusivePath()) {
   172  			progress <- DownloadProgress{Error: errors.New("The downloaded update does not have an authentic authenticode signature")}
   173  			return
   174  		}
   175  
   176  		progress <- DownloadProgress{Activity: "Installing update"}
   177  		err = runMsi(file, userToken)
   178  		if err != nil {
   179  			progress <- DownloadProgress{Error: err}
   180  			return
   181  		}
   182  
   183  		progress <- DownloadProgress{Complete: true}
   184  	}
   185  	if userToken == 0 {
   186  		go func() {
   187  			err := elevate.DoAsSystem(func() error {
   188  				doIt()
   189  				return nil
   190  			})
   191  			if err != nil {
   192  				progress <- DownloadProgress{Error: err}
   193  			}
   194  		}()
   195  	} else {
   196  		go doIt()
   197  	}
   198  
   199  	return progress
   200  }