github.com/boomhut/fiber/v2@v2.0.0-20230603160335-b65c856e57d3/internal/wmi/wmi.go (about)

     1  //go:build windows
     2  // +build windows
     3  
     4  /*
     5  Package wmi provides a WQL interface for WMI on Windows.
     6  
     7  Example code to print names of running processes:
     8  
     9  	type Win32_Process struct {
    10  		Name string
    11  	}
    12  
    13  	func main() {
    14  		var dst []Win32_Process
    15  		q := wmi.CreateQuery(&dst, "")
    16  		err := wmi.Query(q, &dst)
    17  		if err != nil {
    18  			log.Fatal(err)
    19  		}
    20  		for i, v := range dst {
    21  			println(i, v.Name)
    22  		}
    23  	}
    24  */
    25  package wmi
    26  
    27  import (
    28  	"bytes"
    29  	"errors"
    30  	"fmt"
    31  	"log"
    32  	"os"
    33  	"reflect"
    34  	"runtime"
    35  	"strconv"
    36  	"strings"
    37  	"sync"
    38  	"time"
    39  
    40  	"github.com/boomhut/fiber/v2/internal/go-ole"
    41  	"github.com/boomhut/fiber/v2/internal/go-ole/oleutil"
    42  )
    43  
    44  var l = log.New(os.Stdout, "", log.LstdFlags)
    45  
    46  var (
    47  	ErrInvalidEntityType = errors.New("wmi: invalid entity type")
    48  	// ErrNilCreateObject is the error returned if CreateObject returns nil even
    49  	// if the error was nil.
    50  	ErrNilCreateObject = errors.New("wmi: create object returned nil")
    51  	lock               sync.Mutex
    52  )
    53  
    54  // S_FALSE is returned by CoInitializeEx if it was already called on this thread.
    55  const S_FALSE = 0x00000001
    56  
    57  // QueryNamespace invokes Query with the given namespace on the local machine.
    58  func QueryNamespace(query string, dst interface{}, namespace string) error {
    59  	return Query(query, dst, nil, namespace)
    60  }
    61  
    62  // Query runs the WQL query and appends the values to dst.
    63  //
    64  // dst must have type *[]S or *[]*S, for some struct type S. Fields selected in
    65  // the query must have the same name in dst. Supported types are all signed and
    66  // unsigned integers, time.Time, string, bool, or a pointer to one of those.
    67  // Array types are not supported.
    68  //
    69  // By default, the local machine and default namespace are used. These can be
    70  // changed using connectServerArgs. See
    71  // http://msdn.microsoft.com/en-us/library/aa393720.aspx for details.
    72  //
    73  // Query is a wrapper around DefaultClient.Query.
    74  func Query(query string, dst interface{}, connectServerArgs ...interface{}) error {
    75  	if DefaultClient.SWbemServicesClient == nil {
    76  		return DefaultClient.Query(query, dst, connectServerArgs...)
    77  	}
    78  	return DefaultClient.SWbemServicesClient.Query(query, dst, connectServerArgs...)
    79  }
    80  
    81  // A Client is an WMI query client.
    82  //
    83  // Its zero value (DefaultClient) is a usable client.
    84  type Client struct {
    85  	// NonePtrZero specifies if nil values for fields which aren't pointers
    86  	// should be returned as the field types zero value.
    87  	//
    88  	// Setting this to true allows stucts without pointer fields to be used
    89  	// without the risk failure should a nil value returned from WMI.
    90  	NonePtrZero bool
    91  
    92  	// PtrNil specifies if nil values for pointer fields should be returned
    93  	// as nil.
    94  	//
    95  	// Setting this to true will set pointer fields to nil where WMI
    96  	// returned nil, otherwise the types zero value will be returned.
    97  	PtrNil bool
    98  
    99  	// AllowMissingFields specifies that struct fields not present in the
   100  	// query result should not result in an error.
   101  	//
   102  	// Setting this to true allows custom queries to be used with full
   103  	// struct definitions instead of having to define multiple structs.
   104  	AllowMissingFields bool
   105  
   106  	// SWbemServiceClient is an optional SWbemServices object that can be
   107  	// initialized and then reused across multiple queries. If it is null
   108  	// then the method will initialize a new temporary client each time.
   109  	SWbemServicesClient *SWbemServices
   110  }
   111  
   112  // DefaultClient is the default Client and is used by Query, QueryNamespace
   113  var DefaultClient = &Client{}
   114  
   115  // Query runs the WQL query and appends the values to dst.
   116  //
   117  // dst must have type *[]S or *[]*S, for some struct type S. Fields selected in
   118  // the query must have the same name in dst. Supported types are all signed and
   119  // unsigned integers, time.Time, string, bool, or a pointer to one of those.
   120  // Array types are not supported.
   121  //
   122  // By default, the local machine and default namespace are used. These can be
   123  // changed using connectServerArgs. See
   124  // http://msdn.microsoft.com/en-us/library/aa393720.aspx for details.
   125  func (c *Client) Query(query string, dst interface{}, connectServerArgs ...interface{}) error {
   126  	dv := reflect.ValueOf(dst)
   127  	if dv.Kind() != reflect.Ptr || dv.IsNil() {
   128  		return ErrInvalidEntityType
   129  	}
   130  	dv = dv.Elem()
   131  	mat, elemType := checkMultiArg(dv)
   132  	if mat == multiArgTypeInvalid {
   133  		return ErrInvalidEntityType
   134  	}
   135  
   136  	lock.Lock()
   137  	defer lock.Unlock()
   138  	runtime.LockOSThread()
   139  	defer runtime.UnlockOSThread()
   140  
   141  	err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
   142  	if err != nil {
   143  		oleCode := err.(*ole.OleError).Code()
   144  		if oleCode != ole.S_OK && oleCode != S_FALSE {
   145  			return err
   146  		}
   147  	}
   148  	defer ole.CoUninitialize()
   149  
   150  	unknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator")
   151  	if err != nil {
   152  		return err
   153  	} else if unknown == nil {
   154  		return ErrNilCreateObject
   155  	}
   156  	defer unknown.Release()
   157  
   158  	wmi, err := unknown.QueryInterface(ole.IID_IDispatch)
   159  	if err != nil {
   160  		return err
   161  	}
   162  	defer wmi.Release()
   163  
   164  	// service is a SWbemServices
   165  	serviceRaw, err := oleutil.CallMethod(wmi, "ConnectServer", connectServerArgs...)
   166  	if err != nil {
   167  		return err
   168  	}
   169  	service := serviceRaw.ToIDispatch()
   170  	defer serviceRaw.Clear()
   171  
   172  	// result is a SWBemObjectSet
   173  	resultRaw, err := oleutil.CallMethod(service, "ExecQuery", query)
   174  	if err != nil {
   175  		return err
   176  	}
   177  	result := resultRaw.ToIDispatch()
   178  	defer resultRaw.Clear()
   179  
   180  	count, err := oleInt64(result, "Count")
   181  	if err != nil {
   182  		return err
   183  	}
   184  
   185  	enumProperty, err := result.GetProperty("_NewEnum")
   186  	if err != nil {
   187  		return err
   188  	}
   189  	defer enumProperty.Clear()
   190  
   191  	enum, err := enumProperty.ToIUnknown().IEnumVARIANT(ole.IID_IEnumVariant)
   192  	if err != nil {
   193  		return err
   194  	}
   195  	if enum == nil {
   196  		return fmt.Errorf("can't get IEnumVARIANT, enum is nil")
   197  	}
   198  	defer enum.Release()
   199  
   200  	// Initialize a slice with Count capacity
   201  	dv.Set(reflect.MakeSlice(dv.Type(), 0, int(count)))
   202  
   203  	var errFieldMismatch error
   204  	for itemRaw, length, err := enum.Next(1); length > 0; itemRaw, length, err = enum.Next(1) {
   205  		if err != nil {
   206  			return err
   207  		}
   208  
   209  		err := func() error {
   210  			// item is a SWbemObject, but really a Win32_Process
   211  			item := itemRaw.ToIDispatch()
   212  			defer item.Release()
   213  
   214  			ev := reflect.New(elemType)
   215  			if err = c.loadEntity(ev.Interface(), item); err != nil {
   216  				if _, ok := err.(*ErrFieldMismatch); ok {
   217  					// We continue loading entities even in the face of field mismatch errors.
   218  					// If we encounter any other error, that other error is returned. Otherwise,
   219  					// an ErrFieldMismatch is returned.
   220  					errFieldMismatch = err
   221  				} else {
   222  					return err
   223  				}
   224  			}
   225  			if mat != multiArgTypeStructPtr {
   226  				ev = ev.Elem()
   227  			}
   228  			dv.Set(reflect.Append(dv, ev))
   229  			return nil
   230  		}()
   231  		if err != nil {
   232  			return err
   233  		}
   234  	}
   235  	return errFieldMismatch
   236  }
   237  
   238  // ErrFieldMismatch is returned when a field is to be loaded into a different
   239  // type than the one it was stored from, or when a field is missing or
   240  // unexported in the destination struct.
   241  // StructType is the type of the struct pointed to by the destination argument.
   242  type ErrFieldMismatch struct {
   243  	StructType reflect.Type
   244  	FieldName  string
   245  	Reason     string
   246  }
   247  
   248  func (e *ErrFieldMismatch) Error() string {
   249  	return fmt.Sprintf("wmi: cannot load field %q into a %q: %s",
   250  		e.FieldName, e.StructType, e.Reason)
   251  }
   252  
   253  var timeType = reflect.TypeOf(time.Time{})
   254  
   255  // loadEntity loads a SWbemObject into a struct pointer.
   256  func (c *Client) loadEntity(dst interface{}, src *ole.IDispatch) (errFieldMismatch error) {
   257  	v := reflect.ValueOf(dst).Elem()
   258  	for i := 0; i < v.NumField(); i++ {
   259  		f := v.Field(i)
   260  		of := f
   261  		isPtr := f.Kind() == reflect.Ptr
   262  		if isPtr {
   263  			ptr := reflect.New(f.Type().Elem())
   264  			f.Set(ptr)
   265  			f = f.Elem()
   266  		}
   267  		n := v.Type().Field(i).Name
   268  		if !f.CanSet() {
   269  			return &ErrFieldMismatch{
   270  				StructType: of.Type(),
   271  				FieldName:  n,
   272  				Reason:     "CanSet() is false",
   273  			}
   274  		}
   275  		prop, err := oleutil.GetProperty(src, n)
   276  		if err != nil {
   277  			if !c.AllowMissingFields {
   278  				errFieldMismatch = &ErrFieldMismatch{
   279  					StructType: of.Type(),
   280  					FieldName:  n,
   281  					Reason:     "no such struct field",
   282  				}
   283  			}
   284  			continue
   285  		}
   286  		defer prop.Clear()
   287  
   288  		if prop.VT == 0x1 { //VT_NULL
   289  			continue
   290  		}
   291  
   292  		switch val := prop.Value().(type) {
   293  		case int8, int16, int32, int64, int:
   294  			v := reflect.ValueOf(val).Int()
   295  			switch f.Kind() {
   296  			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   297  				f.SetInt(v)
   298  			case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   299  				f.SetUint(uint64(v))
   300  			default:
   301  				return &ErrFieldMismatch{
   302  					StructType: of.Type(),
   303  					FieldName:  n,
   304  					Reason:     "not an integer class",
   305  				}
   306  			}
   307  		case uint8, uint16, uint32, uint64:
   308  			v := reflect.ValueOf(val).Uint()
   309  			switch f.Kind() {
   310  			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   311  				f.SetInt(int64(v))
   312  			case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   313  				f.SetUint(v)
   314  			default:
   315  				return &ErrFieldMismatch{
   316  					StructType: of.Type(),
   317  					FieldName:  n,
   318  					Reason:     "not an integer class",
   319  				}
   320  			}
   321  		case string:
   322  			switch f.Kind() {
   323  			case reflect.String:
   324  				f.SetString(val)
   325  			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   326  				iv, err := strconv.ParseInt(val, 10, 64)
   327  				if err != nil {
   328  					return err
   329  				}
   330  				f.SetInt(iv)
   331  			case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   332  				uv, err := strconv.ParseUint(val, 10, 64)
   333  				if err != nil {
   334  					return err
   335  				}
   336  				f.SetUint(uv)
   337  			case reflect.Struct:
   338  				switch f.Type() {
   339  				case timeType:
   340  					if len(val) == 25 {
   341  						mins, err := strconv.Atoi(val[22:])
   342  						if err != nil {
   343  							return err
   344  						}
   345  						val = val[:22] + fmt.Sprintf("%02d%02d", mins/60, mins%60)
   346  					}
   347  					t, err := time.Parse("20060102150405.000000-0700", val)
   348  					if err != nil {
   349  						return err
   350  					}
   351  					f.Set(reflect.ValueOf(t))
   352  				}
   353  			}
   354  		case bool:
   355  			switch f.Kind() {
   356  			case reflect.Bool:
   357  				f.SetBool(val)
   358  			default:
   359  				return &ErrFieldMismatch{
   360  					StructType: of.Type(),
   361  					FieldName:  n,
   362  					Reason:     "not a bool",
   363  				}
   364  			}
   365  		case float32:
   366  			switch f.Kind() {
   367  			case reflect.Float32:
   368  				f.SetFloat(float64(val))
   369  			default:
   370  				return &ErrFieldMismatch{
   371  					StructType: of.Type(),
   372  					FieldName:  n,
   373  					Reason:     "not a Float32",
   374  				}
   375  			}
   376  		default:
   377  			if f.Kind() == reflect.Slice {
   378  				switch f.Type().Elem().Kind() {
   379  				case reflect.String:
   380  					safeArray := prop.ToArray()
   381  					if safeArray != nil {
   382  						arr := safeArray.ToValueArray()
   383  						fArr := reflect.MakeSlice(f.Type(), len(arr), len(arr))
   384  						for i, v := range arr {
   385  							s := fArr.Index(i)
   386  							s.SetString(v.(string))
   387  						}
   388  						f.Set(fArr)
   389  					}
   390  				case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
   391  					safeArray := prop.ToArray()
   392  					if safeArray != nil {
   393  						arr := safeArray.ToValueArray()
   394  						fArr := reflect.MakeSlice(f.Type(), len(arr), len(arr))
   395  						for i, v := range arr {
   396  							s := fArr.Index(i)
   397  							s.SetUint(reflect.ValueOf(v).Uint())
   398  						}
   399  						f.Set(fArr)
   400  					}
   401  				case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
   402  					safeArray := prop.ToArray()
   403  					if safeArray != nil {
   404  						arr := safeArray.ToValueArray()
   405  						fArr := reflect.MakeSlice(f.Type(), len(arr), len(arr))
   406  						for i, v := range arr {
   407  							s := fArr.Index(i)
   408  							s.SetInt(reflect.ValueOf(v).Int())
   409  						}
   410  						f.Set(fArr)
   411  					}
   412  				default:
   413  					return &ErrFieldMismatch{
   414  						StructType: of.Type(),
   415  						FieldName:  n,
   416  						Reason:     fmt.Sprintf("unsupported slice type (%T)", val),
   417  					}
   418  				}
   419  			} else {
   420  				typeof := reflect.TypeOf(val)
   421  				if typeof == nil && (isPtr || c.NonePtrZero) {
   422  					if (isPtr && c.PtrNil) || (!isPtr && c.NonePtrZero) {
   423  						of.Set(reflect.Zero(of.Type()))
   424  					}
   425  					break
   426  				}
   427  				return &ErrFieldMismatch{
   428  					StructType: of.Type(),
   429  					FieldName:  n,
   430  					Reason:     fmt.Sprintf("unsupported type (%T)", val),
   431  				}
   432  			}
   433  		}
   434  	}
   435  	return errFieldMismatch
   436  }
   437  
   438  type multiArgType int
   439  
   440  const (
   441  	multiArgTypeInvalid multiArgType = iota
   442  	multiArgTypeStruct
   443  	multiArgTypeStructPtr
   444  )
   445  
   446  // checkMultiArg checks that v has type []S, []*S for some struct type S.
   447  //
   448  // It returns what category the slice's elements are, and the reflect.Type
   449  // that represents S.
   450  func checkMultiArg(v reflect.Value) (m multiArgType, elemType reflect.Type) {
   451  	if v.Kind() != reflect.Slice {
   452  		return multiArgTypeInvalid, nil
   453  	}
   454  	elemType = v.Type().Elem()
   455  	switch elemType.Kind() {
   456  	case reflect.Struct:
   457  		return multiArgTypeStruct, elemType
   458  	case reflect.Ptr:
   459  		elemType = elemType.Elem()
   460  		if elemType.Kind() == reflect.Struct {
   461  			return multiArgTypeStructPtr, elemType
   462  		}
   463  	}
   464  	return multiArgTypeInvalid, nil
   465  }
   466  
   467  func oleInt64(item *ole.IDispatch, prop string) (int64, error) {
   468  	v, err := oleutil.GetProperty(item, prop)
   469  	if err != nil {
   470  		return 0, err
   471  	}
   472  	defer v.Clear()
   473  
   474  	i := int64(v.Val)
   475  	return i, nil
   476  }
   477  
   478  // CreateQuery returns a WQL query string that queries all columns of src. where
   479  // is an optional string that is appended to the query, to be used with WHERE
   480  // clauses. In such a case, the "WHERE" string should appear at the beginning.
   481  func CreateQuery(src interface{}, where string) string {
   482  	var b bytes.Buffer
   483  	b.WriteString("SELECT ")
   484  	s := reflect.Indirect(reflect.ValueOf(src))
   485  	t := s.Type()
   486  	if s.Kind() == reflect.Slice {
   487  		t = t.Elem()
   488  	}
   489  	if t.Kind() != reflect.Struct {
   490  		return ""
   491  	}
   492  	var fields []string
   493  	for i := 0; i < t.NumField(); i++ {
   494  		fields = append(fields, t.Field(i).Name)
   495  	}
   496  	b.WriteString(strings.Join(fields, ", "))
   497  	b.WriteString(" FROM ")
   498  	b.WriteString(t.Name())
   499  	b.WriteString(" " + where)
   500  	return b.String()
   501  }