github.com/nemunaire/dnscontrol@v0.2.8/pkg/spflib/flatten.go (about)

     1  package spflib
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  )
     7  
     8  // TXT outputs s as a TXT record.
     9  func (s *SPFRecord) TXT() string {
    10  	text := "v=spf1"
    11  	for _, p := range s.Parts {
    12  		text += " " + p.Text
    13  	}
    14  	return text
    15  }
    16  
    17  const maxLen = 255
    18  
    19  // TXTSplit returns a set of txt records to use for SPF.
    20  // pattern given is used to name all chained spf records.
    21  // patern should include %d, which will be replaced by a counter.
    22  // should result in fqdn after replacement
    23  // returned map will have keys with fqdn of resulting records.
    24  // root record will be under key "@"
    25  func (s *SPFRecord) TXTSplit(pattern string) map[string]string {
    26  	m := map[string]string{}
    27  	s.split("@", pattern, 1, m)
    28  	return m
    29  
    30  }
    31  
    32  func (s *SPFRecord) split(thisfqdn string, pattern string, nextIdx int, m map[string]string) {
    33  	base := s.TXT()
    34  	// simple case. it fits
    35  	if len(base) <= maxLen {
    36  		m[thisfqdn] = base
    37  		return
    38  	}
    39  
    40  	// we need to trim.
    41  	// take parts while we fit
    42  	nextFQDN := fmt.Sprintf(pattern, nextIdx)
    43  	lastPart := s.Parts[len(s.Parts)-1]
    44  	tail := " include:" + nextFQDN + " " + lastPart.Text
    45  	thisText := "v=spf1"
    46  
    47  	newRec := &SPFRecord{}
    48  	over := false
    49  	addedCount := 0
    50  	for _, part := range s.Parts {
    51  		if !over {
    52  			if len(thisText)+1+len(part.Text)+len(tail) <= maxLen {
    53  				thisText += " " + part.Text
    54  				addedCount++
    55  			} else {
    56  				over = true
    57  				if addedCount == 0 {
    58  					// the first part is too big to include. We kinda have to give up here.
    59  					m[thisfqdn] = base
    60  					return
    61  				}
    62  			}
    63  		}
    64  		if over {
    65  			newRec.Parts = append(newRec.Parts, part)
    66  		}
    67  	}
    68  	m[thisfqdn] = thisText + tail
    69  	newRec.split(nextFQDN, pattern, nextIdx+1, m)
    70  }
    71  
    72  // Flatten optimizes s.
    73  func (s *SPFRecord) Flatten(spec string) *SPFRecord {
    74  	newRec := &SPFRecord{}
    75  	for _, p := range s.Parts {
    76  		if p.IncludeRecord == nil {
    77  			// non-includes copy straight over
    78  			newRec.Parts = append(newRec.Parts, p)
    79  		} else if !matchesFlatSpec(spec, p.IncludeDomain) {
    80  			// includes that don't match get copied straight across
    81  			newRec.Parts = append(newRec.Parts, p)
    82  		} else {
    83  			// flatten child recursively
    84  			flattenedChild := p.IncludeRecord.Flatten(spec)
    85  			// include their parts (skipping final all term)
    86  			for _, childPart := range flattenedChild.Parts[:len(flattenedChild.Parts)-1] {
    87  				newRec.Parts = append(newRec.Parts, childPart)
    88  			}
    89  		}
    90  	}
    91  	return newRec
    92  }
    93  
    94  func matchesFlatSpec(spec, fqdn string) bool {
    95  	if spec == "*" {
    96  		return true
    97  	}
    98  	for _, p := range strings.Split(spec, ",") {
    99  		if p == fqdn {
   100  			return true
   101  		}
   102  	}
   103  	return false
   104  }