github.com/uber/kraken@v0.1.4/lib/hostlist/list.go (about)

     1  // Copyright (c) 2016-2019 Uber Technologies, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  package hostlist
    15  
    16  import (
    17  	"fmt"
    18  	"net"
    19  	"os"
    20  	"strings"
    21  	"sync"
    22  
    23  	"github.com/uber/kraken/utils/dedup"
    24  	"github.com/uber/kraken/utils/log"
    25  	"github.com/uber/kraken/utils/stringset"
    26  
    27  	"github.com/andres-erbsen/clock"
    28  )
    29  
    30  // List defines a list of addresses which is subject to change.
    31  type List interface {
    32  	Resolve() stringset.Set
    33  }
    34  
    35  type list struct {
    36  	resolver resolver
    37  
    38  	snapshotTrap *dedup.IntervalTrap
    39  
    40  	mu       sync.RWMutex
    41  	snapshot stringset.Set
    42  }
    43  
    44  // New creates a new List.
    45  //
    46  // An error is returned if a DNS record is supplied and resolves to an empty list
    47  // of addresses.
    48  //
    49  // If List is backed by DNS, it will be periodically refreshed (defined by TTL
    50  // in config). If, after construction, there is an error resolving DNS, the
    51  // latest successful snapshot is used. As such, Resolve never returns an empty
    52  // set.
    53  func New(config Config) (List, error) {
    54  	config.applyDefaults()
    55  
    56  	resolver, err := config.getResolver()
    57  	if err != nil {
    58  		return nil, fmt.Errorf("invalid config: %s", err)
    59  	}
    60  
    61  	l := &list{resolver: resolver}
    62  	l.snapshotTrap = dedup.NewIntervalTrap(config.TTL, clock.New(), &snapshotTask{l})
    63  
    64  	if err := l.takeSnapshot(); err != nil {
    65  		// Fail fast if a snapshot cannot be initialized.
    66  		return nil, err
    67  	}
    68  	return l, nil
    69  }
    70  
    71  func (l *list) Resolve() stringset.Set {
    72  	l.snapshotTrap.Trap()
    73  
    74  	l.mu.RLock()
    75  	defer l.mu.RUnlock()
    76  
    77  	return l.snapshot.Copy()
    78  }
    79  
    80  type snapshotTask struct {
    81  	list *list
    82  }
    83  
    84  func (t *snapshotTask) Run() {
    85  	if err := t.list.takeSnapshot(); err != nil {
    86  		log.With("source", t.list.resolver).Errorf("Error taking hostlist snapshot: %s", err)
    87  	}
    88  }
    89  
    90  func (l *list) takeSnapshot() error {
    91  	snapshot, err := l.resolver.resolve()
    92  	if err != nil {
    93  		return err
    94  	}
    95  	l.mu.Lock()
    96  	l.snapshot = snapshot
    97  	l.mu.Unlock()
    98  	return nil
    99  }
   100  
   101  type nonLocalList struct {
   102  	list       List
   103  	localAddrs stringset.Set
   104  }
   105  
   106  // StripLocal wraps a List and filters out the local machine, if present. The
   107  // local machine is identified by both its hostname and ip address, concatenated
   108  // with port.
   109  //
   110  // If the local machine is the only member of list, then Resolve returns an empty
   111  // set.
   112  func StripLocal(list List, port int) (List, error) {
   113  	localNames, err := getLocalNames()
   114  	if err != nil {
   115  		return nil, fmt.Errorf("get local names: %s", err)
   116  	}
   117  	localAddrs, err := attachPortIfMissing(localNames, port)
   118  	if err != nil {
   119  		return nil, fmt.Errorf("attach port to local names: %s", err)
   120  	}
   121  	return &nonLocalList{list, localAddrs}, nil
   122  }
   123  
   124  func (l *nonLocalList) Resolve() stringset.Set {
   125  	return l.list.Resolve().Sub(l.localAddrs)
   126  }
   127  
   128  func getLocalNames() (stringset.Set, error) {
   129  	result := make(stringset.Set)
   130  
   131  	// Add all local non-loopback ips.
   132  	ifaces, err := net.Interfaces()
   133  	if err != nil {
   134  		return nil, fmt.Errorf("interfaces: %s", err)
   135  	}
   136  	for _, i := range ifaces {
   137  		addrs, err := i.Addrs()
   138  		if err != nil {
   139  			return nil, fmt.Errorf("addrs of %v: %s", i, err)
   140  		}
   141  		for _, addr := range addrs {
   142  			ip := net.ParseIP(addr.String()).To4()
   143  			if ip == nil {
   144  				continue
   145  			}
   146  			result.Add(ip.String())
   147  		}
   148  	}
   149  
   150  	// Add local hostname just to be safe.
   151  	hostname, err := os.Hostname()
   152  	if err != nil {
   153  		return nil, fmt.Errorf("hostname: %s", err)
   154  	}
   155  	result.Add(hostname)
   156  
   157  	return result, nil
   158  }
   159  
   160  func attachPortIfMissing(names stringset.Set, port int) (stringset.Set, error) {
   161  	result := make(stringset.Set)
   162  	for name := range names {
   163  		parts := strings.Split(name, ":")
   164  		switch len(parts) {
   165  		case 1:
   166  			// Name is in 'host' format -- attach port.
   167  			name = fmt.Sprintf("%s:%d", parts[0], port)
   168  		case 2:
   169  			// No-op, name is already in "ip:port" format.
   170  		default:
   171  			return nil, fmt.Errorf("invalid name format: %s, expected 'host' or 'ip:port'", name)
   172  		}
   173  		result.Add(name)
   174  	}
   175  	return result, nil
   176  }