github.com/matrixorigin/matrixone@v1.2.0/pkg/proxy/plugin_test.go (about)

     1  // Copyright 2021 - 2023 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package proxy
    16  
    17  import (
    18  	"context"
    19  	"testing"
    20  	"time"
    21  
    22  	"github.com/lni/goutils/leaktest"
    23  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    24  	"github.com/matrixorigin/matrixone/pkg/common/morpc"
    25  	"github.com/matrixorigin/matrixone/pkg/common/runtime"
    26  	"github.com/matrixorigin/matrixone/pkg/frontend"
    27  	"github.com/matrixorigin/matrixone/pkg/pb/metadata"
    28  	"github.com/matrixorigin/matrixone/pkg/pb/plugin"
    29  	"github.com/stretchr/testify/require"
    30  )
    31  
    32  var _ Router = (*pluginRouter)(nil)
    33  
    34  type mockPlugin struct {
    35  	mockRecommendCNFn func(ctx context.Context, clientInfo clientInfo) (*plugin.Recommendation, error)
    36  }
    37  
    38  func (p *mockPlugin) RecommendCN(ctx context.Context, clientInfo clientInfo) (*plugin.Recommendation, error) {
    39  	if p.mockRecommendCNFn != nil {
    40  		return p.mockRecommendCNFn(ctx, clientInfo)
    41  	}
    42  	return &plugin.Recommendation{
    43  		Action: plugin.Bypass,
    44  	}, nil
    45  }
    46  
    47  type mockRouter struct {
    48  	mockRouteFn func(ctx context.Context, ci clientInfo) (*CNServer, error)
    49  
    50  	refreshCount int
    51  }
    52  
    53  func (r *mockRouter) Route(ctx context.Context, ci clientInfo, f func(string) bool) (*CNServer, error) {
    54  	if r.mockRouteFn != nil {
    55  		return r.mockRouteFn(ctx, ci)
    56  	}
    57  	return nil, nil
    58  }
    59  
    60  func (r *mockRouter) SelectByConnID(connID uint32) (*CNServer, error) {
    61  	return nil, nil
    62  }
    63  
    64  func (r *mockRouter) Connect(c *CNServer, handshakeResp *frontend.Packet, t *tunnel) (ServerConn, []byte, error) {
    65  	return nil, nil, nil
    66  }
    67  
    68  func (r *mockRouter) Refresh(sync bool) {
    69  	r.refreshCount++
    70  }
    71  
    72  func TestPluginRouter_Route(t *testing.T) {
    73  	defer leaktest.AfterTest(t)()
    74  
    75  	runtime.SetupProcessLevelRuntime(runtime.DefaultRuntime())
    76  	tests := []struct {
    77  		name              string
    78  		mockRouteFn       func(ctx context.Context, ci clientInfo) (*CNServer, error)
    79  		mockRecommendCNFn func(ctx context.Context, ci clientInfo) (*plugin.Recommendation, error)
    80  		expectErr         bool
    81  		expectUUID        string
    82  		expectRefresh     int
    83  	}{{
    84  		name: "recommend select CN",
    85  		mockRecommendCNFn: func(ctx context.Context, ci clientInfo) (*plugin.Recommendation, error) {
    86  			return &plugin.Recommendation{
    87  				Action: plugin.Select,
    88  				CN: &metadata.CNService{
    89  					ServiceID: "cn0",
    90  				},
    91  			}, nil
    92  		},
    93  		expectUUID: "cn0",
    94  	}, {
    95  		name: "recommend bypass",
    96  		mockRecommendCNFn: func(ctx context.Context, ci clientInfo) (*plugin.Recommendation, error) {
    97  			return &plugin.Recommendation{
    98  				Action: plugin.Bypass,
    99  			}, nil
   100  		},
   101  		mockRouteFn: func(ctx context.Context, ci clientInfo) (*CNServer, error) {
   102  			return &CNServer{uuid: "cn1"}, nil
   103  		},
   104  		expectUUID: "cn1",
   105  	}, {
   106  		name: "recommend reject",
   107  		mockRecommendCNFn: func(ctx context.Context, ci clientInfo) (*plugin.Recommendation, error) {
   108  			return &plugin.Recommendation{
   109  				Action:  plugin.Reject,
   110  				Message: "IP not in whitelist",
   111  			}, nil
   112  		},
   113  		expectErr: true,
   114  	}, {
   115  		name: "error after bypass",
   116  		mockRecommendCNFn: func(ctx context.Context, ci clientInfo) (*plugin.Recommendation, error) {
   117  			return &plugin.Recommendation{
   118  				Action: plugin.Bypass,
   119  			}, nil
   120  		},
   121  		mockRouteFn: func(ctx context.Context, ci clientInfo) (*CNServer, error) {
   122  			return nil, moerr.NewInternalErrorNoCtx("boom")
   123  		},
   124  		expectErr: true,
   125  	}, {
   126  		name: "unknown action",
   127  		mockRecommendCNFn: func(ctx context.Context, ci clientInfo) (*plugin.Recommendation, error) {
   128  			return &plugin.Recommendation{
   129  				Action: -1,
   130  			}, nil
   131  		},
   132  		expectErr: true,
   133  	}, {
   134  		name: "error recommend",
   135  		mockRecommendCNFn: func(ctx context.Context, ci clientInfo) (*plugin.Recommendation, error) {
   136  			return nil, moerr.NewInternalErrorNoCtx("boom")
   137  		},
   138  		expectErr: true,
   139  	}, {
   140  		name: "refresh",
   141  		mockRecommendCNFn: func(ctx context.Context, ci clientInfo) (*plugin.Recommendation, error) {
   142  			return &plugin.Recommendation{
   143  				Action: plugin.Select,
   144  				CN: &metadata.CNService{
   145  					ServiceID: "cn0",
   146  				},
   147  				Updated: true,
   148  			}, nil
   149  		},
   150  		expectUUID:    "cn0",
   151  		expectRefresh: 1,
   152  	}}
   153  
   154  	for _, tt := range tests {
   155  		t.Run(tt.name, func(t *testing.T) {
   156  			p := &mockPlugin{mockRecommendCNFn: tt.mockRecommendCNFn}
   157  			r := &mockRouter{mockRouteFn: tt.mockRouteFn}
   158  			pr := newPluginRouter(r, p)
   159  			cn, err := pr.Route(context.TODO(), clientInfo{}, nil)
   160  			if tt.expectErr {
   161  				require.Error(t, err)
   162  				require.Nil(t, cn)
   163  			} else {
   164  				require.NotNil(t, cn)
   165  				require.Equal(t, cn.uuid, tt.expectUUID)
   166  			}
   167  			require.Equal(t, r.refreshCount, tt.expectRefresh)
   168  		})
   169  	}
   170  }
   171  
   172  func TestRPCPlugin(t *testing.T) {
   173  	defer leaktest.AfterTest(t)()
   174  
   175  	runtime.SetupProcessLevelRuntime(runtime.DefaultRuntime())
   176  	tests := []struct {
   177  		name       string
   178  		response   *plugin.Recommendation
   179  		expectErr  bool
   180  		expectUUID string
   181  	}{{
   182  		name:     "plugin bypass",
   183  		response: &plugin.Recommendation{Action: plugin.Bypass},
   184  	}, {
   185  		name: "plugin select",
   186  		response: &plugin.Recommendation{
   187  			Action: plugin.Select,
   188  			CN: &metadata.CNService{
   189  				ServiceID: "cn0",
   190  			},
   191  		},
   192  	}, {
   193  		name: "plugin reject",
   194  		response: &plugin.Recommendation{
   195  			Action:  plugin.Reject,
   196  			Message: "boom",
   197  		},
   198  	},
   199  	}
   200  
   201  	for _, tt := range tests {
   202  		t.Run(tt.name, func(t *testing.T) {
   203  			ctx := context.Background()
   204  			addr := "unix:///tmp/plugin.sock"
   205  			s, err := morpc.NewRPCServer("test-plugin-server",
   206  				addr,
   207  				morpc.NewMessageCodec(func() morpc.Message {
   208  					return &plugin.Request{}
   209  				}),
   210  			)
   211  			require.NoError(t, err)
   212  			s.RegisterRequestHandler(func(ctx context.Context, msg morpc.RPCMessage, sequence uint64, cs morpc.ClientSession) error {
   213  				request := msg.Message
   214  				r, ok := request.(*plugin.Request)
   215  				require.True(t, ok)
   216  				return cs.Write(ctx, &plugin.Response{
   217  					RequestID:      r.RequestID,
   218  					Recommendation: tt.response,
   219  				})
   220  			})
   221  			require.NoError(t, s.Start())
   222  			defer func() {
   223  				require.NoError(t, s.Close())
   224  			}()
   225  			p, err := newRPCPlugin(addr, time.Second)
   226  			defer func() {
   227  				require.NoError(t, p.Close())
   228  			}()
   229  			require.NoError(t, err)
   230  			rec, err := p.RecommendCN(ctx, clientInfo{})
   231  			require.NoError(t, err)
   232  			require.Equal(t, tt.response.Action, rec.Action)
   233  		})
   234  	}
   235  }