github.com/ipfans/trojan-go@v0.11.0/common/geodata/cache.go (about)

     1  package geodata
     2  
     3  import (
     4  	"io/ioutil"
     5  	"strings"
     6  
     7  	v2router "github.com/v2fly/v2ray-core/v4/app/router"
     8  	"google.golang.org/protobuf/proto"
     9  
    10  	"github.com/ipfans/trojan-go/common"
    11  	"github.com/ipfans/trojan-go/log"
    12  )
    13  
    14  type geoipCache map[string]*v2router.GeoIP
    15  
    16  func (g geoipCache) Has(key string) bool {
    17  	return !(g.Get(key) == nil)
    18  }
    19  
    20  func (g geoipCache) Get(key string) *v2router.GeoIP {
    21  	if g == nil {
    22  		return nil
    23  	}
    24  	return g[key]
    25  }
    26  
    27  func (g geoipCache) Set(key string, value *v2router.GeoIP) {
    28  	if g == nil {
    29  		g = make(map[string]*v2router.GeoIP)
    30  	}
    31  	g[key] = value
    32  }
    33  
    34  func (g geoipCache) Unmarshal(filename, code string) (*v2router.GeoIP, error) {
    35  	asset := common.GetAssetLocation(filename)
    36  	idx := strings.ToLower(asset + ":" + code)
    37  	if g.Has(idx) {
    38  		log.Debugf("geoip cache HIT: %s -> %s", code, idx)
    39  		return g.Get(idx), nil
    40  	}
    41  
    42  	geoipBytes, err := Decode(asset, code)
    43  	switch err {
    44  	case nil:
    45  		var geoip v2router.GeoIP
    46  		if err := proto.Unmarshal(geoipBytes, &geoip); err != nil {
    47  			return nil, err
    48  		}
    49  		g.Set(idx, &geoip)
    50  		return &geoip, nil
    51  
    52  	case ErrCodeNotFound:
    53  		return nil, common.NewError("country code " + code + " not found in " + filename)
    54  
    55  	case ErrFailedToReadBytes, ErrFailedToReadExpectedLenBytes,
    56  		ErrInvalidGeodataFile, ErrInvalidGeodataVarintLength:
    57  		log.Warnf("failed to decode geoip file: %s, fallback to the original ReadFile method", filename)
    58  		geoipBytes, err = ioutil.ReadFile(asset)
    59  		if err != nil {
    60  			return nil, err
    61  		}
    62  		var geoipList v2router.GeoIPList
    63  		if err := proto.Unmarshal(geoipBytes, &geoipList); err != nil {
    64  			return nil, err
    65  		}
    66  		for _, geoip := range geoipList.GetEntry() {
    67  			if strings.EqualFold(code, geoip.GetCountryCode()) {
    68  				g.Set(idx, geoip)
    69  				return geoip, nil
    70  			}
    71  		}
    72  
    73  	default:
    74  		return nil, err
    75  	}
    76  
    77  	return nil, common.NewError("country code " + code + " not found in " + filename)
    78  }
    79  
    80  type geositeCache map[string]*v2router.GeoSite
    81  
    82  func (g geositeCache) Has(key string) bool {
    83  	return !(g.Get(key) == nil)
    84  }
    85  
    86  func (g geositeCache) Get(key string) *v2router.GeoSite {
    87  	if g == nil {
    88  		return nil
    89  	}
    90  	return g[key]
    91  }
    92  
    93  func (g geositeCache) Set(key string, value *v2router.GeoSite) {
    94  	if g == nil {
    95  		g = make(map[string]*v2router.GeoSite)
    96  	}
    97  	g[key] = value
    98  }
    99  
   100  func (g geositeCache) Unmarshal(filename, code string) (*v2router.GeoSite, error) {
   101  	asset := common.GetAssetLocation(filename)
   102  	idx := strings.ToLower(asset + ":" + code)
   103  	if g.Has(idx) {
   104  		log.Debugf("geosite cache HIT: %s -> %s", code, idx)
   105  		return g.Get(idx), nil
   106  	}
   107  
   108  	geositeBytes, err := Decode(asset, code)
   109  	switch err {
   110  	case nil:
   111  		var geosite v2router.GeoSite
   112  		if err := proto.Unmarshal(geositeBytes, &geosite); err != nil {
   113  			return nil, err
   114  		}
   115  		g.Set(idx, &geosite)
   116  		return &geosite, nil
   117  
   118  	case ErrCodeNotFound:
   119  		return nil, common.NewError("list " + code + " not found in " + filename)
   120  
   121  	case ErrFailedToReadBytes, ErrFailedToReadExpectedLenBytes,
   122  		ErrInvalidGeodataFile, ErrInvalidGeodataVarintLength:
   123  		log.Warnf("failed to decode geoip file: %s, fallback to the original ReadFile method", filename)
   124  		geositeBytes, err = ioutil.ReadFile(asset)
   125  		if err != nil {
   126  			return nil, err
   127  		}
   128  		var geositeList v2router.GeoSiteList
   129  		if err := proto.Unmarshal(geositeBytes, &geositeList); err != nil {
   130  			return nil, err
   131  		}
   132  		for _, geosite := range geositeList.GetEntry() {
   133  			if strings.EqualFold(code, geosite.GetCountryCode()) {
   134  				g.Set(idx, geosite)
   135  				return geosite, nil
   136  			}
   137  		}
   138  
   139  	default:
   140  		return nil, err
   141  	}
   142  
   143  	return nil, common.NewError("list " + code + " not found in " + filename)
   144  }