golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/updater/downloader.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package updater 7 8 import ( 9 "crypto/hmac" 10 "errors" 11 "fmt" 12 "hash" 13 "io" 14 "sync/atomic" 15 16 "golang.org/x/crypto/blake2b" 17 18 "golang.zx2c4.com/wireguard/windows/elevate" 19 "golang.zx2c4.com/wireguard/windows/updater/winhttp" 20 "golang.zx2c4.com/wireguard/windows/version" 21 ) 22 23 type DownloadProgress struct { 24 Activity string 25 BytesDownloaded uint64 26 BytesTotal uint64 27 Error error 28 Complete bool 29 } 30 31 type progressHashWatcher struct { 32 dp *DownloadProgress 33 c chan DownloadProgress 34 hashState hash.Hash 35 } 36 37 func (pm *progressHashWatcher) Write(p []byte) (int, error) { 38 bytes := len(p) 39 pm.dp.BytesDownloaded += uint64(bytes) 40 pm.c <- *pm.dp 41 pm.hashState.Write(p) 42 return bytes, nil 43 } 44 45 type UpdateFound struct { 46 name string 47 hash [blake2b.Size256]byte 48 } 49 50 func CheckForUpdate() (updateFound *UpdateFound, err error) { 51 updateFound, _, _, err = checkForUpdate(false) 52 return 53 } 54 55 func checkForUpdate(keepSession bool) (*UpdateFound, *winhttp.Session, *winhttp.Connection, error) { 56 if !version.IsRunningOfficialVersion() { 57 return nil, nil, nil, errors.New("Build is not official, so updates are disabled") 58 } 59 session, err := winhttp.NewSession(version.UserAgent()) 60 if err != nil { 61 return nil, nil, nil, err 62 } 63 defer func() { 64 if err != nil || !keepSession { 65 session.Close() 66 } 67 }() 68 connection, err := session.Connect(updateServerHost, updateServerPort, updateServerUseHttps) 69 if err != nil { 70 return nil, nil, nil, err 71 } 72 defer func() { 73 if err != nil || !keepSession { 74 connection.Close() 75 } 76 }() 77 response, err := connection.Get(latestVersionPath, true) 78 if err != nil { 79 return nil, nil, nil, err 80 } 81 defer response.Close() 82 var fileList [1024 * 512] /* 512 KiB */ byte 83 bytesRead, err := response.Read(fileList[:]) 84 if err != nil && (err != io.EOF || bytesRead == 0) { 85 return nil, nil, nil, err 86 } 87 files, err := readFileList(fileList[:bytesRead]) 88 if err != nil { 89 return nil, nil, nil, err 90 } 91 updateFound, err := findCandidate(files) 92 if err != nil { 93 return nil, nil, nil, err 94 } 95 if keepSession { 96 return updateFound, session, connection, nil 97 } 98 return updateFound, nil, nil, nil 99 } 100 101 var updateInProgress = uint32(0) 102 103 func DownloadVerifyAndExecute(userToken uintptr) (progress chan DownloadProgress) { 104 progress = make(chan DownloadProgress, 128) 105 progress <- DownloadProgress{Activity: "Initializing"} 106 107 if !atomic.CompareAndSwapUint32(&updateInProgress, 0, 1) { 108 progress <- DownloadProgress{Error: errors.New("An update is already in progress")} 109 return 110 } 111 112 doIt := func() { 113 defer atomic.StoreUint32(&updateInProgress, 0) 114 115 progress <- DownloadProgress{Activity: "Checking for update"} 116 update, session, connection, err := checkForUpdate(true) 117 if err != nil { 118 progress <- DownloadProgress{Error: err} 119 return 120 } 121 defer connection.Close() 122 defer session.Close() 123 if update == nil { 124 progress <- DownloadProgress{Error: errors.New("No update was found")} 125 return 126 } 127 128 progress <- DownloadProgress{Activity: "Creating temporary file"} 129 file, err := msiTempFile() 130 if err != nil { 131 progress <- DownloadProgress{Error: err} 132 return 133 } 134 progress <- DownloadProgress{Activity: fmt.Sprintf("Msi destination is %#q", file.Name())} 135 defer func() { 136 if file != nil { 137 file.Delete() 138 } 139 }() 140 141 dp := DownloadProgress{Activity: "Downloading update"} 142 progress <- dp 143 response, err := connection.Get(fmt.Sprintf(msiPath, update.name), false) 144 if err != nil { 145 progress <- DownloadProgress{Error: err} 146 return 147 } 148 defer response.Close() 149 length, err := response.Length() 150 if err == nil && length >= 0 { 151 dp.BytesTotal = length 152 progress <- dp 153 } 154 hasher, err := blake2b.New256(nil) 155 if err != nil { 156 progress <- DownloadProgress{Error: err} 157 return 158 } 159 pm := &progressHashWatcher{&dp, progress, hasher} 160 _, err = io.Copy(file, io.TeeReader(io.LimitReader(response, 1024*1024*100 /* 100 MiB */), pm)) 161 if err != nil { 162 progress <- DownloadProgress{Error: err} 163 return 164 } 165 if !hmac.Equal(hasher.Sum(nil), update.hash[:]) { 166 progress <- DownloadProgress{Error: errors.New("The downloaded update has the wrong hash")} 167 return 168 } 169 170 progress <- DownloadProgress{Activity: "Verifying authenticode signature"} 171 if !verifyAuthenticode(file.ExclusivePath()) { 172 progress <- DownloadProgress{Error: errors.New("The downloaded update does not have an authentic authenticode signature")} 173 return 174 } 175 176 progress <- DownloadProgress{Activity: "Installing update"} 177 err = runMsi(file, userToken) 178 if err != nil { 179 progress <- DownloadProgress{Error: err} 180 return 181 } 182 183 progress <- DownloadProgress{Complete: true} 184 } 185 if userToken == 0 { 186 go func() { 187 err := elevate.DoAsSystem(func() error { 188 doIt() 189 return nil 190 }) 191 if err != nil { 192 progress <- DownloadProgress{Error: err} 193 } 194 }() 195 } else { 196 go doIt() 197 } 198 199 return progress 200 }