github.com/boomhut/fiber/v2@v2.0.0-20230603160335-b65c856e57d3/internal/wmi/wmi.go (about) 1 //go:build windows 2 // +build windows 3 4 /* 5 Package wmi provides a WQL interface for WMI on Windows. 6 7 Example code to print names of running processes: 8 9 type Win32_Process struct { 10 Name string 11 } 12 13 func main() { 14 var dst []Win32_Process 15 q := wmi.CreateQuery(&dst, "") 16 err := wmi.Query(q, &dst) 17 if err != nil { 18 log.Fatal(err) 19 } 20 for i, v := range dst { 21 println(i, v.Name) 22 } 23 } 24 */ 25 package wmi 26 27 import ( 28 "bytes" 29 "errors" 30 "fmt" 31 "log" 32 "os" 33 "reflect" 34 "runtime" 35 "strconv" 36 "strings" 37 "sync" 38 "time" 39 40 "github.com/boomhut/fiber/v2/internal/go-ole" 41 "github.com/boomhut/fiber/v2/internal/go-ole/oleutil" 42 ) 43 44 var l = log.New(os.Stdout, "", log.LstdFlags) 45 46 var ( 47 ErrInvalidEntityType = errors.New("wmi: invalid entity type") 48 // ErrNilCreateObject is the error returned if CreateObject returns nil even 49 // if the error was nil. 50 ErrNilCreateObject = errors.New("wmi: create object returned nil") 51 lock sync.Mutex 52 ) 53 54 // S_FALSE is returned by CoInitializeEx if it was already called on this thread. 55 const S_FALSE = 0x00000001 56 57 // QueryNamespace invokes Query with the given namespace on the local machine. 58 func QueryNamespace(query string, dst interface{}, namespace string) error { 59 return Query(query, dst, nil, namespace) 60 } 61 62 // Query runs the WQL query and appends the values to dst. 63 // 64 // dst must have type *[]S or *[]*S, for some struct type S. Fields selected in 65 // the query must have the same name in dst. Supported types are all signed and 66 // unsigned integers, time.Time, string, bool, or a pointer to one of those. 67 // Array types are not supported. 68 // 69 // By default, the local machine and default namespace are used. These can be 70 // changed using connectServerArgs. See 71 // http://msdn.microsoft.com/en-us/library/aa393720.aspx for details. 72 // 73 // Query is a wrapper around DefaultClient.Query. 74 func Query(query string, dst interface{}, connectServerArgs ...interface{}) error { 75 if DefaultClient.SWbemServicesClient == nil { 76 return DefaultClient.Query(query, dst, connectServerArgs...) 77 } 78 return DefaultClient.SWbemServicesClient.Query(query, dst, connectServerArgs...) 79 } 80 81 // A Client is an WMI query client. 82 // 83 // Its zero value (DefaultClient) is a usable client. 84 type Client struct { 85 // NonePtrZero specifies if nil values for fields which aren't pointers 86 // should be returned as the field types zero value. 87 // 88 // Setting this to true allows stucts without pointer fields to be used 89 // without the risk failure should a nil value returned from WMI. 90 NonePtrZero bool 91 92 // PtrNil specifies if nil values for pointer fields should be returned 93 // as nil. 94 // 95 // Setting this to true will set pointer fields to nil where WMI 96 // returned nil, otherwise the types zero value will be returned. 97 PtrNil bool 98 99 // AllowMissingFields specifies that struct fields not present in the 100 // query result should not result in an error. 101 // 102 // Setting this to true allows custom queries to be used with full 103 // struct definitions instead of having to define multiple structs. 104 AllowMissingFields bool 105 106 // SWbemServiceClient is an optional SWbemServices object that can be 107 // initialized and then reused across multiple queries. If it is null 108 // then the method will initialize a new temporary client each time. 109 SWbemServicesClient *SWbemServices 110 } 111 112 // DefaultClient is the default Client and is used by Query, QueryNamespace 113 var DefaultClient = &Client{} 114 115 // Query runs the WQL query and appends the values to dst. 116 // 117 // dst must have type *[]S or *[]*S, for some struct type S. Fields selected in 118 // the query must have the same name in dst. Supported types are all signed and 119 // unsigned integers, time.Time, string, bool, or a pointer to one of those. 120 // Array types are not supported. 121 // 122 // By default, the local machine and default namespace are used. These can be 123 // changed using connectServerArgs. See 124 // http://msdn.microsoft.com/en-us/library/aa393720.aspx for details. 125 func (c *Client) Query(query string, dst interface{}, connectServerArgs ...interface{}) error { 126 dv := reflect.ValueOf(dst) 127 if dv.Kind() != reflect.Ptr || dv.IsNil() { 128 return ErrInvalidEntityType 129 } 130 dv = dv.Elem() 131 mat, elemType := checkMultiArg(dv) 132 if mat == multiArgTypeInvalid { 133 return ErrInvalidEntityType 134 } 135 136 lock.Lock() 137 defer lock.Unlock() 138 runtime.LockOSThread() 139 defer runtime.UnlockOSThread() 140 141 err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED) 142 if err != nil { 143 oleCode := err.(*ole.OleError).Code() 144 if oleCode != ole.S_OK && oleCode != S_FALSE { 145 return err 146 } 147 } 148 defer ole.CoUninitialize() 149 150 unknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator") 151 if err != nil { 152 return err 153 } else if unknown == nil { 154 return ErrNilCreateObject 155 } 156 defer unknown.Release() 157 158 wmi, err := unknown.QueryInterface(ole.IID_IDispatch) 159 if err != nil { 160 return err 161 } 162 defer wmi.Release() 163 164 // service is a SWbemServices 165 serviceRaw, err := oleutil.CallMethod(wmi, "ConnectServer", connectServerArgs...) 166 if err != nil { 167 return err 168 } 169 service := serviceRaw.ToIDispatch() 170 defer serviceRaw.Clear() 171 172 // result is a SWBemObjectSet 173 resultRaw, err := oleutil.CallMethod(service, "ExecQuery", query) 174 if err != nil { 175 return err 176 } 177 result := resultRaw.ToIDispatch() 178 defer resultRaw.Clear() 179 180 count, err := oleInt64(result, "Count") 181 if err != nil { 182 return err 183 } 184 185 enumProperty, err := result.GetProperty("_NewEnum") 186 if err != nil { 187 return err 188 } 189 defer enumProperty.Clear() 190 191 enum, err := enumProperty.ToIUnknown().IEnumVARIANT(ole.IID_IEnumVariant) 192 if err != nil { 193 return err 194 } 195 if enum == nil { 196 return fmt.Errorf("can't get IEnumVARIANT, enum is nil") 197 } 198 defer enum.Release() 199 200 // Initialize a slice with Count capacity 201 dv.Set(reflect.MakeSlice(dv.Type(), 0, int(count))) 202 203 var errFieldMismatch error 204 for itemRaw, length, err := enum.Next(1); length > 0; itemRaw, length, err = enum.Next(1) { 205 if err != nil { 206 return err 207 } 208 209 err := func() error { 210 // item is a SWbemObject, but really a Win32_Process 211 item := itemRaw.ToIDispatch() 212 defer item.Release() 213 214 ev := reflect.New(elemType) 215 if err = c.loadEntity(ev.Interface(), item); err != nil { 216 if _, ok := err.(*ErrFieldMismatch); ok { 217 // We continue loading entities even in the face of field mismatch errors. 218 // If we encounter any other error, that other error is returned. Otherwise, 219 // an ErrFieldMismatch is returned. 220 errFieldMismatch = err 221 } else { 222 return err 223 } 224 } 225 if mat != multiArgTypeStructPtr { 226 ev = ev.Elem() 227 } 228 dv.Set(reflect.Append(dv, ev)) 229 return nil 230 }() 231 if err != nil { 232 return err 233 } 234 } 235 return errFieldMismatch 236 } 237 238 // ErrFieldMismatch is returned when a field is to be loaded into a different 239 // type than the one it was stored from, or when a field is missing or 240 // unexported in the destination struct. 241 // StructType is the type of the struct pointed to by the destination argument. 242 type ErrFieldMismatch struct { 243 StructType reflect.Type 244 FieldName string 245 Reason string 246 } 247 248 func (e *ErrFieldMismatch) Error() string { 249 return fmt.Sprintf("wmi: cannot load field %q into a %q: %s", 250 e.FieldName, e.StructType, e.Reason) 251 } 252 253 var timeType = reflect.TypeOf(time.Time{}) 254 255 // loadEntity loads a SWbemObject into a struct pointer. 256 func (c *Client) loadEntity(dst interface{}, src *ole.IDispatch) (errFieldMismatch error) { 257 v := reflect.ValueOf(dst).Elem() 258 for i := 0; i < v.NumField(); i++ { 259 f := v.Field(i) 260 of := f 261 isPtr := f.Kind() == reflect.Ptr 262 if isPtr { 263 ptr := reflect.New(f.Type().Elem()) 264 f.Set(ptr) 265 f = f.Elem() 266 } 267 n := v.Type().Field(i).Name 268 if !f.CanSet() { 269 return &ErrFieldMismatch{ 270 StructType: of.Type(), 271 FieldName: n, 272 Reason: "CanSet() is false", 273 } 274 } 275 prop, err := oleutil.GetProperty(src, n) 276 if err != nil { 277 if !c.AllowMissingFields { 278 errFieldMismatch = &ErrFieldMismatch{ 279 StructType: of.Type(), 280 FieldName: n, 281 Reason: "no such struct field", 282 } 283 } 284 continue 285 } 286 defer prop.Clear() 287 288 if prop.VT == 0x1 { //VT_NULL 289 continue 290 } 291 292 switch val := prop.Value().(type) { 293 case int8, int16, int32, int64, int: 294 v := reflect.ValueOf(val).Int() 295 switch f.Kind() { 296 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 297 f.SetInt(v) 298 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 299 f.SetUint(uint64(v)) 300 default: 301 return &ErrFieldMismatch{ 302 StructType: of.Type(), 303 FieldName: n, 304 Reason: "not an integer class", 305 } 306 } 307 case uint8, uint16, uint32, uint64: 308 v := reflect.ValueOf(val).Uint() 309 switch f.Kind() { 310 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 311 f.SetInt(int64(v)) 312 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 313 f.SetUint(v) 314 default: 315 return &ErrFieldMismatch{ 316 StructType: of.Type(), 317 FieldName: n, 318 Reason: "not an integer class", 319 } 320 } 321 case string: 322 switch f.Kind() { 323 case reflect.String: 324 f.SetString(val) 325 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 326 iv, err := strconv.ParseInt(val, 10, 64) 327 if err != nil { 328 return err 329 } 330 f.SetInt(iv) 331 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 332 uv, err := strconv.ParseUint(val, 10, 64) 333 if err != nil { 334 return err 335 } 336 f.SetUint(uv) 337 case reflect.Struct: 338 switch f.Type() { 339 case timeType: 340 if len(val) == 25 { 341 mins, err := strconv.Atoi(val[22:]) 342 if err != nil { 343 return err 344 } 345 val = val[:22] + fmt.Sprintf("%02d%02d", mins/60, mins%60) 346 } 347 t, err := time.Parse("20060102150405.000000-0700", val) 348 if err != nil { 349 return err 350 } 351 f.Set(reflect.ValueOf(t)) 352 } 353 } 354 case bool: 355 switch f.Kind() { 356 case reflect.Bool: 357 f.SetBool(val) 358 default: 359 return &ErrFieldMismatch{ 360 StructType: of.Type(), 361 FieldName: n, 362 Reason: "not a bool", 363 } 364 } 365 case float32: 366 switch f.Kind() { 367 case reflect.Float32: 368 f.SetFloat(float64(val)) 369 default: 370 return &ErrFieldMismatch{ 371 StructType: of.Type(), 372 FieldName: n, 373 Reason: "not a Float32", 374 } 375 } 376 default: 377 if f.Kind() == reflect.Slice { 378 switch f.Type().Elem().Kind() { 379 case reflect.String: 380 safeArray := prop.ToArray() 381 if safeArray != nil { 382 arr := safeArray.ToValueArray() 383 fArr := reflect.MakeSlice(f.Type(), len(arr), len(arr)) 384 for i, v := range arr { 385 s := fArr.Index(i) 386 s.SetString(v.(string)) 387 } 388 f.Set(fArr) 389 } 390 case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint: 391 safeArray := prop.ToArray() 392 if safeArray != nil { 393 arr := safeArray.ToValueArray() 394 fArr := reflect.MakeSlice(f.Type(), len(arr), len(arr)) 395 for i, v := range arr { 396 s := fArr.Index(i) 397 s.SetUint(reflect.ValueOf(v).Uint()) 398 } 399 f.Set(fArr) 400 } 401 case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int: 402 safeArray := prop.ToArray() 403 if safeArray != nil { 404 arr := safeArray.ToValueArray() 405 fArr := reflect.MakeSlice(f.Type(), len(arr), len(arr)) 406 for i, v := range arr { 407 s := fArr.Index(i) 408 s.SetInt(reflect.ValueOf(v).Int()) 409 } 410 f.Set(fArr) 411 } 412 default: 413 return &ErrFieldMismatch{ 414 StructType: of.Type(), 415 FieldName: n, 416 Reason: fmt.Sprintf("unsupported slice type (%T)", val), 417 } 418 } 419 } else { 420 typeof := reflect.TypeOf(val) 421 if typeof == nil && (isPtr || c.NonePtrZero) { 422 if (isPtr && c.PtrNil) || (!isPtr && c.NonePtrZero) { 423 of.Set(reflect.Zero(of.Type())) 424 } 425 break 426 } 427 return &ErrFieldMismatch{ 428 StructType: of.Type(), 429 FieldName: n, 430 Reason: fmt.Sprintf("unsupported type (%T)", val), 431 } 432 } 433 } 434 } 435 return errFieldMismatch 436 } 437 438 type multiArgType int 439 440 const ( 441 multiArgTypeInvalid multiArgType = iota 442 multiArgTypeStruct 443 multiArgTypeStructPtr 444 ) 445 446 // checkMultiArg checks that v has type []S, []*S for some struct type S. 447 // 448 // It returns what category the slice's elements are, and the reflect.Type 449 // that represents S. 450 func checkMultiArg(v reflect.Value) (m multiArgType, elemType reflect.Type) { 451 if v.Kind() != reflect.Slice { 452 return multiArgTypeInvalid, nil 453 } 454 elemType = v.Type().Elem() 455 switch elemType.Kind() { 456 case reflect.Struct: 457 return multiArgTypeStruct, elemType 458 case reflect.Ptr: 459 elemType = elemType.Elem() 460 if elemType.Kind() == reflect.Struct { 461 return multiArgTypeStructPtr, elemType 462 } 463 } 464 return multiArgTypeInvalid, nil 465 } 466 467 func oleInt64(item *ole.IDispatch, prop string) (int64, error) { 468 v, err := oleutil.GetProperty(item, prop) 469 if err != nil { 470 return 0, err 471 } 472 defer v.Clear() 473 474 i := int64(v.Val) 475 return i, nil 476 } 477 478 // CreateQuery returns a WQL query string that queries all columns of src. where 479 // is an optional string that is appended to the query, to be used with WHERE 480 // clauses. In such a case, the "WHERE" string should appear at the beginning. 481 func CreateQuery(src interface{}, where string) string { 482 var b bytes.Buffer 483 b.WriteString("SELECT ") 484 s := reflect.Indirect(reflect.ValueOf(src)) 485 t := s.Type() 486 if s.Kind() == reflect.Slice { 487 t = t.Elem() 488 } 489 if t.Kind() != reflect.Struct { 490 return "" 491 } 492 var fields []string 493 for i := 0; i < t.NumField(); i++ { 494 fields = append(fields, t.Field(i).Name) 495 } 496 b.WriteString(strings.Join(fields, ", ")) 497 b.WriteString(" FROM ") 498 b.WriteString(t.Name()) 499 b.WriteString(" " + where) 500 return b.String() 501 }