github.com/uber/kraken@v0.1.4/lib/hostlist/config.go (about)

     1  // Copyright (c) 2016-2019 Uber Technologies, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  package hostlist
    15  
    16  import (
    17  	"context"
    18  	"errors"
    19  	"fmt"
    20  	"net"
    21  	"strconv"
    22  	"strings"
    23  	"time"
    24  
    25  	"github.com/uber/kraken/utils/stringset"
    26  )
    27  
    28  // Config defines a list of hosts using either a DNS record or a static list of
    29  // addresses. If present, a DNS record always takes precedence over a static
    30  // list.
    31  type Config struct {
    32  	// DNS record from which to resolve host names. Must include port suffix,
    33  	// which will be attached to each host within the record.
    34  	DNS string `yaml:"dns"`
    35  
    36  	// Statically configured addresses. Must be in 'host:port' format.
    37  	Static []string `yaml:"static"`
    38  
    39  	// TTL defines how long resolved host lists are cached for.
    40  	TTL time.Duration `yaml:"ttl"`
    41  }
    42  
    43  func (c *Config) applyDefaults() {
    44  	if c.TTL == 0 {
    45  		c.TTL = 5 * time.Second
    46  	}
    47  }
    48  
    49  // getResolver parses the configuration for which resolver to use.
    50  func (c *Config) getResolver() (resolver, error) {
    51  	if c.DNS == "" && len(c.Static) == 0 {
    52  		return nil, errors.New("no dns record or static list supplied")
    53  	}
    54  	if c.DNS != "" && len(c.Static) > 0 {
    55  		return nil, errors.New("both dns record and static list supplied")
    56  	}
    57  
    58  	if len(c.Static) > 0 {
    59  		for _, addr := range c.Static {
    60  			if _, _, err := net.SplitHostPort(addr); err != nil {
    61  				return nil, fmt.Errorf("invalid static addr: %s", err)
    62  			}
    63  		}
    64  		return &staticResolver{stringset.FromSlice(c.Static)}, nil
    65  	}
    66  
    67  	dns, rawport, err := net.SplitHostPort(c.DNS)
    68  	if err != nil {
    69  		return nil, fmt.Errorf("invalid dns: %s", err)
    70  	}
    71  	port, err := strconv.Atoi(rawport)
    72  	if err != nil {
    73  		return nil, fmt.Errorf("invalid dns port: %s", err)
    74  	}
    75  	return &dnsResolver{dns, port}, nil
    76  }
    77  
    78  // resolver resolves parsed configuration into a list of addresses.
    79  type resolver interface {
    80  	resolve() (stringset.Set, error)
    81  }
    82  
    83  type staticResolver struct {
    84  	set stringset.Set
    85  }
    86  
    87  func (r *staticResolver) resolve() (stringset.Set, error) {
    88  	return r.set, nil
    89  }
    90  
    91  func (r *staticResolver) String() string {
    92  	return strings.Join(r.set.ToSlice(), ",")
    93  }
    94  
    95  type dnsResolver struct {
    96  	dns  string
    97  	port int
    98  }
    99  
   100  func (r *dnsResolver) resolve() (stringset.Set, error) {
   101  	var nr net.Resolver
   102  	names, err := nr.LookupHost(context.Background(), r.dns)
   103  	if err != nil {
   104  		return nil, fmt.Errorf("resolve dns: %s", err)
   105  	}
   106  	if len(names) == 0 {
   107  		return nil, errors.New("dns record empty")
   108  	}
   109  	addrs, err := attachPortIfMissing(stringset.FromSlice(names), r.port)
   110  	if err != nil {
   111  		return nil, fmt.Errorf("attach port to dns contents: %s", err)
   112  	}
   113  	return addrs, nil
   114  }
   115  
   116  func (r *dnsResolver) String() string {
   117  	return fmt.Sprintf("%s:%d", r.dns, r.port)
   118  }