golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/updater/winhttp/winhttp.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package winhttp 7 8 import ( 9 "errors" 10 "fmt" 11 "io" 12 "runtime" 13 "strconv" 14 "strings" 15 "sync/atomic" 16 "unsafe" 17 18 "golang.org/x/sys/windows" 19 ) 20 21 type Session struct { 22 handle _HINTERNET 23 } 24 25 type Connection struct { 26 handle _HINTERNET 27 session *Session 28 https bool 29 } 30 31 type Response struct { 32 handle _HINTERNET 33 connection *Connection 34 } 35 36 func convertError(err *error) { 37 if *err == nil { 38 return 39 } 40 var errno windows.Errno 41 if errors.As(*err, &errno) { 42 if errno > _WINHTTP_ERROR_BASE && errno <= _WINHTTP_ERROR_LAST { 43 *err = Error(errno) 44 } 45 } 46 } 47 48 func isWin7() bool { 49 maj, min, _ := windows.RtlGetNtVersionNumbers() 50 return maj < 6 || (maj == 6 && min <= 1) 51 } 52 53 func isWin8DotZeroOrBelow() bool { 54 maj, min, _ := windows.RtlGetNtVersionNumbers() 55 return maj < 6 || (maj == 6 && min <= 2) 56 } 57 58 func NewSession(userAgent string) (session *Session, err error) { 59 session = new(Session) 60 defer convertError(&err) 61 defer func() { 62 if err != nil { 63 session.Close() 64 session = nil 65 } 66 }() 67 userAgent16, err := windows.UTF16PtrFromString(userAgent) 68 if err != nil { 69 return 70 } 71 var proxyFlag uint32 = _WINHTTP_ACCESS_TYPE_AUTOMATIC_PROXY 72 if isWin7() { 73 proxyFlag = _WINHTTP_ACCESS_TYPE_DEFAULT_PROXY 74 } 75 session.handle, err = winHttpOpen(userAgent16, proxyFlag, nil, nil, 0) 76 if err != nil { 77 return 78 } 79 var enableHttp2 uint32 = _WINHTTP_PROTOCOL_FLAG_HTTP2 80 _ = winHttpSetOption(session.handle, _WINHTTP_OPTION_ENABLE_HTTP_PROTOCOL, unsafe.Pointer(&enableHttp2), uint32(unsafe.Sizeof(enableHttp2))) // Don't check return value, in case of old Windows 81 82 if isWin8DotZeroOrBelow() { 83 var enableTLS12 uint32 = _WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_2 84 err = winHttpSetOption(session.handle, _WINHTTP_OPTION_SECURE_PROTOCOLS, unsafe.Pointer(&enableTLS12), uint32(unsafe.Sizeof(enableTLS12))) 85 if err != nil { 86 return 87 } 88 } 89 90 runtime.SetFinalizer(session, func(session *Session) { 91 session.Close() 92 }) 93 return 94 } 95 96 func (session *Session) Close() (err error) { 97 defer convertError(&err) 98 handle := (_HINTERNET)(atomic.SwapUintptr((*uintptr)(&session.handle), 0)) 99 if handle == 0 { 100 return 101 } 102 return winHttpCloseHandle(handle) 103 } 104 105 func (session *Session) Connect(server string, port uint16, https bool) (connection *Connection, err error) { 106 connection = &Connection{session: session} 107 defer convertError(&err) 108 defer func() { 109 if err != nil { 110 connection.Close() 111 connection = nil 112 } 113 }() 114 server16, err := windows.UTF16PtrFromString(server) 115 if err != nil { 116 return 117 } 118 connection.handle, err = winHttpConnect(session.handle, server16, port, 0) 119 if err != nil { 120 return 121 } 122 connection.https = https 123 124 runtime.SetFinalizer(connection, func(connection *Connection) { 125 connection.Close() 126 }) 127 return 128 } 129 130 func (connection *Connection) Close() (err error) { 131 defer convertError(&err) 132 handle := (_HINTERNET)(atomic.SwapUintptr((*uintptr)(&connection.handle), 0)) 133 if handle == 0 { 134 return 135 } 136 return winHttpCloseHandle(handle) 137 } 138 139 func (connection *Connection) Get(path string, refresh bool) (response *Response, err error) { 140 response = &Response{connection: connection} 141 defer convertError(&err) 142 defer func() { 143 if err != nil { 144 response.Close() 145 response = nil 146 } 147 }() 148 var flags uint32 149 if refresh { 150 flags |= _WINHTTP_FLAG_REFRESH 151 } 152 if connection.https { 153 flags |= _WINHTTP_FLAG_SECURE 154 } 155 path16, err := windows.UTF16PtrFromString(path) 156 if err != nil { 157 return 158 } 159 get16, err := windows.UTF16PtrFromString("GET") 160 if err != nil { 161 return 162 } 163 response.handle, err = winHttpOpenRequest(connection.handle, get16, path16, nil, nil, nil, flags) 164 if err != nil { 165 return 166 } 167 err = winHttpSendRequest(response.handle, nil, 0, nil, 0, 0, 0) 168 if err != nil { 169 return 170 } 171 err = winHttpReceiveResponse(response.handle, 0) 172 if err != nil { 173 return 174 } 175 176 runtime.SetFinalizer(response, func(response *Response) { 177 response.Close() 178 }) 179 return 180 } 181 182 func (response *Response) Length() (length uint64, err error) { 183 defer convertError(&err) 184 numBuf := make([]uint16, 22) 185 numLen := uint32(len(numBuf) * 2) 186 err = winHttpQueryHeaders(response.handle, _WINHTTP_QUERY_CONTENT_LENGTH, nil, unsafe.Pointer(&numBuf[0]), &numLen, nil) 187 if err != nil { 188 return 189 } 190 length, err = strconv.ParseUint(windows.UTF16ToString(numBuf[:numLen]), 10, 64) 191 if err != nil { 192 return 193 } 194 return 195 } 196 197 func (response *Response) Read(p []byte) (n int, err error) { 198 defer convertError(&err) 199 if len(p) == 0 { 200 return 0, nil 201 } 202 var bytesRead uint32 203 err = winHttpReadData(response.handle, &p[0], uint32(len(p)), &bytesRead) 204 if err != nil { 205 return 0, nil 206 } 207 if bytesRead == 0 || int(bytesRead) < 0 { 208 return 0, io.EOF 209 } 210 return int(bytesRead), nil 211 } 212 213 func (response *Response) Close() (err error) { 214 defer convertError(&err) 215 handle := (_HINTERNET)(atomic.SwapUintptr((*uintptr)(&response.handle), 0)) 216 if handle == 0 { 217 return 218 } 219 return winHttpCloseHandle(handle) 220 } 221 222 func (error Error) Error() string { 223 var message [2048]uint16 224 n, err := windows.FormatMessage(windows.FORMAT_MESSAGE_FROM_HMODULE|windows.FORMAT_MESSAGE_IGNORE_INSERTS|windows.FORMAT_MESSAGE_MAX_WIDTH_MASK, 225 modwinhttp.Handle(), uint32(error), 0, message[:], nil) 226 if err != nil { 227 return fmt.Sprintf("WinHTTP error #%d", error) 228 } 229 return strings.TrimSpace(windows.UTF16ToString(message[:n])) 230 }