golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/installer/fetcher/fetcher.c (about) 1 // SPDX-License-Identifier: GPL-2.0 2 /* 3 * Copyright (C) 2020-2022 Jason A. Donenfeld. All Rights Reserved. 4 */ 5 6 #include <windows.h> 7 #include <delayimp.h> 8 #include <commctrl.h> 9 #include <shlwapi.h> 10 #include <ntsecapi.h> 11 #include <sddl.h> 12 #include <winhttp.h> 13 #include <wintrust.h> 14 #include <softpub.h> 15 #include <msi.h> 16 #include <stdio.h> 17 #include <string.h> 18 #include <stdbool.h> 19 #include <wchar.h> 20 #include "filelist.h" 21 #include "crypto.h" 22 #include "systeminfo.h" 23 #include "constants.h" 24 25 static char msi_filename[MAX_PATH]; 26 static volatile bool msi_filename_is_set, prompts = true; 27 static volatile size_t g_current, g_total; 28 static HWND progress; 29 static HANDLE filehandle = INVALID_HANDLE_VALUE; 30 31 static wchar_t *L(const char *a) 32 { 33 static wchar_t w[0x2000]; 34 if (!MultiByteToWideChar(CP_UTF8, 0, a, -1, w, sizeof(w))) 35 abort(); 36 return w; 37 } 38 39 static bool random_string(char hex[static 65]) 40 { 41 uint8_t bytes[32]; 42 if (!RtlGenRandom(bytes, sizeof(bytes))) 43 return false; 44 for (int i = 0; i < 32; ++i) { 45 hex[i * 2] = 87U + (bytes[i] >> 4) + ((((bytes[i] >> 4) - 10U) >> 8) & ~38U); 46 hex[i * 2 + 1] = 87U + (bytes[i] & 0xf) + ((((bytes[i] & 0xf) - 10U) >> 8) & ~38U); 47 } 48 hex[64] = '\0'; 49 return true; 50 } 51 52 static void set_status(HWND progress, const char *status) 53 { 54 LONG_PTR current_style = GetWindowLongPtrA(progress, GWL_STYLE); 55 char buf[0x1000]; 56 g_total = 0; 57 _snprintf_s(buf, sizeof(buf), _TRUNCATE, "WireGuard: %s...", status); 58 SetWindowTextA(progress, buf); 59 if (!(current_style & PBS_MARQUEE)) { 60 SendMessageA(progress, PBM_SETRANGE32, 0, 100); 61 SendMessageA(progress, PBM_SETPOS, 0, 0); 62 SetWindowLongPtrA(progress, GWL_STYLE, current_style | PBS_MARQUEE); 63 SendMessageA(progress, PBM_SETMARQUEE, TRUE, 0); 64 } 65 } 66 67 static void set_progress(HWND progress, size_t current, size_t total) 68 { 69 g_current = current; 70 g_total = total; 71 PostMessageA(progress, WM_APP, 0, 0); 72 } 73 74 static DWORD __stdcall download_thread(void *param) 75 { 76 DWORD ret = 1, bytes_read, bytes_written, enable_http2 = WINHTTP_PROTOCOL_FLAG_HTTP2; 77 HINTERNET session = NULL, connection = NULL, request = NULL; 78 uint8_t hash[32], computed_hash[32], buf[512 * 1024]; 79 char download_path[MAX_FILENAME_LEN + sizeof(msi_path)], random_filename[65]; 80 wchar_t total_bytes_str[22]; 81 size_t total_bytes, current_bytes; 82 const char *arch; 83 struct blake2b256_state hasher; 84 SECURITY_ATTRIBUTES security_attributes = { .nLength = sizeof(security_attributes) }; 85 WINTRUST_FILE_INFO wintrust_fileinfo = { .cbStruct = sizeof(wintrust_fileinfo) }; 86 WINTRUST_DATA wintrust_data = { 87 .cbStruct = sizeof(wintrust_data), 88 .dwUIChoice = WTD_UI_NONE, 89 .fdwRevocationChecks = WTD_REVOKE_WHOLECHAIN, 90 .dwUnionChoice = WTD_CHOICE_FILE, 91 .dwStateAction = WTD_STATEACTION_VERIFY, 92 .pFile = &wintrust_fileinfo 93 }; 94 95 (void)param; 96 97 set_status(progress, "determining paths"); 98 if (!ConvertStringSecurityDescriptorToSecurityDescriptorA("O:BAD:PAI(A;;FA;;;BA)", SDDL_REVISION_1, &security_attributes.lpSecurityDescriptor, NULL)) 99 goto out; 100 if (!GetWindowsDirectoryA(msi_filename, sizeof(msi_filename)) || !PathAppendA(msi_filename, "Temp")) 101 goto out; 102 if (!random_string(random_filename)) 103 goto out; 104 if (!PathAppendA(msi_filename, random_filename)) 105 goto out; 106 107 set_status(progress, "determining architecture"); 108 arch = architecture(); 109 if (!arch) 110 goto out; 111 112 set_status(progress, "connecting to server"); 113 session = WinHttpOpen(L(useragent()), is_win7() ? WINHTTP_ACCESS_TYPE_DEFAULT_PROXY : WINHTTP_ACCESS_TYPE_AUTOMATIC_PROXY, NULL, NULL, 0); 114 if (!session) 115 goto out; 116 WinHttpSetOption(session, WINHTTP_OPTION_ENABLE_HTTP_PROTOCOL, &enable_http2, sizeof(enable_http2)); // Don't check return value, in case of old Windows 117 if (is_win8dotzero_or_below()) { 118 DWORD enable_tls12 = WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_2; 119 if (!WinHttpSetOption(session, WINHTTP_OPTION_SECURE_PROTOCOLS, &enable_tls12, sizeof(enable_tls12))) 120 goto out; 121 } 122 123 connection = WinHttpConnect(session, L(server), port, 0); 124 if (!connection) 125 goto out; 126 127 set_status(progress, "downloading installer list"); 128 request = WinHttpOpenRequest(connection, L"GET", L(msi_path latest_version_file), NULL, WINHTTP_NO_REFERER, WINHTTP_DEFAULT_ACCEPT_TYPES, WINHTTP_FLAG_REFRESH | WINHTTP_FLAG_SECURE); 129 if (!request) 130 goto out; 131 if (!WinHttpSendRequest(request, WINHTTP_NO_ADDITIONAL_HEADERS, 0, WINHTTP_NO_REQUEST_DATA, 0, 0, 0)) 132 goto out; 133 if (!WinHttpReceiveResponse(request, NULL)) 134 goto out; 135 if (!WinHttpReadData(request, buf, sizeof(buf), &bytes_read)) 136 goto out; 137 WinHttpCloseHandle(request); 138 request = NULL; 139 if (bytes_read <= 0 || bytes_read >= sizeof(buf)) 140 goto out; 141 142 set_status(progress, "verifying installer list"); 143 memcpy(download_path, msi_path, strlen(msi_path)); 144 if (!extract_newest_file(download_path + strlen(msi_path), hash, (const char *)buf, bytes_read, arch)) 145 goto out; 146 147 set_status(progress, "creating temporary file"); 148 filehandle = CreateFileA(msi_filename, GENERIC_WRITE | DELETE, 0, &security_attributes, CREATE_NEW, FILE_ATTRIBUTE_TEMPORARY, NULL); 149 if (filehandle == INVALID_HANDLE_VALUE) 150 goto out; 151 msi_filename_is_set = true; 152 153 set_status(progress, "downloading installer"); 154 request = WinHttpOpenRequest(connection, L"GET", L(download_path), NULL, WINHTTP_NO_REFERER, WINHTTP_DEFAULT_ACCEPT_TYPES, WINHTTP_FLAG_SECURE); 155 if (!request) 156 goto out; 157 if (!WinHttpSendRequest(request, WINHTTP_NO_ADDITIONAL_HEADERS, 0, WINHTTP_NO_REQUEST_DATA, 0, 0, 0)) 158 goto out; 159 if (!WinHttpReceiveResponse(request, NULL)) 160 goto out; 161 bytes_read = sizeof(total_bytes_str); 162 if (!WinHttpQueryHeaders(request, WINHTTP_QUERY_CONTENT_LENGTH, WINHTTP_HEADER_NAME_BY_INDEX, total_bytes_str, &bytes_read, WINHTTP_NO_HEADER_INDEX)) 163 goto out; 164 total_bytes = wcstoul(total_bytes_str, NULL, 10); 165 if (total_bytes > 100 * 1024 * 1024) 166 goto out; 167 blake2b256_init(&hasher); 168 set_progress(progress, 0, total_bytes); 169 for (current_bytes = 0;;) { 170 if (!WinHttpReadData(request, buf, 8192, &bytes_read)) 171 goto out; 172 if (!bytes_read) 173 break; 174 current_bytes += bytes_read; 175 if (current_bytes > 100 * 1024 * 1024) 176 goto out; 177 blake2b256_update(&hasher, buf, bytes_read); 178 if (!WriteFile(filehandle, buf, bytes_read, &bytes_written, NULL) || bytes_read != bytes_written) 179 goto out; 180 set_progress(progress, current_bytes, total_bytes); 181 } 182 183 set_status(progress, "verifying installer"); 184 blake2b256_final(&hasher, computed_hash); 185 if (memcmp(hash, computed_hash, sizeof(hash))) 186 goto out; 187 CloseHandle(filehandle); //TODO: I wish this wasn't required. 188 filehandle = INVALID_HANDLE_VALUE; 189 wintrust_fileinfo.pcwszFilePath = L(msi_filename); 190 ret = WinVerifyTrustEx(INVALID_HANDLE_VALUE, &(GUID)WINTRUST_ACTION_GENERIC_VERIFY_V2, &wintrust_data); 191 wintrust_data.dwStateAction = WTD_STATEACTION_CLOSE; 192 WinVerifyTrustEx(INVALID_HANDLE_VALUE, &(GUID)WINTRUST_ACTION_GENERIC_VERIFY_V2, &wintrust_data); 193 if (ret) 194 goto out; 195 196 set_status(progress, "launching installer"); 197 ShowWindow(progress, SW_HIDE); 198 ret = MsiInstallProductA(msi_filename, NULL); 199 ret = ret == ERROR_INSTALL_USEREXIT ? ERROR_SUCCESS : ret; 200 201 out: 202 if (request) 203 WinHttpCloseHandle(request); 204 if (connection) 205 WinHttpCloseHandle(connection); 206 if (session) 207 WinHttpCloseHandle(session); 208 if (security_attributes.lpSecurityDescriptor) 209 LocalFree(security_attributes.lpSecurityDescriptor); 210 211 if (ret && prompts) { 212 ShowWindow(progress, SW_SHOWDEFAULT); 213 if (MessageBoxA(progress, "Something went wrong when downloading the WireGuard installer. Would you like to open your web browser to the MSI download page?", "Download Error", MB_YESNO | MB_ICONWARNING) == IDYES) 214 ShellExecuteA(progress, NULL, "https://" server msi_path, NULL, NULL, SW_SHOWNORMAL); 215 } 216 exit(ret); 217 return ret; 218 } 219 220 static int cleanup(void) 221 { 222 BOOL did_delete_via_handle = FALSE; 223 FILE_DISPOSITION_INFO disposition = { TRUE }; 224 if (filehandle != INVALID_HANDLE_VALUE) { 225 did_delete_via_handle = SetFileInformationByHandle(filehandle, FileDispositionInfo, &disposition, sizeof(disposition)); 226 CloseHandle(filehandle); 227 filehandle = INVALID_HANDLE_VALUE; 228 } 229 if (msi_filename_is_set && !did_delete_via_handle) { 230 //TODO: how does DeleteFile deal with reparse points? 231 for (int i = 0; i < 200 && !DeleteFileA(msi_filename) && GetLastError() != ERROR_FILE_NOT_FOUND; ++i) 232 Sleep(200); 233 } 234 return 0; 235 } 236 237 static FARPROC WINAPI delayed_load_library_hook(unsigned dliNotify, PDelayLoadInfo pdli) 238 { 239 HMODULE library; 240 if (dliNotify != dliNotePreLoadLibrary) 241 return NULL; 242 library = LoadLibraryExA(pdli->szDll, NULL, LOAD_LIBRARY_SEARCH_SYSTEM32); 243 if (!library) 244 abort(); 245 return (FARPROC)library; 246 } 247 248 PfnDliHook __pfnDliNotifyHook2 = delayed_load_library_hook; 249 250 static LRESULT CALLBACK wndproc(HWND hWnd, UINT uMsg, WPARAM wParam, LPARAM lParam, UINT_PTR uIdSubclass, DWORD_PTR dwRefData) 251 { 252 (void)uIdSubclass; (void)dwRefData; 253 254 switch (uMsg) { 255 case WM_CLOSE: 256 case WM_DESTROY: { 257 LRESULT ret = DefSubclassProc(hWnd, uMsg, wParam, lParam); 258 exit(0); 259 return ret; 260 } 261 case WM_APP: if (g_total) { 262 char buf[0x1000], *start, *paren; 263 LONG_PTR current_style; 264 int chars = GetWindowTextA(progress, buf, sizeof(buf)); 265 if (chars) { 266 start = buf + chars; 267 if (start[-1] == '.' && start[-2] == '.' && start[-3] == '.') 268 start -= 3; 269 else if ((paren = memchr(buf, '(', chars)) && paren > buf) 270 start = paren - 1; 271 *start = '\0'; 272 _snprintf_s(start, sizeof(buf) - (start - buf), _TRUNCATE, " (%.2f%%)", g_current * 100.0f / g_total); 273 SetWindowTextA(progress, buf); 274 } 275 current_style = GetWindowLongPtrA(progress, GWL_STYLE); 276 if (current_style & PBS_MARQUEE) { 277 SetWindowLongPtrA(progress, GWL_STYLE, current_style & ~PBS_MARQUEE); 278 SendMessageA(progress, PBM_SETMARQUEE, FALSE, 0); 279 } 280 SendMessageA(progress, PBM_SETRANGE32, 0, (LPARAM)g_total); 281 SendMessageA(progress, PBM_SETPOS, (WPARAM)g_current, 0); 282 break; 283 } 284 } 285 return DefSubclassProc(hWnd, uMsg, wParam, lParam); 286 } 287 288 static void parse_command_line(void) 289 { 290 LPWSTR *argv; 291 int argc; 292 argv = CommandLineToArgvW(GetCommandLineW(), &argc); 293 if (!argv) 294 return; 295 for (int i = 1; i < argc; ++i) { 296 if (wcsicmp(argv[i], L"/noprompt") == 0) 297 prompts = false; 298 } 299 LocalFree(argv); 300 } 301 302 int WINAPI WinMain(HINSTANCE hInstance, HINSTANCE hPrevInstance, PSTR pCmdLine, int nCmdShow) 303 { 304 MSG msg; 305 HICON icon; 306 HDC dc; 307 float scale; 308 309 (void)hPrevInstance; (void)pCmdLine; (void)nCmdShow; 310 311 if (!SetDllDirectoryA("") || !SetDefaultDllDirectories(LOAD_LIBRARY_SEARCH_SYSTEM32)) 312 return 1; 313 314 parse_command_line(); 315 316 InitCommonControlsEx(&(INITCOMMONCONTROLSEX){ .dwSize = sizeof(INITCOMMONCONTROLSEX), .dwICC = ICC_PROGRESS_CLASS }); 317 318 progress = CreateWindowExA(0, PROGRESS_CLASS, "WireGuard Installer", 319 (WS_OVERLAPPEDWINDOW & ~(WS_BORDER | WS_THICKFRAME | WS_MAXIMIZEBOX)) | PBS_MARQUEE | PBS_SMOOTH, 320 CW_USEDEFAULT, CW_USEDEFAULT, CW_USEDEFAULT, CW_USEDEFAULT, 321 NULL, NULL, hInstance, NULL); 322 SetWindowSubclass(progress, wndproc, 0, 0); 323 dc = GetDC(progress); 324 scale = GetDeviceCaps(dc, LOGPIXELSY) / 96.0f; 325 ReleaseDC(progress, dc); 326 icon = LoadIconA(hInstance, MAKEINTRESOURCE(7)); 327 SendMessageA(progress, WM_SETICON, ICON_BIG, (LPARAM)icon); 328 SendMessageA(progress, WM_SETICON, ICON_SMALL, (LPARAM)icon); 329 SendMessageA(progress, PBM_SETMARQUEE, TRUE, 0); 330 SetWindowPos(progress, HWND_TOPMOST, -1, -1, 500 * scale, 80 * scale, SWP_NOMOVE | SWP_SHOWWINDOW); 331 332 _onexit(cleanup); 333 CreateThread(NULL, 0, download_thread, NULL, 0, NULL); 334 335 while (GetMessage(&msg, NULL, 0, 0)) { 336 TranslateMessage(&msg); 337 DispatchMessage(&msg); 338 } 339 return 0; 340 }