golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/manager/ipc_client.go (about) 1 /* SPDX-License-Identifier: MIT 2 * 3 * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved. 4 */ 5 6 package manager 7 8 import ( 9 "encoding/gob" 10 "errors" 11 "os" 12 "sync" 13 14 "golang.zx2c4.com/wireguard/windows/conf" 15 "golang.zx2c4.com/wireguard/windows/updater" 16 ) 17 18 type Tunnel struct { 19 Name string 20 } 21 22 type TunnelState int 23 24 const ( 25 TunnelUnknown TunnelState = iota 26 TunnelStarted 27 TunnelStopped 28 TunnelStarting 29 TunnelStopping 30 ) 31 32 type NotificationType int 33 34 const ( 35 TunnelChangeNotificationType NotificationType = iota 36 TunnelsChangeNotificationType 37 ManagerStoppingNotificationType 38 UpdateFoundNotificationType 39 UpdateProgressNotificationType 40 ) 41 42 type MethodType int 43 44 const ( 45 StoredConfigMethodType MethodType = iota 46 RuntimeConfigMethodType 47 StartMethodType 48 StopMethodType 49 WaitForStopMethodType 50 DeleteMethodType 51 StateMethodType 52 GlobalStateMethodType 53 CreateMethodType 54 TunnelsMethodType 55 QuitMethodType 56 UpdateStateMethodType 57 UpdateMethodType 58 ) 59 60 var ( 61 rpcEncoder *gob.Encoder 62 rpcDecoder *gob.Decoder 63 rpcMutex sync.Mutex 64 ) 65 66 type TunnelChangeCallback struct { 67 cb func(tunnel *Tunnel, state, globalState TunnelState, err error) 68 } 69 70 var tunnelChangeCallbacks = make(map[*TunnelChangeCallback]bool) 71 72 type TunnelsChangeCallback struct { 73 cb func() 74 } 75 76 var tunnelsChangeCallbacks = make(map[*TunnelsChangeCallback]bool) 77 78 type ManagerStoppingCallback struct { 79 cb func() 80 } 81 82 var managerStoppingCallbacks = make(map[*ManagerStoppingCallback]bool) 83 84 type UpdateFoundCallback struct { 85 cb func(updateState UpdateState) 86 } 87 88 var updateFoundCallbacks = make(map[*UpdateFoundCallback]bool) 89 90 type UpdateProgressCallback struct { 91 cb func(dp updater.DownloadProgress) 92 } 93 94 var updateProgressCallbacks = make(map[*UpdateProgressCallback]bool) 95 96 func InitializeIPCClient(reader, writer, events *os.File) { 97 rpcDecoder = gob.NewDecoder(reader) 98 rpcEncoder = gob.NewEncoder(writer) 99 go func() { 100 decoder := gob.NewDecoder(events) 101 for { 102 var notificationType NotificationType 103 err := decoder.Decode(¬ificationType) 104 if err != nil { 105 return 106 } 107 switch notificationType { 108 case TunnelChangeNotificationType: 109 var tunnel string 110 err := decoder.Decode(&tunnel) 111 if err != nil || len(tunnel) == 0 { 112 continue 113 } 114 var state TunnelState 115 err = decoder.Decode(&state) 116 if err != nil { 117 continue 118 } 119 var globalState TunnelState 120 err = decoder.Decode(&globalState) 121 if err != nil { 122 continue 123 } 124 var errStr string 125 err = decoder.Decode(&errStr) 126 if err != nil { 127 continue 128 } 129 var retErr error 130 if len(errStr) > 0 { 131 retErr = errors.New(errStr) 132 } 133 if state == TunnelUnknown { 134 continue 135 } 136 t := &Tunnel{tunnel} 137 for cb := range tunnelChangeCallbacks { 138 cb.cb(t, state, globalState, retErr) 139 } 140 case TunnelsChangeNotificationType: 141 for cb := range tunnelsChangeCallbacks { 142 cb.cb() 143 } 144 case ManagerStoppingNotificationType: 145 for cb := range managerStoppingCallbacks { 146 cb.cb() 147 } 148 case UpdateFoundNotificationType: 149 var state UpdateState 150 err = decoder.Decode(&state) 151 if err != nil { 152 continue 153 } 154 for cb := range updateFoundCallbacks { 155 cb.cb(state) 156 } 157 case UpdateProgressNotificationType: 158 var dp updater.DownloadProgress 159 err = decoder.Decode(&dp.Activity) 160 if err != nil { 161 continue 162 } 163 err = decoder.Decode(&dp.BytesDownloaded) 164 if err != nil { 165 continue 166 } 167 err = decoder.Decode(&dp.BytesTotal) 168 if err != nil { 169 continue 170 } 171 var errStr string 172 err = decoder.Decode(&errStr) 173 if err != nil { 174 continue 175 } 176 if len(errStr) > 0 { 177 dp.Error = errors.New(errStr) 178 } 179 err = decoder.Decode(&dp.Complete) 180 if err != nil { 181 continue 182 } 183 for cb := range updateProgressCallbacks { 184 cb.cb(dp) 185 } 186 } 187 } 188 }() 189 } 190 191 func rpcDecodeError() error { 192 var str string 193 err := rpcDecoder.Decode(&str) 194 if err != nil { 195 return err 196 } 197 if len(str) == 0 { 198 return nil 199 } 200 return errors.New(str) 201 } 202 203 func (t *Tunnel) StoredConfig() (c conf.Config, err error) { 204 rpcMutex.Lock() 205 defer rpcMutex.Unlock() 206 207 err = rpcEncoder.Encode(StoredConfigMethodType) 208 if err != nil { 209 return 210 } 211 err = rpcEncoder.Encode(t.Name) 212 if err != nil { 213 return 214 } 215 err = rpcDecoder.Decode(&c) 216 if err != nil { 217 return 218 } 219 err = rpcDecodeError() 220 return 221 } 222 223 func (t *Tunnel) RuntimeConfig() (c conf.Config, err error) { 224 rpcMutex.Lock() 225 defer rpcMutex.Unlock() 226 227 err = rpcEncoder.Encode(RuntimeConfigMethodType) 228 if err != nil { 229 return 230 } 231 err = rpcEncoder.Encode(t.Name) 232 if err != nil { 233 return 234 } 235 err = rpcDecoder.Decode(&c) 236 if err != nil { 237 return 238 } 239 err = rpcDecodeError() 240 return 241 } 242 243 func (t *Tunnel) Start() (err error) { 244 rpcMutex.Lock() 245 defer rpcMutex.Unlock() 246 247 err = rpcEncoder.Encode(StartMethodType) 248 if err != nil { 249 return 250 } 251 err = rpcEncoder.Encode(t.Name) 252 if err != nil { 253 return 254 } 255 err = rpcDecodeError() 256 return 257 } 258 259 func (t *Tunnel) Stop() (err error) { 260 rpcMutex.Lock() 261 defer rpcMutex.Unlock() 262 263 err = rpcEncoder.Encode(StopMethodType) 264 if err != nil { 265 return 266 } 267 err = rpcEncoder.Encode(t.Name) 268 if err != nil { 269 return 270 } 271 err = rpcDecodeError() 272 return 273 } 274 275 func (t *Tunnel) Toggle() (oldState TunnelState, err error) { 276 oldState, err = t.State() 277 if err != nil { 278 oldState = TunnelUnknown 279 return 280 } 281 if oldState == TunnelStarted { 282 err = t.Stop() 283 } else if oldState == TunnelStopped { 284 err = t.Start() 285 } 286 return 287 } 288 289 func (t *Tunnel) WaitForStop() (err error) { 290 rpcMutex.Lock() 291 defer rpcMutex.Unlock() 292 293 err = rpcEncoder.Encode(WaitForStopMethodType) 294 if err != nil { 295 return 296 } 297 err = rpcEncoder.Encode(t.Name) 298 if err != nil { 299 return 300 } 301 err = rpcDecodeError() 302 return 303 } 304 305 func (t *Tunnel) Delete() (err error) { 306 rpcMutex.Lock() 307 defer rpcMutex.Unlock() 308 309 err = rpcEncoder.Encode(DeleteMethodType) 310 if err != nil { 311 return 312 } 313 err = rpcEncoder.Encode(t.Name) 314 if err != nil { 315 return 316 } 317 err = rpcDecodeError() 318 return 319 } 320 321 func (t *Tunnel) State() (tunnelState TunnelState, err error) { 322 rpcMutex.Lock() 323 defer rpcMutex.Unlock() 324 325 err = rpcEncoder.Encode(StateMethodType) 326 if err != nil { 327 return 328 } 329 err = rpcEncoder.Encode(t.Name) 330 if err != nil { 331 return 332 } 333 err = rpcDecoder.Decode(&tunnelState) 334 if err != nil { 335 return 336 } 337 err = rpcDecodeError() 338 return 339 } 340 341 func IPCClientGlobalState() (tunnelState TunnelState, err error) { 342 rpcMutex.Lock() 343 defer rpcMutex.Unlock() 344 345 err = rpcEncoder.Encode(GlobalStateMethodType) 346 if err != nil { 347 return 348 } 349 err = rpcDecoder.Decode(&tunnelState) 350 if err != nil { 351 return 352 } 353 return 354 } 355 356 func IPCClientNewTunnel(conf *conf.Config) (tunnel Tunnel, err error) { 357 rpcMutex.Lock() 358 defer rpcMutex.Unlock() 359 360 err = rpcEncoder.Encode(CreateMethodType) 361 if err != nil { 362 return 363 } 364 err = rpcEncoder.Encode(*conf) 365 if err != nil { 366 return 367 } 368 err = rpcDecoder.Decode(&tunnel) 369 if err != nil { 370 return 371 } 372 err = rpcDecodeError() 373 return 374 } 375 376 func IPCClientTunnels() (tunnels []Tunnel, err error) { 377 rpcMutex.Lock() 378 defer rpcMutex.Unlock() 379 380 err = rpcEncoder.Encode(TunnelsMethodType) 381 if err != nil { 382 return 383 } 384 err = rpcDecoder.Decode(&tunnels) 385 if err != nil { 386 return 387 } 388 err = rpcDecodeError() 389 return 390 } 391 392 func IPCClientQuit(stopTunnelsOnQuit bool) (alreadyQuit bool, err error) { 393 rpcMutex.Lock() 394 defer rpcMutex.Unlock() 395 396 err = rpcEncoder.Encode(QuitMethodType) 397 if err != nil { 398 return 399 } 400 err = rpcEncoder.Encode(stopTunnelsOnQuit) 401 if err != nil { 402 return 403 } 404 err = rpcDecoder.Decode(&alreadyQuit) 405 if err != nil { 406 return 407 } 408 err = rpcDecodeError() 409 return 410 } 411 412 func IPCClientUpdateState() (updateState UpdateState, err error) { 413 rpcMutex.Lock() 414 defer rpcMutex.Unlock() 415 416 err = rpcEncoder.Encode(UpdateStateMethodType) 417 if err != nil { 418 return 419 } 420 err = rpcDecoder.Decode(&updateState) 421 if err != nil { 422 return 423 } 424 return 425 } 426 427 func IPCClientUpdate() error { 428 rpcMutex.Lock() 429 defer rpcMutex.Unlock() 430 431 return rpcEncoder.Encode(UpdateMethodType) 432 } 433 434 func IPCClientRegisterTunnelChange(cb func(tunnel *Tunnel, state, globalState TunnelState, err error)) *TunnelChangeCallback { 435 s := &TunnelChangeCallback{cb} 436 tunnelChangeCallbacks[s] = true 437 return s 438 } 439 440 func (cb *TunnelChangeCallback) Unregister() { 441 delete(tunnelChangeCallbacks, cb) 442 } 443 444 func IPCClientRegisterTunnelsChange(cb func()) *TunnelsChangeCallback { 445 s := &TunnelsChangeCallback{cb} 446 tunnelsChangeCallbacks[s] = true 447 return s 448 } 449 450 func (cb *TunnelsChangeCallback) Unregister() { 451 delete(tunnelsChangeCallbacks, cb) 452 } 453 454 func IPCClientRegisterManagerStopping(cb func()) *ManagerStoppingCallback { 455 s := &ManagerStoppingCallback{cb} 456 managerStoppingCallbacks[s] = true 457 return s 458 } 459 460 func (cb *ManagerStoppingCallback) Unregister() { 461 delete(managerStoppingCallbacks, cb) 462 } 463 464 func IPCClientRegisterUpdateFound(cb func(updateState UpdateState)) *UpdateFoundCallback { 465 s := &UpdateFoundCallback{cb} 466 updateFoundCallbacks[s] = true 467 return s 468 } 469 470 func (cb *UpdateFoundCallback) Unregister() { 471 delete(updateFoundCallbacks, cb) 472 } 473 474 func IPCClientRegisterUpdateProgress(cb func(dp updater.DownloadProgress)) *UpdateProgressCallback { 475 s := &UpdateProgressCallback{cb} 476 updateProgressCallbacks[s] = true 477 return s 478 } 479 480 func (cb *UpdateProgressCallback) Unregister() { 481 delete(updateProgressCallbacks, cb) 482 }