github.com/songzhibin97/gkit@v1.2.13/net/ip/filter.go (about)

     1  package ip
     2  
     3  import (
     4  	"net"
     5  	"sync"
     6  )
     7  
     8  type Options struct {
     9  	//AllowedIPs allowed IPs
    10  	AllowedIPs []string
    11  	//BlockedIPs blocked IPs
    12  	BlockedIPs []string
    13  
    14  	//block by default (defaults to allow)
    15  	BlockByDefault bool
    16  }
    17  
    18  type subnet struct {
    19  	str     string
    20  	ipNet   *net.IPNet
    21  	allowed bool
    22  }
    23  
    24  type Filter struct {
    25  	opts Options
    26  	//mut protects the below
    27  	//rw since writes are rare
    28  	mut            sync.RWMutex
    29  	defaultAllowed bool
    30  	ips            map[string]bool
    31  	codes          map[string]bool
    32  	subnets        []*subnet
    33  }
    34  
    35  func (f *Filter) AllowIP(ip string) bool {
    36  	return f.ToggleIP(ip, true)
    37  }
    38  
    39  func (f *Filter) BlockIP(ip string) bool {
    40  	return f.ToggleIP(ip, false)
    41  }
    42  
    43  func (f *Filter) ToggleIP(str string, allowed bool) bool {
    44  	//check if has subnet
    45  	if ip, net, err := net.ParseCIDR(str); err == nil {
    46  		// containing only one ip? (no bits masked)
    47  		if n, total := net.Mask.Size(); n == total {
    48  			f.mut.Lock()
    49  			f.ips[ip.String()] = allowed
    50  			f.mut.Unlock()
    51  			return true
    52  		}
    53  		//check for existing
    54  		f.mut.Lock()
    55  		found := false
    56  		for _, subnet := range f.subnets {
    57  			if subnet.str == str {
    58  				found = true
    59  				subnet.allowed = allowed
    60  				break
    61  			}
    62  		}
    63  		if !found {
    64  			f.subnets = append(f.subnets, &subnet{
    65  				str:     str,
    66  				ipNet:   net,
    67  				allowed: allowed,
    68  			})
    69  		}
    70  		f.mut.Unlock()
    71  		return true
    72  	}
    73  	//check if plain ip (/32)
    74  	if ip := net.ParseIP(str); ip != nil {
    75  		f.mut.Lock()
    76  		f.ips[ip.String()] = allowed
    77  		f.mut.Unlock()
    78  		return true
    79  	}
    80  	return false
    81  }
    82  
    83  // ToggleDefault alters the default setting
    84  func (f *Filter) ToggleDefault(allowed bool) {
    85  	f.mut.Lock()
    86  	f.defaultAllowed = allowed
    87  	f.mut.Unlock()
    88  }
    89  
    90  // Allowed returns if a given IP can pass through the filter
    91  func (f *Filter) Allowed(ipstr string) bool {
    92  	return f.NetAllowed(net.ParseIP(ipstr))
    93  }
    94  
    95  // NetAllowed returns if a given net.IP can pass through the filter
    96  func (f *Filter) NetAllowed(ip net.IP) bool {
    97  	//invalid ip
    98  	if ip == nil {
    99  		return false
   100  	}
   101  	//read lock entire function
   102  	//except for db access
   103  	f.mut.RLock()
   104  	defer f.mut.RUnlock()
   105  	//check single ips
   106  	allowed, ok := f.ips[ip.String()]
   107  	if ok {
   108  		return allowed
   109  	}
   110  	//scan subnets for any allow/block
   111  	blocked := false
   112  	for _, subnet := range f.subnets {
   113  		if subnet.ipNet.Contains(ip) {
   114  			if subnet.allowed {
   115  				return true
   116  			}
   117  			blocked = true
   118  		}
   119  	}
   120  	if blocked {
   121  		return false
   122  	}
   123  
   124  	return f.defaultAllowed
   125  }
   126  
   127  // Blocked returns if a given IP can NOT pass through the filter
   128  func (f *Filter) Blocked(ip string) bool {
   129  	return !f.Allowed(ip)
   130  }
   131  
   132  // NetBlocked returns if a given net.IP can NOT pass through the filter
   133  func (f *Filter) NetBlocked(ip net.IP) bool {
   134  	return !f.NetAllowed(ip)
   135  }
   136  
   137  // New constructs IPFilter instance without downloading DB.
   138  func New(opts Options) *Filter {
   139  	f := &Filter{
   140  		opts:           opts,
   141  		ips:            map[string]bool{},
   142  		codes:          map[string]bool{},
   143  		defaultAllowed: !opts.BlockByDefault,
   144  	}
   145  	for _, ip := range opts.BlockedIPs {
   146  		f.BlockIP(ip)
   147  	}
   148  	for _, ip := range opts.AllowedIPs {
   149  		f.AllowIP(ip)
   150  	}
   151  	return f
   152  }