golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/manager/ipc_client.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package manager
     7  
     8  import (
     9  	"encoding/gob"
    10  	"errors"
    11  	"os"
    12  	"sync"
    13  
    14  	"golang.zx2c4.com/wireguard/windows/conf"
    15  	"golang.zx2c4.com/wireguard/windows/updater"
    16  )
    17  
    18  type Tunnel struct {
    19  	Name string
    20  }
    21  
    22  type TunnelState int
    23  
    24  const (
    25  	TunnelUnknown TunnelState = iota
    26  	TunnelStarted
    27  	TunnelStopped
    28  	TunnelStarting
    29  	TunnelStopping
    30  )
    31  
    32  type NotificationType int
    33  
    34  const (
    35  	TunnelChangeNotificationType NotificationType = iota
    36  	TunnelsChangeNotificationType
    37  	ManagerStoppingNotificationType
    38  	UpdateFoundNotificationType
    39  	UpdateProgressNotificationType
    40  )
    41  
    42  type MethodType int
    43  
    44  const (
    45  	StoredConfigMethodType MethodType = iota
    46  	RuntimeConfigMethodType
    47  	StartMethodType
    48  	StopMethodType
    49  	WaitForStopMethodType
    50  	DeleteMethodType
    51  	StateMethodType
    52  	GlobalStateMethodType
    53  	CreateMethodType
    54  	TunnelsMethodType
    55  	QuitMethodType
    56  	UpdateStateMethodType
    57  	UpdateMethodType
    58  )
    59  
    60  var (
    61  	rpcEncoder *gob.Encoder
    62  	rpcDecoder *gob.Decoder
    63  	rpcMutex   sync.Mutex
    64  )
    65  
    66  type TunnelChangeCallback struct {
    67  	cb func(tunnel *Tunnel, state, globalState TunnelState, err error)
    68  }
    69  
    70  var tunnelChangeCallbacks = make(map[*TunnelChangeCallback]bool)
    71  
    72  type TunnelsChangeCallback struct {
    73  	cb func()
    74  }
    75  
    76  var tunnelsChangeCallbacks = make(map[*TunnelsChangeCallback]bool)
    77  
    78  type ManagerStoppingCallback struct {
    79  	cb func()
    80  }
    81  
    82  var managerStoppingCallbacks = make(map[*ManagerStoppingCallback]bool)
    83  
    84  type UpdateFoundCallback struct {
    85  	cb func(updateState UpdateState)
    86  }
    87  
    88  var updateFoundCallbacks = make(map[*UpdateFoundCallback]bool)
    89  
    90  type UpdateProgressCallback struct {
    91  	cb func(dp updater.DownloadProgress)
    92  }
    93  
    94  var updateProgressCallbacks = make(map[*UpdateProgressCallback]bool)
    95  
    96  func InitializeIPCClient(reader, writer, events *os.File) {
    97  	rpcDecoder = gob.NewDecoder(reader)
    98  	rpcEncoder = gob.NewEncoder(writer)
    99  	go func() {
   100  		decoder := gob.NewDecoder(events)
   101  		for {
   102  			var notificationType NotificationType
   103  			err := decoder.Decode(&notificationType)
   104  			if err != nil {
   105  				return
   106  			}
   107  			switch notificationType {
   108  			case TunnelChangeNotificationType:
   109  				var tunnel string
   110  				err := decoder.Decode(&tunnel)
   111  				if err != nil || len(tunnel) == 0 {
   112  					continue
   113  				}
   114  				var state TunnelState
   115  				err = decoder.Decode(&state)
   116  				if err != nil {
   117  					continue
   118  				}
   119  				var globalState TunnelState
   120  				err = decoder.Decode(&globalState)
   121  				if err != nil {
   122  					continue
   123  				}
   124  				var errStr string
   125  				err = decoder.Decode(&errStr)
   126  				if err != nil {
   127  					continue
   128  				}
   129  				var retErr error
   130  				if len(errStr) > 0 {
   131  					retErr = errors.New(errStr)
   132  				}
   133  				if state == TunnelUnknown {
   134  					continue
   135  				}
   136  				t := &Tunnel{tunnel}
   137  				for cb := range tunnelChangeCallbacks {
   138  					cb.cb(t, state, globalState, retErr)
   139  				}
   140  			case TunnelsChangeNotificationType:
   141  				for cb := range tunnelsChangeCallbacks {
   142  					cb.cb()
   143  				}
   144  			case ManagerStoppingNotificationType:
   145  				for cb := range managerStoppingCallbacks {
   146  					cb.cb()
   147  				}
   148  			case UpdateFoundNotificationType:
   149  				var state UpdateState
   150  				err = decoder.Decode(&state)
   151  				if err != nil {
   152  					continue
   153  				}
   154  				for cb := range updateFoundCallbacks {
   155  					cb.cb(state)
   156  				}
   157  			case UpdateProgressNotificationType:
   158  				var dp updater.DownloadProgress
   159  				err = decoder.Decode(&dp.Activity)
   160  				if err != nil {
   161  					continue
   162  				}
   163  				err = decoder.Decode(&dp.BytesDownloaded)
   164  				if err != nil {
   165  					continue
   166  				}
   167  				err = decoder.Decode(&dp.BytesTotal)
   168  				if err != nil {
   169  					continue
   170  				}
   171  				var errStr string
   172  				err = decoder.Decode(&errStr)
   173  				if err != nil {
   174  					continue
   175  				}
   176  				if len(errStr) > 0 {
   177  					dp.Error = errors.New(errStr)
   178  				}
   179  				err = decoder.Decode(&dp.Complete)
   180  				if err != nil {
   181  					continue
   182  				}
   183  				for cb := range updateProgressCallbacks {
   184  					cb.cb(dp)
   185  				}
   186  			}
   187  		}
   188  	}()
   189  }
   190  
   191  func rpcDecodeError() error {
   192  	var str string
   193  	err := rpcDecoder.Decode(&str)
   194  	if err != nil {
   195  		return err
   196  	}
   197  	if len(str) == 0 {
   198  		return nil
   199  	}
   200  	return errors.New(str)
   201  }
   202  
   203  func (t *Tunnel) StoredConfig() (c conf.Config, err error) {
   204  	rpcMutex.Lock()
   205  	defer rpcMutex.Unlock()
   206  
   207  	err = rpcEncoder.Encode(StoredConfigMethodType)
   208  	if err != nil {
   209  		return
   210  	}
   211  	err = rpcEncoder.Encode(t.Name)
   212  	if err != nil {
   213  		return
   214  	}
   215  	err = rpcDecoder.Decode(&c)
   216  	if err != nil {
   217  		return
   218  	}
   219  	err = rpcDecodeError()
   220  	return
   221  }
   222  
   223  func (t *Tunnel) RuntimeConfig() (c conf.Config, err error) {
   224  	rpcMutex.Lock()
   225  	defer rpcMutex.Unlock()
   226  
   227  	err = rpcEncoder.Encode(RuntimeConfigMethodType)
   228  	if err != nil {
   229  		return
   230  	}
   231  	err = rpcEncoder.Encode(t.Name)
   232  	if err != nil {
   233  		return
   234  	}
   235  	err = rpcDecoder.Decode(&c)
   236  	if err != nil {
   237  		return
   238  	}
   239  	err = rpcDecodeError()
   240  	return
   241  }
   242  
   243  func (t *Tunnel) Start() (err error) {
   244  	rpcMutex.Lock()
   245  	defer rpcMutex.Unlock()
   246  
   247  	err = rpcEncoder.Encode(StartMethodType)
   248  	if err != nil {
   249  		return
   250  	}
   251  	err = rpcEncoder.Encode(t.Name)
   252  	if err != nil {
   253  		return
   254  	}
   255  	err = rpcDecodeError()
   256  	return
   257  }
   258  
   259  func (t *Tunnel) Stop() (err error) {
   260  	rpcMutex.Lock()
   261  	defer rpcMutex.Unlock()
   262  
   263  	err = rpcEncoder.Encode(StopMethodType)
   264  	if err != nil {
   265  		return
   266  	}
   267  	err = rpcEncoder.Encode(t.Name)
   268  	if err != nil {
   269  		return
   270  	}
   271  	err = rpcDecodeError()
   272  	return
   273  }
   274  
   275  func (t *Tunnel) Toggle() (oldState TunnelState, err error) {
   276  	oldState, err = t.State()
   277  	if err != nil {
   278  		oldState = TunnelUnknown
   279  		return
   280  	}
   281  	if oldState == TunnelStarted {
   282  		err = t.Stop()
   283  	} else if oldState == TunnelStopped {
   284  		err = t.Start()
   285  	}
   286  	return
   287  }
   288  
   289  func (t *Tunnel) WaitForStop() (err error) {
   290  	rpcMutex.Lock()
   291  	defer rpcMutex.Unlock()
   292  
   293  	err = rpcEncoder.Encode(WaitForStopMethodType)
   294  	if err != nil {
   295  		return
   296  	}
   297  	err = rpcEncoder.Encode(t.Name)
   298  	if err != nil {
   299  		return
   300  	}
   301  	err = rpcDecodeError()
   302  	return
   303  }
   304  
   305  func (t *Tunnel) Delete() (err error) {
   306  	rpcMutex.Lock()
   307  	defer rpcMutex.Unlock()
   308  
   309  	err = rpcEncoder.Encode(DeleteMethodType)
   310  	if err != nil {
   311  		return
   312  	}
   313  	err = rpcEncoder.Encode(t.Name)
   314  	if err != nil {
   315  		return
   316  	}
   317  	err = rpcDecodeError()
   318  	return
   319  }
   320  
   321  func (t *Tunnel) State() (tunnelState TunnelState, err error) {
   322  	rpcMutex.Lock()
   323  	defer rpcMutex.Unlock()
   324  
   325  	err = rpcEncoder.Encode(StateMethodType)
   326  	if err != nil {
   327  		return
   328  	}
   329  	err = rpcEncoder.Encode(t.Name)
   330  	if err != nil {
   331  		return
   332  	}
   333  	err = rpcDecoder.Decode(&tunnelState)
   334  	if err != nil {
   335  		return
   336  	}
   337  	err = rpcDecodeError()
   338  	return
   339  }
   340  
   341  func IPCClientGlobalState() (tunnelState TunnelState, err error) {
   342  	rpcMutex.Lock()
   343  	defer rpcMutex.Unlock()
   344  
   345  	err = rpcEncoder.Encode(GlobalStateMethodType)
   346  	if err != nil {
   347  		return
   348  	}
   349  	err = rpcDecoder.Decode(&tunnelState)
   350  	if err != nil {
   351  		return
   352  	}
   353  	return
   354  }
   355  
   356  func IPCClientNewTunnel(conf *conf.Config) (tunnel Tunnel, err error) {
   357  	rpcMutex.Lock()
   358  	defer rpcMutex.Unlock()
   359  
   360  	err = rpcEncoder.Encode(CreateMethodType)
   361  	if err != nil {
   362  		return
   363  	}
   364  	err = rpcEncoder.Encode(*conf)
   365  	if err != nil {
   366  		return
   367  	}
   368  	err = rpcDecoder.Decode(&tunnel)
   369  	if err != nil {
   370  		return
   371  	}
   372  	err = rpcDecodeError()
   373  	return
   374  }
   375  
   376  func IPCClientTunnels() (tunnels []Tunnel, err error) {
   377  	rpcMutex.Lock()
   378  	defer rpcMutex.Unlock()
   379  
   380  	err = rpcEncoder.Encode(TunnelsMethodType)
   381  	if err != nil {
   382  		return
   383  	}
   384  	err = rpcDecoder.Decode(&tunnels)
   385  	if err != nil {
   386  		return
   387  	}
   388  	err = rpcDecodeError()
   389  	return
   390  }
   391  
   392  func IPCClientQuit(stopTunnelsOnQuit bool) (alreadyQuit bool, err error) {
   393  	rpcMutex.Lock()
   394  	defer rpcMutex.Unlock()
   395  
   396  	err = rpcEncoder.Encode(QuitMethodType)
   397  	if err != nil {
   398  		return
   399  	}
   400  	err = rpcEncoder.Encode(stopTunnelsOnQuit)
   401  	if err != nil {
   402  		return
   403  	}
   404  	err = rpcDecoder.Decode(&alreadyQuit)
   405  	if err != nil {
   406  		return
   407  	}
   408  	err = rpcDecodeError()
   409  	return
   410  }
   411  
   412  func IPCClientUpdateState() (updateState UpdateState, err error) {
   413  	rpcMutex.Lock()
   414  	defer rpcMutex.Unlock()
   415  
   416  	err = rpcEncoder.Encode(UpdateStateMethodType)
   417  	if err != nil {
   418  		return
   419  	}
   420  	err = rpcDecoder.Decode(&updateState)
   421  	if err != nil {
   422  		return
   423  	}
   424  	return
   425  }
   426  
   427  func IPCClientUpdate() error {
   428  	rpcMutex.Lock()
   429  	defer rpcMutex.Unlock()
   430  
   431  	return rpcEncoder.Encode(UpdateMethodType)
   432  }
   433  
   434  func IPCClientRegisterTunnelChange(cb func(tunnel *Tunnel, state, globalState TunnelState, err error)) *TunnelChangeCallback {
   435  	s := &TunnelChangeCallback{cb}
   436  	tunnelChangeCallbacks[s] = true
   437  	return s
   438  }
   439  
   440  func (cb *TunnelChangeCallback) Unregister() {
   441  	delete(tunnelChangeCallbacks, cb)
   442  }
   443  
   444  func IPCClientRegisterTunnelsChange(cb func()) *TunnelsChangeCallback {
   445  	s := &TunnelsChangeCallback{cb}
   446  	tunnelsChangeCallbacks[s] = true
   447  	return s
   448  }
   449  
   450  func (cb *TunnelsChangeCallback) Unregister() {
   451  	delete(tunnelsChangeCallbacks, cb)
   452  }
   453  
   454  func IPCClientRegisterManagerStopping(cb func()) *ManagerStoppingCallback {
   455  	s := &ManagerStoppingCallback{cb}
   456  	managerStoppingCallbacks[s] = true
   457  	return s
   458  }
   459  
   460  func (cb *ManagerStoppingCallback) Unregister() {
   461  	delete(managerStoppingCallbacks, cb)
   462  }
   463  
   464  func IPCClientRegisterUpdateFound(cb func(updateState UpdateState)) *UpdateFoundCallback {
   465  	s := &UpdateFoundCallback{cb}
   466  	updateFoundCallbacks[s] = true
   467  	return s
   468  }
   469  
   470  func (cb *UpdateFoundCallback) Unregister() {
   471  	delete(updateFoundCallbacks, cb)
   472  }
   473  
   474  func IPCClientRegisterUpdateProgress(cb func(dp updater.DownloadProgress)) *UpdateProgressCallback {
   475  	s := &UpdateProgressCallback{cb}
   476  	updateProgressCallbacks[s] = true
   477  	return s
   478  }
   479  
   480  func (cb *UpdateProgressCallback) Unregister() {
   481  	delete(updateProgressCallbacks, cb)
   482  }