github.com/songzhibin97/gkit@v1.2.13/net/ip/filter.go (about) 1 package ip 2 3 import ( 4 "net" 5 "sync" 6 ) 7 8 type Options struct { 9 //AllowedIPs allowed IPs 10 AllowedIPs []string 11 //BlockedIPs blocked IPs 12 BlockedIPs []string 13 14 //block by default (defaults to allow) 15 BlockByDefault bool 16 } 17 18 type subnet struct { 19 str string 20 ipNet *net.IPNet 21 allowed bool 22 } 23 24 type Filter struct { 25 opts Options 26 //mut protects the below 27 //rw since writes are rare 28 mut sync.RWMutex 29 defaultAllowed bool 30 ips map[string]bool 31 codes map[string]bool 32 subnets []*subnet 33 } 34 35 func (f *Filter) AllowIP(ip string) bool { 36 return f.ToggleIP(ip, true) 37 } 38 39 func (f *Filter) BlockIP(ip string) bool { 40 return f.ToggleIP(ip, false) 41 } 42 43 func (f *Filter) ToggleIP(str string, allowed bool) bool { 44 //check if has subnet 45 if ip, net, err := net.ParseCIDR(str); err == nil { 46 // containing only one ip? (no bits masked) 47 if n, total := net.Mask.Size(); n == total { 48 f.mut.Lock() 49 f.ips[ip.String()] = allowed 50 f.mut.Unlock() 51 return true 52 } 53 //check for existing 54 f.mut.Lock() 55 found := false 56 for _, subnet := range f.subnets { 57 if subnet.str == str { 58 found = true 59 subnet.allowed = allowed 60 break 61 } 62 } 63 if !found { 64 f.subnets = append(f.subnets, &subnet{ 65 str: str, 66 ipNet: net, 67 allowed: allowed, 68 }) 69 } 70 f.mut.Unlock() 71 return true 72 } 73 //check if plain ip (/32) 74 if ip := net.ParseIP(str); ip != nil { 75 f.mut.Lock() 76 f.ips[ip.String()] = allowed 77 f.mut.Unlock() 78 return true 79 } 80 return false 81 } 82 83 // ToggleDefault alters the default setting 84 func (f *Filter) ToggleDefault(allowed bool) { 85 f.mut.Lock() 86 f.defaultAllowed = allowed 87 f.mut.Unlock() 88 } 89 90 // Allowed returns if a given IP can pass through the filter 91 func (f *Filter) Allowed(ipstr string) bool { 92 return f.NetAllowed(net.ParseIP(ipstr)) 93 } 94 95 // NetAllowed returns if a given net.IP can pass through the filter 96 func (f *Filter) NetAllowed(ip net.IP) bool { 97 //invalid ip 98 if ip == nil { 99 return false 100 } 101 //read lock entire function 102 //except for db access 103 f.mut.RLock() 104 defer f.mut.RUnlock() 105 //check single ips 106 allowed, ok := f.ips[ip.String()] 107 if ok { 108 return allowed 109 } 110 //scan subnets for any allow/block 111 blocked := false 112 for _, subnet := range f.subnets { 113 if subnet.ipNet.Contains(ip) { 114 if subnet.allowed { 115 return true 116 } 117 blocked = true 118 } 119 } 120 if blocked { 121 return false 122 } 123 124 return f.defaultAllowed 125 } 126 127 // Blocked returns if a given IP can NOT pass through the filter 128 func (f *Filter) Blocked(ip string) bool { 129 return !f.Allowed(ip) 130 } 131 132 // NetBlocked returns if a given net.IP can NOT pass through the filter 133 func (f *Filter) NetBlocked(ip net.IP) bool { 134 return !f.NetAllowed(ip) 135 } 136 137 // New constructs IPFilter instance without downloading DB. 138 func New(opts Options) *Filter { 139 f := &Filter{ 140 opts: opts, 141 ips: map[string]bool{}, 142 codes: map[string]bool{}, 143 defaultAllowed: !opts.BlockByDefault, 144 } 145 for _, ip := range opts.BlockedIPs { 146 f.BlockIP(ip) 147 } 148 for _, ip := range opts.AllowedIPs { 149 f.AllowIP(ip) 150 } 151 return f 152 }