github.com/matrixorigin/matrixone@v1.2.0/pkg/proxy/conn_manager_test.go (about)

     1  // Copyright 2021 - 2023 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package proxy
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"github.com/matrixorigin/matrixone/pkg/common/runtime"
    21  	"sync"
    22  	"testing"
    23  
    24  	"github.com/lni/goutils/leaktest"
    25  	"github.com/stretchr/testify/require"
    26  )
    27  
    28  func TestTunnelSet(t *testing.T) {
    29  	ts := make(tunnelSet)
    30  	tu := &tunnel{}
    31  
    32  	ts.add(tu)
    33  	require.Equal(t, 1, ts.count())
    34  	ts.add(tu)
    35  	require.Equal(t, 1, ts.count())
    36  
    37  	ts.add(&tunnel{})
    38  	require.Equal(t, 2, ts.count())
    39  
    40  	require.True(t, ts.exists(tu))
    41  	t1 := &tunnel{}
    42  	require.False(t, ts.exists(t1))
    43  
    44  	ts.del(tu)
    45  	require.Equal(t, 1, ts.count())
    46  	require.False(t, ts.exists(tu))
    47  }
    48  
    49  func TestCNTunnels(t *testing.T) {
    50  	defer leaktest.AfterTest(t)()
    51  
    52  	ct := newCNTunnels()
    53  	require.NotNil(t, ct)
    54  
    55  	t1 := &tunnel{}
    56  	ct.add("cn1", t1)
    57  	require.Equal(t, 1, ct.count())
    58  
    59  	t2 := &tunnel{}
    60  	ct.add("cn1", t2)
    61  	require.Equal(t, 2, ct.count())
    62  
    63  	// same tunnel
    64  	ct.add("cn1", t2)
    65  	require.Equal(t, 2, ct.count())
    66  
    67  	ct.add("cn1", nil)
    68  	require.Equal(t, 2, ct.count())
    69  
    70  	t3 := &tunnel{}
    71  	ct.add("cn2", t3)
    72  	require.Equal(t, 3, ct.count())
    73  
    74  	// no this cn.
    75  	ct.del("no-this-cn", t1)
    76  	require.Equal(t, 3, ct.count())
    77  
    78  	// tunnel is not on this cn.
    79  	ct.del("cn2", t1)
    80  	require.Equal(t, 3, ct.count())
    81  
    82  	ct.del("cn1", t1)
    83  	require.Equal(t, 2, ct.count())
    84  
    85  	ct.del("cn1", t1)
    86  	require.Equal(t, 2, ct.count())
    87  	ct.del("cn1", t2)
    88  	require.Equal(t, 1, ct.count())
    89  	ct.del("cn2", t3)
    90  	require.Equal(t, 0, ct.count())
    91  
    92  	ct.del("cn2", t3)
    93  	require.Equal(t, 0, ct.count())
    94  }
    95  
    96  func TestConnManagerConnection(t *testing.T) {
    97  	defer leaktest.AfterTest(t)()
    98  
    99  	cm := newConnManager()
   100  	require.NotNil(t, cm)
   101  
   102  	cn11 := testMakeCNServer("cn11", "", 0, "hash1",
   103  		newLabelInfo("t1", map[string]string{
   104  			"k1": "v1",
   105  		}),
   106  	)
   107  	cn12 := testMakeCNServer("cn12", "", 0, "hash1",
   108  		newLabelInfo("t1", map[string]string{
   109  			"k1": "v1",
   110  		}),
   111  	)
   112  	cn21 := testMakeCNServer("cn21", "", 0, "hash2",
   113  		newLabelInfo("t1", map[string]string{
   114  			"k2": "v2",
   115  		}),
   116  	)
   117  
   118  	rt := runtime.DefaultRuntime()
   119  	tu0 := newTunnel(context.TODO(), rt.Logger(), nil)
   120  
   121  	tu11 := newTunnel(context.TODO(), rt.Logger(), nil)
   122  	cm.connect(cn11, tu11)
   123  	require.Equal(t, 1, cm.count())
   124  	require.Equal(t, 1, len(cm.getLabelHashes()))
   125  	require.Equal(t, 1, cm.getCNTunnels("hash1").count())
   126  	require.Equal(t, 0, cm.getCNTunnels("hash2").count())
   127  
   128  	tu12 := newTunnel(context.TODO(), rt.Logger(), nil)
   129  	cm.connect(cn12, tu12)
   130  	require.Equal(t, 2, cm.count())
   131  	require.Equal(t, 1, len(cm.getLabelHashes()))
   132  	require.Equal(t, 2, cm.getCNTunnels("hash1").count())
   133  	require.Equal(t, 0, cm.getCNTunnels("hash2").count())
   134  
   135  	tu21 := newTunnel(context.TODO(), rt.Logger(), nil)
   136  	cm.connect(cn21, tu21)
   137  	require.Equal(t, 3, cm.count())
   138  	require.Equal(t, 2, len(cm.getLabelHashes()))
   139  	require.Equal(t, 2, cm.getCNTunnels("hash1").count())
   140  	require.Equal(t, 1, cm.getCNTunnels("hash2").count())
   141  
   142  	cm.disconnect(cn12, tu11)
   143  	require.Equal(t, 3, cm.count())
   144  	require.Equal(t, 2, len(cm.getLabelHashes()))
   145  	require.Equal(t, 2, cm.getCNTunnels("hash1").count())
   146  	require.Equal(t, 1, cm.getCNTunnels("hash2").count())
   147  
   148  	cm.disconnect(cn11, tu0)
   149  	require.Equal(t, 3, cm.count())
   150  	require.Equal(t, 2, len(cm.getLabelHashes()))
   151  	require.Equal(t, 2, cm.getCNTunnels("hash1").count())
   152  	require.Equal(t, 1, cm.getCNTunnels("hash2").count())
   153  
   154  	cm.disconnect(cn12, tu12)
   155  	require.Equal(t, 2, cm.count())
   156  	require.Equal(t, 2, len(cm.getLabelHashes()))
   157  	require.Equal(t, 1, cm.getCNTunnels("hash1").count())
   158  	require.Equal(t, 1, cm.getCNTunnels("hash2").count())
   159  
   160  	cm.disconnect(cn12, tu0)
   161  	require.Equal(t, 2, cm.count())
   162  	require.Equal(t, 2, len(cm.getLabelHashes()))
   163  	require.Equal(t, 1, cm.getCNTunnels("hash1").count())
   164  	require.Equal(t, 1, cm.getCNTunnels("hash2").count())
   165  
   166  	cm.disconnect(cn11, tu11)
   167  	require.Equal(t, 1, cm.count())
   168  	require.Equal(t, 1, len(cm.getLabelHashes()))
   169  	require.Equal(t, 0, cm.getCNTunnels("hash1").count())
   170  	require.Equal(t, 1, cm.getCNTunnels("hash2").count())
   171  
   172  	cm.disconnect(cn21, tu21)
   173  	require.Equal(t, 0, cm.count())
   174  	require.Equal(t, 0, len(cm.getLabelHashes()))
   175  	require.Equal(t, 0, cm.getCNTunnels("hash1").count())
   176  	require.Equal(t, 0, cm.getCNTunnels("hash2").count())
   177  
   178  	cm.disconnect(cn21, tu0)
   179  	require.Equal(t, 0, cm.count())
   180  	require.Equal(t, 0, len(cm.getLabelHashes()))
   181  	require.Equal(t, 0, cm.getCNTunnels("hash1").count())
   182  	require.Equal(t, 0, cm.getCNTunnels("hash2").count())
   183  }
   184  
   185  func TestConnManagerConnectionConcurrency(t *testing.T) {
   186  	defer leaktest.AfterTest(t)()
   187  
   188  	rt := runtime.DefaultRuntime()
   189  	cm := newConnManager()
   190  	require.NotNil(t, cm)
   191  
   192  	var wg sync.WaitGroup
   193  	for i := 0; i < 100; i++ {
   194  		wg.Add(2)
   195  		go func(j int) {
   196  			cn11 := testMakeCNServer(fmt.Sprintf("cn1-%d", j), "", 0, "hash1",
   197  				newLabelInfo("t1", map[string]string{
   198  					"k1": "v1",
   199  				}),
   200  			)
   201  			tu11 := newTunnel(context.TODO(), rt.Logger(), nil)
   202  			cm.connect(cn11, tu11)
   203  			wg.Done()
   204  		}(i)
   205  		go func(j int) {
   206  			cn11 := testMakeCNServer(fmt.Sprintf("cn2-%d", j), "", 0, "hash2",
   207  				newLabelInfo("t1", map[string]string{
   208  					"k2": "v2",
   209  				}),
   210  			)
   211  			tu11 := newTunnel(context.TODO(), rt.Logger(), nil)
   212  			cm.connect(cn11, tu11)
   213  			wg.Done()
   214  		}(i)
   215  	}
   216  	wg.Wait()
   217  
   218  	require.Equal(t, 200, cm.count())
   219  	require.Equal(t, 2, len(cm.getLabelHashes()))
   220  	require.Equal(t, 100, cm.getCNTunnels("hash1").count())
   221  	require.Equal(t, 100, cm.getCNTunnels("hash2").count())
   222  }
   223  
   224  func TestConnManagerLabelInfo(t *testing.T) {
   225  	rt := runtime.DefaultRuntime()
   226  	cm := newConnManager()
   227  	require.NotNil(t, cm)
   228  
   229  	cn11 := testMakeCNServer("cn11", "", 0, "hash1",
   230  		newLabelInfo("t1", map[string]string{
   231  			"k1": "v1",
   232  		}),
   233  	)
   234  
   235  	tu11 := newTunnel(context.TODO(), rt.Logger(), nil)
   236  	cm.connect(cn11, tu11)
   237  	require.Equal(t, 1, cm.count())
   238  	require.Equal(t, 1, len(cm.getLabelHashes()))
   239  	require.Equal(t, 1, cm.getCNTunnels("hash1").count())
   240  	require.Equal(t, 0, cm.getCNTunnels("hash2").count())
   241  
   242  	li := cm.getLabelInfo("hash1")
   243  	require.Equal(t, labelInfo{
   244  		Tenant: "t1",
   245  		Labels: map[string]string{
   246  			"k1": "v1",
   247  		},
   248  	}, li)
   249  }