github.com/MerlinKodo/sing-tun@v0.1.15/internal/winfw/winfw.go (about)

     1  // Copyright (c) 2018 Samuel Melrose
     2  // SPDX-License-Identifier: MIT
     3  // https://github.com/iamacarpet/go-win64api/blob/ef6dbdd6db97301ae08a55eedea773476985a602/firewall.go
     4  
     5  //go:build windows
     6  
     7  package winfw
     8  
     9  import (
    10  	"fmt"
    11  	"runtime"
    12  
    13  	"github.com/go-ole/go-ole"
    14  	"github.com/go-ole/go-ole/oleutil"
    15  	"github.com/scjalliance/comshim"
    16  )
    17  
    18  // Firewall related API constants.
    19  const (
    20  	NET_FW_IP_PROTOCOL_TCP    = 6
    21  	NET_FW_IP_PROTOCOL_UDP    = 17
    22  	NET_FW_IP_PROTOCOL_ICMPv4 = 1
    23  	NET_FW_IP_PROTOCOL_ICMPv6 = 58
    24  	NET_FW_IP_PROTOCOL_ANY    = 256
    25  
    26  	NET_FW_RULE_DIR_IN  = 1
    27  	NET_FW_RULE_DIR_OUT = 2
    28  
    29  	NET_FW_ACTION_BLOCK = 0
    30  	NET_FW_ACTION_ALLOW = 1
    31  
    32  	// NET_FW_PROFILE2_CURRENT is not real API constant, just helper used in FW functions.
    33  	// It can mean one profile or multiple (even all) profiles. It depends on which profiles
    34  	// are currently in use. Every active interface can have it's own profile. F.e.: Public for Wifi,
    35  	// Domain for VPN, and Private for LAN. All at the same time.
    36  	NET_FW_PROFILE2_CURRENT = 0
    37  	NET_FW_PROFILE2_DOMAIN  = 1
    38  	NET_FW_PROFILE2_PRIVATE = 2
    39  	NET_FW_PROFILE2_PUBLIC  = 4
    40  	NET_FW_PROFILE2_ALL     = 2147483647
    41  )
    42  
    43  // Firewall Rule Groups
    44  // Use this magical strings instead of group names. It will work on all language Windows versions.
    45  // You can find more string locations here:
    46  // https://windows10dll.nirsoft.net/firewallapi_dll.html
    47  const (
    48  	NET_FW_FILE_AND_PRINTER_SHARING = "@FirewallAPI.dll,-28502"
    49  	NET_FW_REMOTE_DESKTOP           = "@FirewallAPI.dll,-28752"
    50  )
    51  
    52  // FWRule represents Firewall Rule.
    53  type FWRule struct {
    54  	Name, Description, ApplicationName, ServiceName string
    55  	LocalPorts, RemotePorts                         string
    56  	// LocalAddresses, RemoteAddresses are always returned with netmask, f.e.:
    57  	//   `10.10.1.1/255.255.255.0`
    58  	LocalAddresses, RemoteAddresses string
    59  	// ICMPTypesAndCodes is string. You can find define multiple codes separated by ":" (colon).
    60  	// Types are listed here:
    61  	// https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml
    62  	// So to allow ping set it to:
    63  	//   "0"
    64  	ICMPTypesAndCodes string
    65  	Grouping          string
    66  	// InterfaceTypes can be:
    67  	//   "LAN", "Wireless", "RemoteAccess", "All"
    68  	// You can add multiple deviding with comma:
    69  	//   "LAN, Wireless"
    70  	InterfaceTypes                        string
    71  	Protocol, Direction, Action, Profiles int32
    72  	Enabled, EdgeTraversal                bool
    73  }
    74  
    75  // FirewallRuleAddAdvanced allows to modify almost all available FW Rule parameters.
    76  // You probably do not want to use this, as function allows to create any rule, even opening all ports
    77  // in given profile. So use with caution.
    78  func FirewallRuleAddAdvanced(rule FWRule) (bool, error) {
    79  	return firewallRuleAdd(rule.Name, rule.Description, rule.Grouping, rule.ApplicationName, rule.ServiceName,
    80  		rule.LocalPorts, rule.RemotePorts, rule.LocalAddresses, rule.RemoteAddresses, rule.ICMPTypesAndCodes,
    81  		rule.Protocol, rule.Direction, rule.Action, rule.Profiles, rule.Enabled, rule.EdgeTraversal)
    82  }
    83  
    84  // firewallRuleAdd is universal function to add all kinds of rules.
    85  func firewallRuleAdd(name, description, group, appPath, serviceName, ports, remotePorts, localAddresses, remoteAddresses, icmpTypes string, protocol, direction, action, profile int32, enabled, edgeTraversal bool) (bool, error) {
    86  	if name == "" {
    87  		return false, fmt.Errorf("empty FW Rule name, name is mandatory")
    88  	}
    89  
    90  	runtime.LockOSThread()
    91  	defer runtime.UnlockOSThread()
    92  
    93  	u, fwPolicy, err := firewallAPIInit()
    94  	if err != nil {
    95  		return false, err
    96  	}
    97  	defer firewallAPIRelease(u, fwPolicy)
    98  
    99  	if profile == NET_FW_PROFILE2_CURRENT {
   100  		currentProfiles, err := oleutil.GetProperty(fwPolicy, "CurrentProfileTypes")
   101  		if err != nil {
   102  			return false, fmt.Errorf("Failed to get CurrentProfiles: %s", err)
   103  		}
   104  		profile = currentProfiles.Value().(int32)
   105  	}
   106  	unknownRules, err := oleutil.GetProperty(fwPolicy, "Rules")
   107  	if err != nil {
   108  		return false, fmt.Errorf("Failed to get Rules: %s", err)
   109  	}
   110  	rules := unknownRules.ToIDispatch()
   111  
   112  	if ok, err := FirewallRuleExistsByName(rules, name); err != nil {
   113  		return false, fmt.Errorf("Error while checking rules for duplicate: %s", err)
   114  	} else if ok {
   115  		return false, nil
   116  	}
   117  
   118  	unknown2, err := oleutil.CreateObject("HNetCfg.FWRule")
   119  	if err != nil {
   120  		return false, fmt.Errorf("Error creating Rule object: %s", err)
   121  	}
   122  	defer unknown2.Release()
   123  
   124  	fwRule, err := unknown2.QueryInterface(ole.IID_IDispatch)
   125  	if err != nil {
   126  		return false, fmt.Errorf("Error creating Rule object (2): %s", err)
   127  	}
   128  	defer fwRule.Release()
   129  
   130  	if _, err := oleutil.PutProperty(fwRule, "Name", name); err != nil {
   131  		return false, fmt.Errorf("Error setting property (Name) of Rule: %s", err)
   132  	}
   133  	if _, err := oleutil.PutProperty(fwRule, "Description", description); err != nil {
   134  		return false, fmt.Errorf("Error setting property (Description) of Rule: %s", err)
   135  	}
   136  	if appPath != "" {
   137  		if _, err := oleutil.PutProperty(fwRule, "Applicationname", appPath); err != nil {
   138  			return false, fmt.Errorf("Error setting property (Applicationname) of Rule: %s", err)
   139  		}
   140  	}
   141  	if serviceName != "" {
   142  		if _, err := oleutil.PutProperty(fwRule, "ServiceName", serviceName); err != nil {
   143  			return false, fmt.Errorf("Error setting property (ServiceName) of Rule: %s", err)
   144  		}
   145  	}
   146  	if protocol != 0 {
   147  		if _, err := oleutil.PutProperty(fwRule, "Protocol", protocol); err != nil {
   148  			return false, fmt.Errorf("Error setting property (Protocol) of Rule: %s", err)
   149  		}
   150  	}
   151  	if icmpTypes != "" {
   152  		if _, err := oleutil.PutProperty(fwRule, "IcmpTypesAndCodes", icmpTypes); err != nil {
   153  			return false, fmt.Errorf("Error setting property (IcmpTypesAndCodes) of Rule: %s", err)
   154  		}
   155  	}
   156  	if ports != "" {
   157  		if _, err := oleutil.PutProperty(fwRule, "LocalPorts", ports); err != nil {
   158  			return false, fmt.Errorf("Error setting property (LocalPorts) of Rule: %s", err)
   159  		}
   160  	}
   161  	if remotePorts != "" {
   162  		if _, err := oleutil.PutProperty(fwRule, "RemotePorts", remotePorts); err != nil {
   163  			return false, fmt.Errorf("Error setting property (RemotePorts) of Rule: %s", err)
   164  		}
   165  	}
   166  	if localAddresses != "" {
   167  		if _, err := oleutil.PutProperty(fwRule, "LocalAddresses", localAddresses); err != nil {
   168  			return false, fmt.Errorf("Error setting property (LocalAddresses) of Rule: %s", err)
   169  		}
   170  	}
   171  	if remoteAddresses != "" {
   172  		if _, err := oleutil.PutProperty(fwRule, "RemoteAddresses", remoteAddresses); err != nil {
   173  			return false, fmt.Errorf("Error setting property (RemoteAddresses) of Rule: %s", err)
   174  		}
   175  	}
   176  	if direction != 0 {
   177  		if _, err := oleutil.PutProperty(fwRule, "Direction", direction); err != nil {
   178  			return false, fmt.Errorf("Error setting property (Direction) of Rule: %s", err)
   179  		}
   180  	}
   181  	if _, err := oleutil.PutProperty(fwRule, "Enabled", enabled); err != nil {
   182  		return false, fmt.Errorf("Error setting property (Enabled) of Rule: %s", err)
   183  	}
   184  	if _, err := oleutil.PutProperty(fwRule, "Grouping", group); err != nil {
   185  		return false, fmt.Errorf("Error setting property (Grouping) of Rule: %s", err)
   186  	}
   187  	if _, err := oleutil.PutProperty(fwRule, "Profiles", profile); err != nil {
   188  		return false, fmt.Errorf("Error setting property (Profiles) of Rule: %s", err)
   189  	}
   190  	if _, err := oleutil.PutProperty(fwRule, "Action", action); err != nil {
   191  		return false, fmt.Errorf("Error setting property (Action) of Rule: %s", err)
   192  	}
   193  	if edgeTraversal {
   194  		if _, err := oleutil.PutProperty(fwRule, "EdgeTraversal", edgeTraversal); err != nil {
   195  			return false, fmt.Errorf("Error setting property (EdgeTraversal) of Rule: %s", err)
   196  		}
   197  	}
   198  
   199  	if _, err := oleutil.CallMethod(rules, "Add", fwRule); err != nil {
   200  		return false, fmt.Errorf("Error adding Rule: %s", err)
   201  	}
   202  
   203  	return true, nil
   204  }
   205  
   206  func FirewallRuleExistsByName(rules *ole.IDispatch, name string) (bool, error) {
   207  	enumProperty, err := rules.GetProperty("_NewEnum")
   208  	if err != nil {
   209  		return false, fmt.Errorf("Failed to get enumeration property on Rules: %s", err)
   210  	}
   211  	defer enumProperty.Clear()
   212  
   213  	enum, err := enumProperty.ToIUnknown().IEnumVARIANT(ole.IID_IEnumVariant)
   214  	if err != nil {
   215  		return false, fmt.Errorf("Failed to cast enum to correct type: %s", err)
   216  	}
   217  	if enum == nil {
   218  		return false, fmt.Errorf("can't get IEnumVARIANT, enum is nil")
   219  	}
   220  
   221  	for itemRaw, length, err := enum.Next(1); length > 0; itemRaw, length, err = enum.Next(1) {
   222  		if err != nil {
   223  			return false, fmt.Errorf("Failed to seek next Rule item: %s", err)
   224  		}
   225  
   226  		t, err := func() (bool, error) {
   227  			item := itemRaw.ToIDispatch()
   228  			defer item.Release()
   229  
   230  			if item, err := oleutil.GetProperty(item, "Name"); err != nil {
   231  				return false, fmt.Errorf("Failed to get Property (Name) of Rule")
   232  			} else if item.ToString() == name {
   233  				return true, nil
   234  			}
   235  
   236  			return false, nil
   237  		}()
   238  
   239  		if err != nil {
   240  			return false, err
   241  		} else if t {
   242  			return true, nil
   243  		}
   244  	}
   245  
   246  	return false, nil
   247  }
   248  
   249  // firewallAPIInit initialize common fw api.
   250  // then:
   251  // dispatch firewallAPIRelease(u, fwp)
   252  func firewallAPIInit() (*ole.IUnknown, *ole.IDispatch, error) {
   253  	comshim.Add(1)
   254  
   255  	unknown, err := oleutil.CreateObject("HNetCfg.FwPolicy2")
   256  	if err != nil {
   257  		return nil, nil, fmt.Errorf("Failed to create FwPolicy Object: %s", err)
   258  	}
   259  
   260  	fwPolicy, err := unknown.QueryInterface(ole.IID_IDispatch)
   261  	if err != nil {
   262  		unknown.Release()
   263  		return nil, nil, fmt.Errorf("Failed to create FwPolicy Object (2): %s", err)
   264  	}
   265  
   266  	return unknown, fwPolicy, nil
   267  }
   268  
   269  // firewallAPIRelease cleans memory.
   270  func firewallAPIRelease(u *ole.IUnknown, fwp *ole.IDispatch) {
   271  	fwp.Release()
   272  	u.Release()
   273  	comshim.Done()
   274  }