github.com/StackExchange/DNSControl@v0.2.8/pkg/spflib/resolver.go (about)

     1  package spflib
     2  
     3  import (
     4  	"encoding/json"
     5  	"io/ioutil"
     6  	"net"
     7  	"os"
     8  	"strings"
     9  
    10  	"github.com/pkg/errors"
    11  )
    12  
    13  // Resolver looks up spf txt records associated with a FQDN.
    14  type Resolver interface {
    15  	GetSPF(string) (string, error)
    16  }
    17  
    18  // LiveResolver simply queries DNS to resolve SPF records.
    19  type LiveResolver struct{}
    20  
    21  // GetSPF looks up the SPF record named "name".
    22  func (l LiveResolver) GetSPF(name string) (string, error) {
    23  	vals, err := net.LookupTXT(name)
    24  	if err != nil {
    25  		return "", err
    26  	}
    27  	spf := ""
    28  	for _, v := range vals {
    29  		if strings.HasPrefix(v, "v=spf1") {
    30  			if spf != "" {
    31  				return "", errors.Errorf("%s has multiple SPF records", name)
    32  			}
    33  			spf = v
    34  		}
    35  	}
    36  	if spf == "" {
    37  		return "", errors.Errorf("%s has no SPF record", name)
    38  	}
    39  	return spf, nil
    40  }
    41  
    42  // CachingResolver wraps a live resolver and adds caching to it.
    43  // GetSPF will always return the cached value, if present.
    44  // It will also query the inner resolver and compare results.
    45  // If a given lookup has inconsistencies between cache and live,
    46  // GetSPF will return the cached result.
    47  // All records queries will be stored for the lifetime of the resolver,
    48  // and can be flushed to disk at the end.
    49  // All resolution errors from the inner resolver will be saved and can be retreived later.
    50  type CachingResolver interface {
    51  	Resolver
    52  	ChangedRecords() []string
    53  	ResolveErrors() []error
    54  	Save(filename string) error
    55  }
    56  
    57  type cacheEntry struct {
    58  	SPF string
    59  
    60  	// value we have looked up this run
    61  	resolvedSPF  string
    62  	resolveError error
    63  }
    64  
    65  type cache struct {
    66  	records map[string]*cacheEntry
    67  
    68  	inner Resolver
    69  }
    70  
    71  // NewCache creates a new cache file named filename.
    72  func NewCache(filename string) (CachingResolver, error) {
    73  	f, err := os.Open(filename)
    74  	if err != nil {
    75  		if os.IsNotExist(err) {
    76  			// doesn't exist, just make a new one
    77  			return &cache{
    78  				records: map[string]*cacheEntry{},
    79  				inner:   LiveResolver{},
    80  			}, nil
    81  		}
    82  		return nil, err
    83  	}
    84  	dec := json.NewDecoder(f)
    85  	recs := map[string]*cacheEntry{}
    86  	if err := dec.Decode(&recs); err != nil {
    87  		return nil, err
    88  	}
    89  	return &cache{
    90  		records: recs,
    91  		inner:   LiveResolver{},
    92  	}, nil
    93  }
    94  
    95  func (c *cache) GetSPF(name string) (string, error) {
    96  	entry, ok := c.records[name]
    97  	if !ok {
    98  		entry = &cacheEntry{}
    99  		c.records[name] = entry
   100  	}
   101  	if entry.resolvedSPF == "" && entry.resolveError == nil {
   102  		entry.resolvedSPF, entry.resolveError = c.inner.GetSPF(name)
   103  	}
   104  	// return cached value
   105  	if entry.SPF != "" {
   106  		return entry.SPF, nil
   107  	}
   108  	// if not cached, return results of inner resolver
   109  	return entry.resolvedSPF, entry.resolveError
   110  }
   111  
   112  func (c *cache) ChangedRecords() []string {
   113  	names := []string{}
   114  	for name, entry := range c.records {
   115  		if entry.resolvedSPF != entry.SPF {
   116  			names = append(names, name)
   117  		}
   118  	}
   119  	return names
   120  }
   121  
   122  func (c *cache) ResolveErrors() (errs []error) {
   123  	for _, entry := range c.records {
   124  		if entry.resolveError != nil {
   125  			errs = append(errs, entry.resolveError)
   126  		}
   127  	}
   128  	return
   129  }
   130  func (c *cache) Save(filename string) error {
   131  	outRecs := make(map[string]*cacheEntry, len(c.records))
   132  	for k, entry := range c.records {
   133  		// move resolved data into cached field
   134  		// only take those we actually resolved
   135  		if entry.resolvedSPF != "" {
   136  			entry.SPF = entry.resolvedSPF
   137  			outRecs[k] = entry
   138  		}
   139  	}
   140  	dat, _ := json.MarshalIndent(outRecs, "", "  ")
   141  	return ioutil.WriteFile(filename, dat, 0644)
   142  }