golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/ui/listview.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package ui
     7  
     8  import (
     9  	"sort"
    10  	"sync/atomic"
    11  
    12  	"github.com/lxn/win"
    13  
    14  	"golang.zx2c4.com/wireguard/windows/conf"
    15  	"golang.zx2c4.com/wireguard/windows/manager"
    16  
    17  	"github.com/lxn/walk"
    18  )
    19  
    20  // ListModel is a struct to store the currently known tunnels to the GUI, suitable as a model for a walk.TableView.
    21  type ListModel struct {
    22  	walk.TableModelBase
    23  	walk.SorterBase
    24  
    25  	tunnels           []manager.Tunnel
    26  	lastObservedState map[manager.Tunnel]manager.TunnelState
    27  }
    28  
    29  var cachedListViewIconsForWidthAndState = make(map[widthAndState]*walk.Bitmap)
    30  
    31  func (t *ListModel) RowCount() int {
    32  	return len(t.tunnels)
    33  }
    34  
    35  func (t *ListModel) Value(row, col int) any {
    36  	if col != 0 || row < 0 || row >= len(t.tunnels) {
    37  		return ""
    38  	}
    39  	return t.tunnels[row].Name
    40  }
    41  
    42  func (t *ListModel) Sort(col int, order walk.SortOrder) error {
    43  	sort.SliceStable(t.tunnels, func(i, j int) bool {
    44  		return conf.TunnelNameIsLess(t.tunnels[i].Name, t.tunnels[j].Name)
    45  	})
    46  
    47  	return t.SorterBase.Sort(col, order)
    48  }
    49  
    50  type ListView struct {
    51  	*walk.TableView
    52  
    53  	model *ListModel
    54  
    55  	tunnelChangedCB        *manager.TunnelChangeCallback
    56  	tunnelsChangedCB       *manager.TunnelsChangeCallback
    57  	tunnelsUpdateSuspended int32
    58  }
    59  
    60  func NewListView(parent walk.Container) (*ListView, error) {
    61  	var disposables walk.Disposables
    62  	defer disposables.Treat()
    63  
    64  	tv, err := walk.NewTableView(parent)
    65  	if err != nil {
    66  		return nil, err
    67  	}
    68  	disposables.Add(tv)
    69  
    70  	tv.SetDoubleBuffering(true)
    71  
    72  	model := new(ListModel)
    73  	model.lastObservedState = make(map[manager.Tunnel]manager.TunnelState)
    74  	tv.SetModel(model)
    75  	tv.SetLastColumnStretched(true)
    76  	tv.SetHeaderHidden(true)
    77  	tv.SetIgnoreNowhere(true)
    78  	tv.SetScrollbarOrientation(walk.Vertical)
    79  	tv.Columns().Add(walk.NewTableViewColumn())
    80  
    81  	tunnelsView := &ListView{
    82  		TableView: tv,
    83  		model:     model,
    84  	}
    85  	tv.SetCellStyler(tunnelsView)
    86  
    87  	disposables.Spare()
    88  
    89  	tunnelsView.tunnelChangedCB = manager.IPCClientRegisterTunnelChange(tunnelsView.onTunnelChange)
    90  	tunnelsView.tunnelsChangedCB = manager.IPCClientRegisterTunnelsChange(tunnelsView.onTunnelsChange)
    91  
    92  	return tunnelsView, nil
    93  }
    94  
    95  func (tv *ListView) Dispose() {
    96  	if tv.tunnelChangedCB != nil {
    97  		tv.tunnelChangedCB.Unregister()
    98  		tv.tunnelChangedCB = nil
    99  	}
   100  	if tv.tunnelsChangedCB != nil {
   101  		tv.tunnelsChangedCB.Unregister()
   102  		tv.tunnelsChangedCB = nil
   103  	}
   104  	tv.TableView.Dispose()
   105  }
   106  
   107  func (tv *ListView) CurrentTunnel() *manager.Tunnel {
   108  	idx := tv.CurrentIndex()
   109  	if idx == -1 {
   110  		return nil
   111  	}
   112  
   113  	return &tv.model.tunnels[idx]
   114  }
   115  
   116  var dummyBitmap *walk.Bitmap
   117  
   118  func (tv *ListView) StyleCell(style *walk.CellStyle) {
   119  	row := style.Row()
   120  	if row < 0 || row >= len(tv.model.tunnels) {
   121  		return
   122  	}
   123  	tunnel := &tv.model.tunnels[row]
   124  
   125  	var state manager.TunnelState
   126  	var ok bool
   127  	state, ok = tv.model.lastObservedState[tv.model.tunnels[row]]
   128  	if !ok {
   129  		var err error
   130  		state, err = tunnel.State()
   131  		if err != nil {
   132  			return
   133  		}
   134  		tv.model.lastObservedState[tv.model.tunnels[row]] = state
   135  	}
   136  
   137  	icon, err := iconForState(state, 14)
   138  	if err != nil {
   139  		return
   140  	}
   141  	margin := tv.IntFrom96DPI(1)
   142  	bitmapWidth := tv.IntFrom96DPI(16)
   143  
   144  	if win.IsAppThemed() {
   145  		cacheKey := widthAndState{bitmapWidth, state}
   146  		if cacheValue, ok := cachedListViewIconsForWidthAndState[cacheKey]; ok {
   147  			style.Image = cacheValue
   148  			return
   149  		}
   150  		bitmap, err := walk.NewBitmapWithTransparentPixelsForDPI(walk.Size{bitmapWidth, bitmapWidth}, tv.DPI())
   151  		if err != nil {
   152  			return
   153  		}
   154  		canvas, err := walk.NewCanvasFromImage(bitmap)
   155  		if err != nil {
   156  			return
   157  		}
   158  		bounds := walk.Rectangle{X: margin, Y: margin, Height: bitmapWidth - 2*margin, Width: bitmapWidth - 2*margin}
   159  		err = canvas.DrawImageStretchedPixels(icon, bounds)
   160  		canvas.Dispose()
   161  		if err != nil {
   162  			return
   163  		}
   164  		cachedListViewIconsForWidthAndState[cacheKey] = bitmap
   165  		style.Image = bitmap
   166  	} else {
   167  		if dummyBitmap == nil {
   168  			dummyBitmap, _ = walk.NewBitmapForDPI(tv.SizeFrom96DPI(walk.Size{}), 96)
   169  		}
   170  		style.Image = dummyBitmap
   171  		canvas := style.Canvas()
   172  		if canvas == nil {
   173  			return
   174  		}
   175  		bounds := style.BoundsPixels()
   176  		bounds.Width = bitmapWidth - 2*margin
   177  		bounds.X = (bounds.Height - bounds.Width) / 2
   178  		bounds.Height = bounds.Width
   179  		bounds.Y += bounds.X
   180  		canvas.DrawImageStretchedPixels(icon, bounds)
   181  	}
   182  }
   183  
   184  func (tv *ListView) onTunnelChange(tunnel *manager.Tunnel, state, globalState manager.TunnelState, err error) {
   185  	tv.Synchronize(func() {
   186  		idx := -1
   187  		for i := range tv.model.tunnels {
   188  			if tv.model.tunnels[i].Name == tunnel.Name {
   189  				idx = i
   190  				break
   191  			}
   192  		}
   193  
   194  		if idx != -1 {
   195  			tv.model.lastObservedState[tv.model.tunnels[idx]] = state
   196  			tv.model.PublishRowChanged(idx)
   197  			return
   198  		}
   199  	})
   200  }
   201  
   202  func (tv *ListView) onTunnelsChange() {
   203  	if atomic.LoadInt32(&tv.tunnelsUpdateSuspended) == 0 {
   204  		tv.Load(true)
   205  	}
   206  }
   207  
   208  func (tv *ListView) SetSuspendTunnelsUpdate(suspend bool) {
   209  	if suspend {
   210  		atomic.AddInt32(&tv.tunnelsUpdateSuspended, 1)
   211  	} else {
   212  		atomic.AddInt32(&tv.tunnelsUpdateSuspended, -1)
   213  	}
   214  	tv.Load(true)
   215  }
   216  
   217  func (tv *ListView) Load(asyncUI bool) {
   218  	tunnels, err := manager.IPCClientTunnels()
   219  	if err != nil {
   220  		return
   221  	}
   222  	doUI := func() {
   223  		newTunnels := make(map[manager.Tunnel]bool, len(tunnels))
   224  		oldTunnels := make(map[manager.Tunnel]bool, len(tv.model.tunnels))
   225  		for _, tunnel := range tunnels {
   226  			newTunnels[tunnel] = true
   227  		}
   228  		for i := len(tv.model.tunnels); i > 0; {
   229  			i--
   230  			tunnel := tv.model.tunnels[i]
   231  			oldTunnels[tunnel] = true
   232  			if !newTunnels[tunnel] {
   233  				tv.model.tunnels = append(tv.model.tunnels[:i], tv.model.tunnels[i+1:]...)
   234  				tv.model.PublishRowsRemoved(i, i) // TODO: Do we have to call that everytime or can we pass a range?
   235  				delete(tv.model.lastObservedState, tunnel)
   236  			}
   237  		}
   238  		didAdd := false
   239  		firstTunnelName := ""
   240  		for tunnel := range newTunnels {
   241  			if !oldTunnels[tunnel] {
   242  				if len(firstTunnelName) == 0 || !conf.TunnelNameIsLess(firstTunnelName, tunnel.Name) {
   243  					firstTunnelName = tunnel.Name
   244  				}
   245  				tv.model.tunnels = append(tv.model.tunnels, tunnel)
   246  				didAdd = true
   247  			}
   248  		}
   249  		if didAdd {
   250  			tv.model.PublishRowsReset()
   251  			tv.model.Sort(tv.model.SortedColumn(), tv.model.SortOrder())
   252  			if len(tv.SelectedIndexes()) == 0 {
   253  				tv.selectTunnel(firstTunnelName)
   254  			}
   255  		}
   256  	}
   257  	if asyncUI {
   258  		tv.Synchronize(doUI)
   259  	} else {
   260  		doUI()
   261  	}
   262  }
   263  
   264  func (tv *ListView) selectTunnel(tunnelName string) {
   265  	for i, tunnel := range tv.model.tunnels {
   266  		if tunnel.Name == tunnelName {
   267  			tv.SetCurrentIndex(i)
   268  			break
   269  		}
   270  	}
   271  }
   272  
   273  func (tv *ListView) SelectFirstActiveTunnel() {
   274  	tunnels := make([]manager.Tunnel, len(tv.model.tunnels))
   275  	copy(tunnels, tv.model.tunnels)
   276  	go func() {
   277  		for _, tunnel := range tunnels {
   278  			state, err := tunnel.State()
   279  			if err != nil {
   280  				continue
   281  			}
   282  			if state == manager.TunnelStarting || state == manager.TunnelStarted {
   283  				tv.Synchronize(func() {
   284  					tv.selectTunnel(tunnel.Name)
   285  				})
   286  				return
   287  			}
   288  		}
   289  	}()
   290  }