golang.zx2c4.com/wireguard/windows@v0.5.4-0.20230123132234-dcc0eb72a04b/tunnel/mtumonitor.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2019-2022 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package tunnel
     7  
     8  import (
     9  	"golang.org/x/sys/windows"
    10  	"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
    11  )
    12  
    13  func findDefaultLUID(family winipcfg.AddressFamily, ourLUID winipcfg.LUID, lastLUID *winipcfg.LUID, lastIndex *uint32) error {
    14  	r, err := winipcfg.GetIPForwardTable2(family)
    15  	if err != nil {
    16  		return err
    17  	}
    18  	lowestMetric := ^uint32(0)
    19  	index := uint32(0)
    20  	luid := winipcfg.LUID(0)
    21  	for i := range r {
    22  		if r[i].DestinationPrefix.PrefixLength != 0 || r[i].InterfaceLUID == ourLUID {
    23  			continue
    24  		}
    25  		ifrow, err := r[i].InterfaceLUID.Interface()
    26  		if err != nil || ifrow.OperStatus != winipcfg.IfOperStatusUp {
    27  			continue
    28  		}
    29  
    30  		iface, err := r[i].InterfaceLUID.IPInterface(family)
    31  		if err != nil {
    32  			continue
    33  		}
    34  
    35  		if r[i].Metric+iface.Metric < lowestMetric {
    36  			lowestMetric = r[i].Metric + iface.Metric
    37  			index = r[i].InterfaceIndex
    38  			luid = r[i].InterfaceLUID
    39  		}
    40  	}
    41  	if luid == *lastLUID && index == *lastIndex {
    42  		return nil
    43  	}
    44  	*lastLUID = luid
    45  	*lastIndex = index
    46  	return nil
    47  }
    48  
    49  func monitorMTU(family winipcfg.AddressFamily, ourLUID winipcfg.LUID) ([]winipcfg.ChangeCallback, error) {
    50  	var minMTU uint32
    51  	if family == windows.AF_INET {
    52  		minMTU = 576
    53  	} else if family == windows.AF_INET6 {
    54  		minMTU = 1280
    55  	}
    56  	lastLUID := winipcfg.LUID(0)
    57  	lastIndex := ^uint32(0)
    58  	lastMTU := uint32(0)
    59  	doIt := func() error {
    60  		err := findDefaultLUID(family, ourLUID, &lastLUID, &lastIndex)
    61  		if err != nil {
    62  			return err
    63  		}
    64  		mtu := uint32(0)
    65  		if lastLUID != 0 {
    66  			iface, err := lastLUID.Interface()
    67  			if err != nil {
    68  				return err
    69  			}
    70  			if iface.MTU > 0 {
    71  				mtu = iface.MTU
    72  			}
    73  		}
    74  		if mtu > 0 && lastMTU != mtu {
    75  			iface, err := ourLUID.IPInterface(family)
    76  			if err != nil {
    77  				return err
    78  			}
    79  			iface.NLMTU = mtu - 80
    80  			if iface.NLMTU < minMTU {
    81  				iface.NLMTU = minMTU
    82  			}
    83  			err = iface.Set()
    84  			if err != nil {
    85  				return err
    86  			}
    87  			lastMTU = mtu
    88  		}
    89  		return nil
    90  	}
    91  	err := doIt()
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  	cbr, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.MibIPforwardRow2) {
    96  		if route != nil && route.DestinationPrefix.PrefixLength == 0 {
    97  			doIt()
    98  		}
    99  	})
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  	cbi, err := winipcfg.RegisterInterfaceChangeCallback(func(notificationType winipcfg.MibNotificationType, iface *winipcfg.MibIPInterfaceRow) {
   104  		if notificationType == winipcfg.MibParameterNotification {
   105  			doIt()
   106  		}
   107  	})
   108  	if err != nil {
   109  		cbr.Unregister()
   110  		return nil, err
   111  	}
   112  	return []winipcfg.ChangeCallback{cbr, cbi}, nil
   113  }