vitess.io/vitess@v0.16.2/go/vt/vtadmin/cluster/discovery/fakediscovery/discovery.go (about)

     1  /*
     2  Copyright 2020 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  // Package fakediscovery provides a fake, in-memory discovery implementation.
    18  package fakediscovery
    19  
    20  import (
    21  	"context"
    22  	"math/rand"
    23  	"sync"
    24  
    25  	"github.com/stretchr/testify/assert"
    26  
    27  	"vitess.io/vitess/go/vt/vtadmin/cluster/discovery"
    28  
    29  	vtadminpb "vitess.io/vitess/go/vt/proto/vtadmin"
    30  )
    31  
    32  type vtctlds struct {
    33  	byTag     map[string][]*vtadminpb.Vtctld
    34  	byName    map[string]*vtadminpb.Vtctld
    35  	shouldErr bool
    36  }
    37  
    38  type gates struct {
    39  	byTag     map[string][]*vtadminpb.VTGate
    40  	byName    map[string]*vtadminpb.VTGate
    41  	shouldErr bool
    42  }
    43  
    44  // Fake is a fake discovery implementation for use in testing.
    45  type Fake struct {
    46  	gates   *gates
    47  	vtctlds *vtctlds
    48  
    49  	m sync.Mutex
    50  }
    51  
    52  // New returns a new fake.
    53  func New() *Fake {
    54  	return &Fake{
    55  		gates: &gates{
    56  			byTag:  map[string][]*vtadminpb.VTGate{},
    57  			byName: map[string]*vtadminpb.VTGate{},
    58  		},
    59  		vtctlds: &vtctlds{
    60  			byTag:  map[string][]*vtadminpb.Vtctld{},
    61  			byName: map[string]*vtadminpb.Vtctld{},
    62  		},
    63  	}
    64  }
    65  
    66  func (d *Fake) Clear() {
    67  	d.m.Lock()
    68  	defer d.m.Unlock()
    69  
    70  	d.gates = &gates{
    71  		byTag:  map[string][]*vtadminpb.VTGate{},
    72  		byName: map[string]*vtadminpb.VTGate{},
    73  	}
    74  
    75  	d.vtctlds = &vtctlds{
    76  		byTag:  map[string][]*vtadminpb.Vtctld{},
    77  		byName: map[string]*vtadminpb.Vtctld{},
    78  	}
    79  }
    80  
    81  // AddTaggedGates adds the given gates to the discovery fake, associating each
    82  // gate with each tag. To tag different gates with multiple tags, call multiple
    83  // times with the same gates but different tag slices. Gates are uniquely
    84  // identified by hostname.
    85  func (d *Fake) AddTaggedGates(tags []string, gates ...*vtadminpb.VTGate) {
    86  	d.m.Lock()
    87  	defer d.m.Unlock()
    88  
    89  	for _, tag := range tags {
    90  		d.gates.byTag[tag] = append(d.gates.byTag[tag], gates...)
    91  	}
    92  
    93  	for _, g := range gates {
    94  		d.gates.byName[g.Hostname] = g
    95  	}
    96  }
    97  
    98  // AddTaggedVtctlds adds the given vtctlds to the discovery fake, associating
    99  // each vtctld with each tag. To tag different vtctlds with multiple tags, call
   100  // multiple times with the same vtctlds but different tag slices. Vtctlds are
   101  // uniquely identified by hostname.
   102  func (d *Fake) AddTaggedVtctlds(tags []string, vtctlds ...*vtadminpb.Vtctld) {
   103  	d.m.Lock()
   104  	defer d.m.Unlock()
   105  
   106  	for _, tag := range tags {
   107  		d.vtctlds.byTag[tag] = append(d.vtctlds.byTag[tag], vtctlds...)
   108  	}
   109  
   110  	for _, vtctld := range vtctlds {
   111  		d.vtctlds.byName[vtctld.Hostname] = vtctld
   112  	}
   113  }
   114  
   115  // SetGatesError instructs whether the fake should return an error on gate
   116  // discovery functions.
   117  func (d *Fake) SetGatesError(shouldErr bool) {
   118  	d.m.Lock()
   119  	defer d.m.Unlock()
   120  
   121  	d.gates.shouldErr = shouldErr
   122  }
   123  
   124  // SetVtctldsError instructs whether the fake should return an error on vtctld
   125  // discovery functions.
   126  func (d *Fake) SetVtctldsError(shouldErr bool) {
   127  	d.m.Lock()
   128  	defer d.m.Unlock()
   129  
   130  	d.vtctlds.shouldErr = shouldErr
   131  }
   132  
   133  var _ discovery.Discovery = (*Fake)(nil)
   134  
   135  // DiscoverVTGates is part of the discovery.Discovery interface.
   136  func (d *Fake) DiscoverVTGates(ctx context.Context, tags []string) ([]*vtadminpb.VTGate, error) {
   137  	d.m.Lock()
   138  	defer d.m.Unlock()
   139  
   140  	if d.gates.shouldErr {
   141  		return nil, assert.AnError
   142  	}
   143  
   144  	if len(tags) == 0 {
   145  		results := make([]*vtadminpb.VTGate, 0, len(d.gates.byName))
   146  		for _, gate := range d.gates.byName {
   147  			results = append(results, gate)
   148  		}
   149  
   150  		return results, nil
   151  	}
   152  
   153  	set := d.gates.byName
   154  
   155  	for _, tag := range tags {
   156  		intermediate := map[string]*vtadminpb.VTGate{}
   157  
   158  		gates, ok := d.gates.byTag[tag]
   159  		if !ok {
   160  			return []*vtadminpb.VTGate{}, nil
   161  		}
   162  
   163  		for _, g := range gates {
   164  			if _, ok := set[g.Hostname]; ok {
   165  				intermediate[g.Hostname] = g
   166  			}
   167  		}
   168  
   169  		set = intermediate
   170  	}
   171  
   172  	results := make([]*vtadminpb.VTGate, 0, len(set))
   173  
   174  	for _, gate := range set {
   175  		results = append(results, gate)
   176  	}
   177  
   178  	return results, nil
   179  }
   180  
   181  // DiscoverVTGate is part of the discovery.Discovery interface.
   182  func (d *Fake) DiscoverVTGate(ctx context.Context, tags []string) (*vtadminpb.VTGate, error) {
   183  	gates, err := d.DiscoverVTGates(ctx, tags)
   184  	if err != nil {
   185  		return nil, err
   186  	}
   187  
   188  	if len(gates) == 0 {
   189  		return nil, assert.AnError
   190  	}
   191  
   192  	return gates[rand.Intn(len(gates))], nil
   193  }
   194  
   195  // DiscoverVTGateAddr is part of the discovery.Discovery interface.
   196  func (d *Fake) DiscoverVTGateAddr(ctx context.Context, tags []string) (string, error) {
   197  	gate, err := d.DiscoverVTGate(ctx, tags)
   198  	if err != nil {
   199  		return "", err
   200  	}
   201  
   202  	return gate.Hostname, nil
   203  }
   204  
   205  // DiscoverVTGateAddrs is part of the discovery.Discovery interface.
   206  func (d *Fake) DiscoverVTGateAddrs(ctx context.Context, tags []string) ([]string, error) {
   207  	gates, err := d.DiscoverVTGates(ctx, tags)
   208  	if err != nil {
   209  		return nil, err
   210  	}
   211  
   212  	addrs := make([]string, len(gates))
   213  	for i, gate := range gates {
   214  		addrs[i] = gate.Hostname
   215  	}
   216  
   217  	return addrs, nil
   218  }
   219  
   220  // DiscoverVtctlds is part of the discover.Discovery interface.
   221  func (d *Fake) DiscoverVtctlds(ctx context.Context, tags []string) ([]*vtadminpb.Vtctld, error) {
   222  	d.m.Lock()
   223  	defer d.m.Unlock()
   224  
   225  	if d.vtctlds.shouldErr {
   226  		return nil, assert.AnError
   227  	}
   228  
   229  	if len(tags) == 0 {
   230  		results := make([]*vtadminpb.Vtctld, 0, len(d.vtctlds.byName))
   231  		for _, vtctld := range d.vtctlds.byName {
   232  			results = append(results, vtctld)
   233  		}
   234  
   235  		return results, nil
   236  	}
   237  
   238  	set := d.vtctlds.byName
   239  
   240  	for _, tag := range tags {
   241  		intermediate := map[string]*vtadminpb.Vtctld{}
   242  
   243  		vtctlds, ok := d.vtctlds.byTag[tag]
   244  		if !ok {
   245  			return []*vtadminpb.Vtctld{}, nil
   246  		}
   247  
   248  		for _, v := range vtctlds {
   249  			if _, ok := set[v.Hostname]; ok {
   250  				intermediate[v.Hostname] = v
   251  			}
   252  		}
   253  
   254  		set = intermediate
   255  	}
   256  
   257  	results := make([]*vtadminpb.Vtctld, 0, len(set))
   258  
   259  	for _, vtctld := range set {
   260  		results = append(results, vtctld)
   261  	}
   262  
   263  	return results, nil
   264  }
   265  
   266  // DiscoverVtctldAddr is part of the discover.Discovery interface.
   267  func (d *Fake) DiscoverVtctldAddr(ctx context.Context, tags []string) (string, error) {
   268  	vtctld, err := d.DiscoverVtctld(ctx, tags)
   269  	if err != nil {
   270  		return "", err
   271  	}
   272  
   273  	return vtctld.Hostname, nil
   274  }
   275  
   276  // DiscoverVtctldAddrs is part of the discovery.Discovery interface.
   277  func (d *Fake) DiscoverVtctldAddrs(ctx context.Context, tags []string) ([]string, error) {
   278  	vtctlds, err := d.DiscoverVtctlds(ctx, tags)
   279  	if err != nil {
   280  		return nil, err
   281  	}
   282  
   283  	addrs := make([]string, len(vtctlds))
   284  	for i, vtctld := range vtctlds {
   285  		addrs[i] = vtctld.Hostname
   286  	}
   287  
   288  	return addrs, nil
   289  }
   290  
   291  // DiscoverVtctld is part of the discover.Discovery interface.
   292  func (d *Fake) DiscoverVtctld(ctx context.Context, tags []string) (*vtadminpb.Vtctld, error) {
   293  	vtctlds, err := d.DiscoverVtctlds(ctx, tags)
   294  	if err != nil {
   295  		return nil, err
   296  	}
   297  
   298  	if len(vtctlds) == 0 {
   299  		return nil, assert.AnError
   300  	}
   301  
   302  	return vtctlds[rand.Intn(len(vtctlds))], nil
   303  }