github.com/metacubex/sing-tun@v0.2.7-0.20240512075008-89e7c6208eec/internal/winfw/winfw.go (about)

     1  // Copyright (c) 2018 Samuel Melrose
     2  // SPDX-License-Identifier: MIT
     3  // https://github.com/iamacarpet/go-win64api/blob/ef6dbdd6db97301ae08a55eedea773476985a602/firewall.go
     4  
     5  //go:build windows
     6  
     7  package winfw
     8  
     9  import (
    10  	"fmt"
    11  	"runtime"
    12  
    13  	"github.com/go-ole/go-ole"
    14  	"github.com/go-ole/go-ole/oleutil"
    15  )
    16  
    17  // Firewall related API constants.
    18  const (
    19  	NET_FW_IP_PROTOCOL_TCP    = 6
    20  	NET_FW_IP_PROTOCOL_UDP    = 17
    21  	NET_FW_IP_PROTOCOL_ICMPv4 = 1
    22  	NET_FW_IP_PROTOCOL_ICMPv6 = 58
    23  	NET_FW_IP_PROTOCOL_ANY    = 256
    24  
    25  	NET_FW_RULE_DIR_IN  = 1
    26  	NET_FW_RULE_DIR_OUT = 2
    27  
    28  	NET_FW_ACTION_BLOCK = 0
    29  	NET_FW_ACTION_ALLOW = 1
    30  
    31  	// NET_FW_PROFILE2_CURRENT is not real API constant, just helper used in FW functions.
    32  	// It can mean one profile or multiple (even all) profiles. It depends on which profiles
    33  	// are currently in use. Every active interface can have it's own profile. F.e.: Public for Wifi,
    34  	// Domain for VPN, and Private for LAN. All at the same time.
    35  	NET_FW_PROFILE2_CURRENT = 0
    36  	NET_FW_PROFILE2_DOMAIN  = 1
    37  	NET_FW_PROFILE2_PRIVATE = 2
    38  	NET_FW_PROFILE2_PUBLIC  = 4
    39  	NET_FW_PROFILE2_ALL     = 2147483647
    40  )
    41  
    42  // Firewall Rule Groups
    43  // Use this magical strings instead of group names. It will work on all language Windows versions.
    44  // You can find more string locations here:
    45  // https://windows10dll.nirsoft.net/firewallapi_dll.html
    46  const (
    47  	NET_FW_FILE_AND_PRINTER_SHARING = "@FirewallAPI.dll,-28502"
    48  	NET_FW_REMOTE_DESKTOP           = "@FirewallAPI.dll,-28752"
    49  )
    50  
    51  // FWRule represents Firewall Rule.
    52  type FWRule struct {
    53  	Name, Description, ApplicationName, ServiceName string
    54  	LocalPorts, RemotePorts                         string
    55  	// LocalAddresses, RemoteAddresses are always returned with netmask, f.e.:
    56  	//   `10.10.1.1/255.255.255.0`
    57  	LocalAddresses, RemoteAddresses string
    58  	// ICMPTypesAndCodes is string. You can find define multiple codes separated by ":" (colon).
    59  	// Types are listed here:
    60  	// https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml
    61  	// So to allow ping set it to:
    62  	//   "0"
    63  	ICMPTypesAndCodes string
    64  	Grouping          string
    65  	// InterfaceTypes can be:
    66  	//   "LAN", "Wireless", "RemoteAccess", "All"
    67  	// You can add multiple deviding with comma:
    68  	//   "LAN, Wireless"
    69  	InterfaceTypes                        string
    70  	Protocol, Direction, Action, Profiles int32
    71  	Enabled, EdgeTraversal                bool
    72  }
    73  
    74  // FirewallRuleAddAdvanced allows to modify almost all available FW Rule parameters.
    75  // You probably do not want to use this, as function allows to create any rule, even opening all ports
    76  // in given profile. So use with caution.
    77  func FirewallRuleAddAdvanced(rule FWRule) (bool, error) {
    78  	return firewallRuleAdd(rule.Name, rule.Description, rule.Grouping, rule.ApplicationName, rule.ServiceName,
    79  		rule.LocalPorts, rule.RemotePorts, rule.LocalAddresses, rule.RemoteAddresses, rule.ICMPTypesAndCodes,
    80  		rule.Protocol, rule.Direction, rule.Action, rule.Profiles, rule.Enabled, rule.EdgeTraversal)
    81  }
    82  
    83  // firewallRuleAdd is universal function to add all kinds of rules.
    84  func firewallRuleAdd(name, description, group, appPath, serviceName, ports, remotePorts, localAddresses, remoteAddresses, icmpTypes string, protocol, direction, action, profile int32, enabled, edgeTraversal bool) (bool, error) {
    85  	if name == "" {
    86  		return false, fmt.Errorf("empty FW Rule name, name is mandatory")
    87  	}
    88  
    89  	runtime.LockOSThread()
    90  	defer runtime.UnlockOSThread()
    91  
    92  	u, fwPolicy, err := firewallAPIInit()
    93  	if err != nil {
    94  		return false, err
    95  	}
    96  	defer firewallAPIRelease(u, fwPolicy)
    97  
    98  	if profile == NET_FW_PROFILE2_CURRENT {
    99  		currentProfiles, err := oleutil.GetProperty(fwPolicy, "CurrentProfileTypes")
   100  		if err != nil {
   101  			return false, fmt.Errorf("Failed to get CurrentProfiles: %s", err)
   102  		}
   103  		profile = currentProfiles.Value().(int32)
   104  	}
   105  	unknownRules, err := oleutil.GetProperty(fwPolicy, "Rules")
   106  	if err != nil {
   107  		return false, fmt.Errorf("Failed to get Rules: %s", err)
   108  	}
   109  	rules := unknownRules.ToIDispatch()
   110  
   111  	if ok, err := FirewallRuleExistsByName(rules, name); err != nil {
   112  		return false, fmt.Errorf("Error while checking rules for duplicate: %s", err)
   113  	} else if ok {
   114  		return false, nil
   115  	}
   116  
   117  	unknown2, err := oleutil.CreateObject("HNetCfg.FWRule")
   118  	if err != nil {
   119  		return false, fmt.Errorf("Error creating Rule object: %s", err)
   120  	}
   121  	defer unknown2.Release()
   122  
   123  	fwRule, err := unknown2.QueryInterface(ole.IID_IDispatch)
   124  	if err != nil {
   125  		return false, fmt.Errorf("Error creating Rule object (2): %s", err)
   126  	}
   127  	defer fwRule.Release()
   128  
   129  	if _, err := oleutil.PutProperty(fwRule, "Name", name); err != nil {
   130  		return false, fmt.Errorf("Error setting property (Name) of Rule: %s", err)
   131  	}
   132  	if _, err := oleutil.PutProperty(fwRule, "Description", description); err != nil {
   133  		return false, fmt.Errorf("Error setting property (Description) of Rule: %s", err)
   134  	}
   135  	if appPath != "" {
   136  		if _, err := oleutil.PutProperty(fwRule, "Applicationname", appPath); err != nil {
   137  			return false, fmt.Errorf("Error setting property (Applicationname) of Rule: %s", err)
   138  		}
   139  	}
   140  	if serviceName != "" {
   141  		if _, err := oleutil.PutProperty(fwRule, "ServiceName", serviceName); err != nil {
   142  			return false, fmt.Errorf("Error setting property (ServiceName) of Rule: %s", err)
   143  		}
   144  	}
   145  	if protocol != 0 {
   146  		if _, err := oleutil.PutProperty(fwRule, "Protocol", protocol); err != nil {
   147  			return false, fmt.Errorf("Error setting property (Protocol) of Rule: %s", err)
   148  		}
   149  	}
   150  	if icmpTypes != "" {
   151  		if _, err := oleutil.PutProperty(fwRule, "IcmpTypesAndCodes", icmpTypes); err != nil {
   152  			return false, fmt.Errorf("Error setting property (IcmpTypesAndCodes) of Rule: %s", err)
   153  		}
   154  	}
   155  	if ports != "" {
   156  		if _, err := oleutil.PutProperty(fwRule, "LocalPorts", ports); err != nil {
   157  			return false, fmt.Errorf("Error setting property (LocalPorts) of Rule: %s", err)
   158  		}
   159  	}
   160  	if remotePorts != "" {
   161  		if _, err := oleutil.PutProperty(fwRule, "RemotePorts", remotePorts); err != nil {
   162  			return false, fmt.Errorf("Error setting property (RemotePorts) of Rule: %s", err)
   163  		}
   164  	}
   165  	if localAddresses != "" {
   166  		if _, err := oleutil.PutProperty(fwRule, "LocalAddresses", localAddresses); err != nil {
   167  			return false, fmt.Errorf("Error setting property (LocalAddresses) of Rule: %s", err)
   168  		}
   169  	}
   170  	if remoteAddresses != "" {
   171  		if _, err := oleutil.PutProperty(fwRule, "RemoteAddresses", remoteAddresses); err != nil {
   172  			return false, fmt.Errorf("Error setting property (RemoteAddresses) of Rule: %s", err)
   173  		}
   174  	}
   175  	if direction != 0 {
   176  		if _, err := oleutil.PutProperty(fwRule, "Direction", direction); err != nil {
   177  			return false, fmt.Errorf("Error setting property (Direction) of Rule: %s", err)
   178  		}
   179  	}
   180  	if _, err := oleutil.PutProperty(fwRule, "Enabled", enabled); err != nil {
   181  		return false, fmt.Errorf("Error setting property (Enabled) of Rule: %s", err)
   182  	}
   183  	if _, err := oleutil.PutProperty(fwRule, "Grouping", group); err != nil {
   184  		return false, fmt.Errorf("Error setting property (Grouping) of Rule: %s", err)
   185  	}
   186  	if _, err := oleutil.PutProperty(fwRule, "Profiles", profile); err != nil {
   187  		return false, fmt.Errorf("Error setting property (Profiles) of Rule: %s", err)
   188  	}
   189  	if _, err := oleutil.PutProperty(fwRule, "Action", action); err != nil {
   190  		return false, fmt.Errorf("Error setting property (Action) of Rule: %s", err)
   191  	}
   192  	if edgeTraversal {
   193  		if _, err := oleutil.PutProperty(fwRule, "EdgeTraversal", edgeTraversal); err != nil {
   194  			return false, fmt.Errorf("Error setting property (EdgeTraversal) of Rule: %s", err)
   195  		}
   196  	}
   197  
   198  	if _, err := oleutil.CallMethod(rules, "Add", fwRule); err != nil {
   199  		return false, fmt.Errorf("Error adding Rule: %s", err)
   200  	}
   201  
   202  	return true, nil
   203  }
   204  
   205  func FirewallRuleExistsByName(rules *ole.IDispatch, name string) (bool, error) {
   206  	enumProperty, err := rules.GetProperty("_NewEnum")
   207  	if err != nil {
   208  		return false, fmt.Errorf("Failed to get enumeration property on Rules: %s", err)
   209  	}
   210  	defer enumProperty.Clear()
   211  
   212  	enum, err := enumProperty.ToIUnknown().IEnumVARIANT(ole.IID_IEnumVariant)
   213  	if err != nil {
   214  		return false, fmt.Errorf("Failed to cast enum to correct type: %s", err)
   215  	}
   216  	if enum == nil {
   217  		return false, fmt.Errorf("can't get IEnumVARIANT, enum is nil")
   218  	}
   219  
   220  	for itemRaw, length, err := enum.Next(1); length > 0; itemRaw, length, err = enum.Next(1) {
   221  		if err != nil {
   222  			return false, fmt.Errorf("Failed to seek next Rule item: %s", err)
   223  		}
   224  
   225  		t, err := func() (bool, error) {
   226  			item := itemRaw.ToIDispatch()
   227  			defer item.Release()
   228  
   229  			if item, err := oleutil.GetProperty(item, "Name"); err != nil {
   230  				return false, fmt.Errorf("Failed to get Property (Name) of Rule")
   231  			} else if item.ToString() == name {
   232  				return true, nil
   233  			}
   234  
   235  			return false, nil
   236  		}()
   237  
   238  		if err != nil {
   239  			return false, err
   240  		} else if t {
   241  			return true, nil
   242  		}
   243  	}
   244  
   245  	return false, nil
   246  }
   247  
   248  // firewallAPIInit initialize common fw api.
   249  // then:
   250  // dispatch firewallAPIRelease(u, fwp)
   251  func firewallAPIInit() (*ole.IUnknown, *ole.IDispatch, error) {
   252  	err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
   253  	if err != nil {
   254  		return nil, nil, fmt.Errorf("Failed to initialize COM: %s", err)
   255  	}
   256  
   257  	unknown, err := oleutil.CreateObject("HNetCfg.FwPolicy2")
   258  	if err != nil {
   259  		return nil, nil, fmt.Errorf("Failed to create FwPolicy Object: %s", err)
   260  	}
   261  
   262  	fwPolicy, err := unknown.QueryInterface(ole.IID_IDispatch)
   263  	if err != nil {
   264  		unknown.Release()
   265  		return nil, nil, fmt.Errorf("Failed to create FwPolicy Object (2): %s", err)
   266  	}
   267  
   268  	return unknown, fwPolicy, nil
   269  }
   270  
   271  // firewallAPIRelease cleans memory.
   272  func firewallAPIRelease(u *ole.IUnknown, fwp *ole.IDispatch) {
   273  	fwp.Release()
   274  	u.Release()
   275  	ole.CoUninitialize()
   276  }