github.com/google/cloudprober@v0.11.3/rds/client/srvlist.go (about)

     1  // Copyright 2020 The Cloudprober Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // This file implements a client-side load balancing resolver for gRPC clients.
    16  // This resolver takes a comma separated list of addresses and sets client
    17  // connection to use those addresses in a round-robin manner. It implements
    18  // the APIs defined in google.golang.org/grpc/resolver.
    19  
    20  package client
    21  
    22  import (
    23  	"math/rand"
    24  	"net"
    25  	"strings"
    26  
    27  	cpRes "github.com/google/cloudprober/targets/resolver"
    28  	"google.golang.org/grpc/resolver"
    29  )
    30  
    31  // srvListResolver implements the resolver.Resolver interface.
    32  type srvListResolver struct {
    33  	hostList    []string
    34  	portList    []string
    35  	r           *cpRes.Resolver
    36  	cc          resolver.ClientConn
    37  	defaultPort string
    38  }
    39  
    40  func parseAddr(addr, defaultPort string) (host, port string, err error) {
    41  	if ipStr, ok := formatIP(addr); ok {
    42  		return ipStr, defaultPort, nil
    43  	}
    44  
    45  	host, port, err = net.SplitHostPort(addr)
    46  	if err != nil {
    47  		return "", "", err
    48  	}
    49  
    50  	if port == "" {
    51  		port = defaultPort
    52  	}
    53  
    54  	// target has port, i.e ipv4-host:port, [ipv6-host]:port, host-name:port
    55  	if host == "" {
    56  		// Keep consistent with net.Dial(): If the host is empty, as in ":80", the local system is assumed.
    57  		host = "localhost"
    58  	}
    59  
    60  	return
    61  }
    62  
    63  // formatIP returns ok = false if addr is not a valid textual representation of an IP address.
    64  // If addr is an IPv4 address, return the addr and ok = true.
    65  // If addr is an IPv6 address, return the addr enclosed in square brackets and ok = true.
    66  func formatIP(addr string) (addrIP string, ok bool) {
    67  	ip := net.ParseIP(addr)
    68  	if ip == nil {
    69  		return "", false
    70  	}
    71  	if ip.To4() != nil {
    72  		return addr, true
    73  	}
    74  	return "[" + addr + "]", true
    75  }
    76  
    77  func (res *srvListResolver) resolve() (*resolver.State, error) {
    78  	state := &resolver.State{}
    79  
    80  	for i, host := range res.hostList {
    81  		if ipStr, ok := formatIP(host); ok {
    82  			state.Addresses = append(state.Addresses, resolver.Address{
    83  				Addr: ipStr + ":" + res.portList[i],
    84  			})
    85  			continue
    86  		}
    87  
    88  		ip, err := res.r.Resolve(host, 0)
    89  		if err != nil {
    90  			return nil, err
    91  		}
    92  		state.Addresses = append(state.Addresses, resolver.Address{
    93  			Addr: ip.String() + ":" + res.portList[i],
    94  		})
    95  	}
    96  
    97  	// Set round robin policy.
    98  	state.ServiceConfig = res.cc.ParseServiceConfig("{\"loadBalancingPolicy\": \"round_robin\"}")
    99  	return state, nil
   100  }
   101  
   102  func (res *srvListResolver) ResolveNow(_ resolver.ResolveNowOptions) {
   103  	state, err := res.resolve()
   104  	if err != nil {
   105  		res.cc.ReportError(err)
   106  		return
   107  	}
   108  
   109  	res.cc.UpdateState(*state)
   110  }
   111  
   112  func (res *srvListResolver) Close() {
   113  }
   114  
   115  func newSrvListResolver(target, defaultPort string) (*srvListResolver, error) {
   116  	res := &srvListResolver{
   117  		r:           cpRes.New(),
   118  		defaultPort: defaultPort,
   119  	}
   120  
   121  	addrs := strings.Split(target, ",")
   122  
   123  	// Shuffle addresses to create variance in what order different clients start
   124  	// connecting to these addresses. Note that round-robin load balancing policy
   125  	// takes care of distributing load evenly over time.
   126  	rand.Shuffle(len(addrs), func(i, j int) {
   127  		addrs[i], addrs[j] = addrs[j], addrs[i]
   128  	})
   129  
   130  	for _, addr := range addrs {
   131  		host, port, err := parseAddr(addr, defaultPort)
   132  		if err != nil {
   133  			return nil, err
   134  		}
   135  
   136  		res.hostList = append(res.hostList, host)
   137  		res.portList = append(res.portList, port)
   138  	}
   139  
   140  	return res, nil
   141  }
   142  
   143  type srvListBuilder struct {
   144  	defaultPort string
   145  }
   146  
   147  // Scheme returns the naming scheme of this resolver builder, which is "srvlist".
   148  func (slb *srvListBuilder) Scheme() string {
   149  	return "srvlist"
   150  }
   151  
   152  func (slb *srvListBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) {
   153  	res, err := newSrvListResolver(target.Endpoint, slb.defaultPort)
   154  	if err != nil {
   155  		return nil, err
   156  	}
   157  
   158  	res.cc = cc
   159  
   160  	state, err := res.resolve()
   161  	if err != nil {
   162  		res.cc.ReportError(err)
   163  	} else {
   164  		res.cc.UpdateState(*state)
   165  	}
   166  
   167  	return res, nil
   168  }