golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/updater/winhttp/winhttp.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package winhttp
     7  
     8  import (
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"runtime"
    13  	"strconv"
    14  	"strings"
    15  	"sync/atomic"
    16  	"unsafe"
    17  
    18  	"golang.org/x/sys/windows"
    19  )
    20  
    21  type Session struct {
    22  	handle _HINTERNET
    23  }
    24  
    25  type Connection struct {
    26  	handle  _HINTERNET
    27  	session *Session
    28  	https   bool
    29  }
    30  
    31  type Response struct {
    32  	handle     _HINTERNET
    33  	connection *Connection
    34  }
    35  
    36  func convertError(err *error) {
    37  	if *err == nil {
    38  		return
    39  	}
    40  	var errno windows.Errno
    41  	if errors.As(*err, &errno) {
    42  		if errno > _WINHTTP_ERROR_BASE && errno <= _WINHTTP_ERROR_LAST {
    43  			*err = Error(errno)
    44  		}
    45  	}
    46  }
    47  
    48  func isWin7() bool {
    49  	maj, min, _ := windows.RtlGetNtVersionNumbers()
    50  	return maj < 6 || (maj == 6 && min <= 1)
    51  }
    52  
    53  func isWin8DotZeroOrBelow() bool {
    54  	maj, min, _ := windows.RtlGetNtVersionNumbers()
    55  	return maj < 6 || (maj == 6 && min <= 2)
    56  }
    57  
    58  func NewSession(userAgent string) (session *Session, err error) {
    59  	session = new(Session)
    60  	defer convertError(&err)
    61  	defer func() {
    62  		if err != nil {
    63  			session.Close()
    64  			session = nil
    65  		}
    66  	}()
    67  	userAgent16, err := windows.UTF16PtrFromString(userAgent)
    68  	if err != nil {
    69  		return
    70  	}
    71  	var proxyFlag uint32 = _WINHTTP_ACCESS_TYPE_AUTOMATIC_PROXY
    72  	if isWin7() {
    73  		proxyFlag = _WINHTTP_ACCESS_TYPE_DEFAULT_PROXY
    74  	}
    75  	session.handle, err = winHttpOpen(userAgent16, proxyFlag, nil, nil, 0)
    76  	if err != nil {
    77  		return
    78  	}
    79  	var enableHttp2 uint32 = _WINHTTP_PROTOCOL_FLAG_HTTP2
    80  	_ = winHttpSetOption(session.handle, _WINHTTP_OPTION_ENABLE_HTTP_PROTOCOL, unsafe.Pointer(&enableHttp2), uint32(unsafe.Sizeof(enableHttp2))) // Don't check return value, in case of old Windows
    81  
    82  	if isWin8DotZeroOrBelow() {
    83  		var enableTLS12 uint32 = _WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_2
    84  		err = winHttpSetOption(session.handle, _WINHTTP_OPTION_SECURE_PROTOCOLS, unsafe.Pointer(&enableTLS12), uint32(unsafe.Sizeof(enableTLS12)))
    85  		if err != nil {
    86  			return
    87  		}
    88  	}
    89  
    90  	runtime.SetFinalizer(session, func(session *Session) {
    91  		session.Close()
    92  	})
    93  	return
    94  }
    95  
    96  func (session *Session) Close() (err error) {
    97  	defer convertError(&err)
    98  	handle := (_HINTERNET)(atomic.SwapUintptr((*uintptr)(&session.handle), 0))
    99  	if handle == 0 {
   100  		return
   101  	}
   102  	return winHttpCloseHandle(handle)
   103  }
   104  
   105  func (session *Session) Connect(server string, port uint16, https bool) (connection *Connection, err error) {
   106  	connection = &Connection{session: session}
   107  	defer convertError(&err)
   108  	defer func() {
   109  		if err != nil {
   110  			connection.Close()
   111  			connection = nil
   112  		}
   113  	}()
   114  	server16, err := windows.UTF16PtrFromString(server)
   115  	if err != nil {
   116  		return
   117  	}
   118  	connection.handle, err = winHttpConnect(session.handle, server16, port, 0)
   119  	if err != nil {
   120  		return
   121  	}
   122  	connection.https = https
   123  
   124  	runtime.SetFinalizer(connection, func(connection *Connection) {
   125  		connection.Close()
   126  	})
   127  	return
   128  }
   129  
   130  func (connection *Connection) Close() (err error) {
   131  	defer convertError(&err)
   132  	handle := (_HINTERNET)(atomic.SwapUintptr((*uintptr)(&connection.handle), 0))
   133  	if handle == 0 {
   134  		return
   135  	}
   136  	return winHttpCloseHandle(handle)
   137  }
   138  
   139  func (connection *Connection) Get(path string, refresh bool) (response *Response, err error) {
   140  	response = &Response{connection: connection}
   141  	defer convertError(&err)
   142  	defer func() {
   143  		if err != nil {
   144  			response.Close()
   145  			response = nil
   146  		}
   147  	}()
   148  	var flags uint32
   149  	if refresh {
   150  		flags |= _WINHTTP_FLAG_REFRESH
   151  	}
   152  	if connection.https {
   153  		flags |= _WINHTTP_FLAG_SECURE
   154  	}
   155  	path16, err := windows.UTF16PtrFromString(path)
   156  	if err != nil {
   157  		return
   158  	}
   159  	get16, err := windows.UTF16PtrFromString("GET")
   160  	if err != nil {
   161  		return
   162  	}
   163  	response.handle, err = winHttpOpenRequest(connection.handle, get16, path16, nil, nil, nil, flags)
   164  	if err != nil {
   165  		return
   166  	}
   167  	err = winHttpSendRequest(response.handle, nil, 0, nil, 0, 0, 0)
   168  	if err != nil {
   169  		return
   170  	}
   171  	err = winHttpReceiveResponse(response.handle, 0)
   172  	if err != nil {
   173  		return
   174  	}
   175  
   176  	runtime.SetFinalizer(response, func(response *Response) {
   177  		response.Close()
   178  	})
   179  	return
   180  }
   181  
   182  func (response *Response) Length() (length uint64, err error) {
   183  	defer convertError(&err)
   184  	numBuf := make([]uint16, 22)
   185  	numLen := uint32(len(numBuf) * 2)
   186  	err = winHttpQueryHeaders(response.handle, _WINHTTP_QUERY_CONTENT_LENGTH, nil, unsafe.Pointer(&numBuf[0]), &numLen, nil)
   187  	if err != nil {
   188  		return
   189  	}
   190  	length, err = strconv.ParseUint(windows.UTF16ToString(numBuf[:numLen]), 10, 64)
   191  	if err != nil {
   192  		return
   193  	}
   194  	return
   195  }
   196  
   197  func (response *Response) Read(p []byte) (n int, err error) {
   198  	defer convertError(&err)
   199  	if len(p) == 0 {
   200  		return 0, nil
   201  	}
   202  	var bytesRead uint32
   203  	err = winHttpReadData(response.handle, &p[0], uint32(len(p)), &bytesRead)
   204  	if err != nil {
   205  		return 0, nil
   206  	}
   207  	if bytesRead == 0 || int(bytesRead) < 0 {
   208  		return 0, io.EOF
   209  	}
   210  	return int(bytesRead), nil
   211  }
   212  
   213  func (response *Response) Close() (err error) {
   214  	defer convertError(&err)
   215  	handle := (_HINTERNET)(atomic.SwapUintptr((*uintptr)(&response.handle), 0))
   216  	if handle == 0 {
   217  		return
   218  	}
   219  	return winHttpCloseHandle(handle)
   220  }
   221  
   222  func (error Error) Error() string {
   223  	var message [2048]uint16
   224  	n, err := windows.FormatMessage(windows.FORMAT_MESSAGE_FROM_HMODULE|windows.FORMAT_MESSAGE_IGNORE_INSERTS|windows.FORMAT_MESSAGE_MAX_WIDTH_MASK,
   225  		modwinhttp.Handle(), uint32(error), 0, message[:], nil)
   226  	if err != nil {
   227  		return fmt.Sprintf("WinHTTP error #%d", error)
   228  	}
   229  	return strings.TrimSpace(windows.UTF16ToString(message[:n]))
   230  }