github.com/MetalBlockchain/metalgo@v1.11.9/api/admin/client_test.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package admin
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"testing"
    10  
    11  	"github.com/stretchr/testify/require"
    12  
    13  	"github.com/MetalBlockchain/metalgo/api"
    14  	"github.com/MetalBlockchain/metalgo/ids"
    15  	"github.com/MetalBlockchain/metalgo/utils/logging"
    16  	"github.com/MetalBlockchain/metalgo/utils/rpc"
    17  )
    18  
    19  var (
    20  	errTest = errors.New("non-nil error")
    21  
    22  	SuccessResponseTests = []struct {
    23  		name        string
    24  		expectedErr error
    25  	}{
    26  		{
    27  			name:        "no error",
    28  			expectedErr: nil,
    29  		},
    30  		{
    31  			name:        "error",
    32  			expectedErr: errTest,
    33  		},
    34  	}
    35  )
    36  
    37  type mockClient struct {
    38  	response interface{}
    39  	err      error
    40  }
    41  
    42  // NewMockClient returns a mock client for testing
    43  func NewMockClient(response interface{}, err error) rpc.EndpointRequester {
    44  	return &mockClient{
    45  		response: response,
    46  		err:      err,
    47  	}
    48  }
    49  
    50  func (mc *mockClient) SendRequest(_ context.Context, _ string, _ interface{}, reply interface{}, _ ...rpc.Option) error {
    51  	if mc.err != nil {
    52  		return mc.err
    53  	}
    54  
    55  	switch p := reply.(type) {
    56  	case *api.EmptyReply:
    57  		response := mc.response.(*api.EmptyReply)
    58  		*p = *response
    59  	case *GetChainAliasesReply:
    60  		response := mc.response.(*GetChainAliasesReply)
    61  		*p = *response
    62  	case *LoadVMsReply:
    63  		response := mc.response.(*LoadVMsReply)
    64  		*p = *response
    65  	case *LoggerLevelReply:
    66  		response := mc.response.(*LoggerLevelReply)
    67  		*p = *response
    68  	case *interface{}:
    69  		response := mc.response.(*interface{})
    70  		*p = *response
    71  	default:
    72  		panic("illegal type")
    73  	}
    74  	return nil
    75  }
    76  
    77  func TestStartCPUProfiler(t *testing.T) {
    78  	for _, test := range SuccessResponseTests {
    79  		t.Run(test.name, func(t *testing.T) {
    80  			mockClient := client{requester: NewMockClient(&api.EmptyReply{}, test.expectedErr)}
    81  			err := mockClient.StartCPUProfiler(context.Background())
    82  			require.ErrorIs(t, err, test.expectedErr)
    83  		})
    84  	}
    85  }
    86  
    87  func TestStopCPUProfiler(t *testing.T) {
    88  	for _, test := range SuccessResponseTests {
    89  		t.Run(test.name, func(t *testing.T) {
    90  			mockClient := client{requester: NewMockClient(&api.EmptyReply{}, test.expectedErr)}
    91  			err := mockClient.StopCPUProfiler(context.Background())
    92  			require.ErrorIs(t, err, test.expectedErr)
    93  		})
    94  	}
    95  }
    96  
    97  func TestMemoryProfile(t *testing.T) {
    98  	for _, test := range SuccessResponseTests {
    99  		t.Run(test.name, func(t *testing.T) {
   100  			mockClient := client{requester: NewMockClient(&api.EmptyReply{}, test.expectedErr)}
   101  			err := mockClient.MemoryProfile(context.Background())
   102  			require.ErrorIs(t, err, test.expectedErr)
   103  		})
   104  	}
   105  }
   106  
   107  func TestLockProfile(t *testing.T) {
   108  	for _, test := range SuccessResponseTests {
   109  		t.Run(test.name, func(t *testing.T) {
   110  			mockClient := client{requester: NewMockClient(&api.EmptyReply{}, test.expectedErr)}
   111  			err := mockClient.LockProfile(context.Background())
   112  			require.ErrorIs(t, err, test.expectedErr)
   113  		})
   114  	}
   115  }
   116  
   117  func TestAlias(t *testing.T) {
   118  	for _, test := range SuccessResponseTests {
   119  		t.Run(test.name, func(t *testing.T) {
   120  			mockClient := client{requester: NewMockClient(&api.EmptyReply{}, test.expectedErr)}
   121  			err := mockClient.Alias(context.Background(), "alias", "alias2")
   122  			require.ErrorIs(t, err, test.expectedErr)
   123  		})
   124  	}
   125  }
   126  
   127  func TestAliasChain(t *testing.T) {
   128  	for _, test := range SuccessResponseTests {
   129  		t.Run(test.name, func(t *testing.T) {
   130  			mockClient := client{requester: NewMockClient(&api.EmptyReply{}, test.expectedErr)}
   131  			err := mockClient.AliasChain(context.Background(), "chain", "chain-alias")
   132  			require.ErrorIs(t, err, test.expectedErr)
   133  		})
   134  	}
   135  }
   136  
   137  func TestGetChainAliases(t *testing.T) {
   138  	t.Run("successful", func(t *testing.T) {
   139  		require := require.New(t)
   140  
   141  		expectedReply := []string{"alias1", "alias2"}
   142  		mockClient := client{requester: NewMockClient(&GetChainAliasesReply{
   143  			Aliases: expectedReply,
   144  		}, nil)}
   145  
   146  		reply, err := mockClient.GetChainAliases(context.Background(), "chain")
   147  		require.NoError(err)
   148  		require.Equal(expectedReply, reply)
   149  	})
   150  
   151  	t.Run("failure", func(t *testing.T) {
   152  		mockClient := client{requester: NewMockClient(&GetChainAliasesReply{}, errTest)}
   153  		_, err := mockClient.GetChainAliases(context.Background(), "chain")
   154  		require.ErrorIs(t, err, errTest)
   155  	})
   156  }
   157  
   158  func TestStacktrace(t *testing.T) {
   159  	for _, test := range SuccessResponseTests {
   160  		t.Run(test.name, func(t *testing.T) {
   161  			mockClient := client{requester: NewMockClient(&api.EmptyReply{}, test.expectedErr)}
   162  			err := mockClient.Stacktrace(context.Background())
   163  			require.ErrorIs(t, err, test.expectedErr)
   164  		})
   165  	}
   166  }
   167  
   168  func TestReloadInstalledVMs(t *testing.T) {
   169  	t.Run("successful", func(t *testing.T) {
   170  		require := require.New(t)
   171  
   172  		expectedNewVMs := map[ids.ID][]string{
   173  			ids.GenerateTestID(): {"foo"},
   174  			ids.GenerateTestID(): {"bar"},
   175  		}
   176  		expectedFailedVMs := map[ids.ID]string{
   177  			ids.GenerateTestID(): "oops",
   178  			ids.GenerateTestID(): "uh-oh",
   179  		}
   180  		mockClient := client{requester: NewMockClient(&LoadVMsReply{
   181  			NewVMs:    expectedNewVMs,
   182  			FailedVMs: expectedFailedVMs,
   183  		}, nil)}
   184  
   185  		loadedVMs, failedVMs, err := mockClient.LoadVMs(context.Background())
   186  		require.NoError(err)
   187  		require.Equal(expectedNewVMs, loadedVMs)
   188  		require.Equal(expectedFailedVMs, failedVMs)
   189  	})
   190  
   191  	t.Run("failure", func(t *testing.T) {
   192  		mockClient := client{requester: NewMockClient(&LoadVMsReply{}, errTest)}
   193  		_, _, err := mockClient.LoadVMs(context.Background())
   194  		require.ErrorIs(t, err, errTest)
   195  	})
   196  }
   197  
   198  func TestSetLoggerLevel(t *testing.T) {
   199  	type test struct {
   200  		name            string
   201  		logLevel        string
   202  		displayLevel    string
   203  		serviceResponse map[string]LogAndDisplayLevels
   204  		serviceErr      error
   205  		clientErr       error
   206  	}
   207  	tests := []test{
   208  		{
   209  			name:         "Happy path",
   210  			logLevel:     "INFO",
   211  			displayLevel: "INFO",
   212  			serviceResponse: map[string]LogAndDisplayLevels{
   213  				"Happy path": {LogLevel: logging.Info, DisplayLevel: logging.Info},
   214  			},
   215  			serviceErr: nil,
   216  			clientErr:  nil,
   217  		},
   218  		{
   219  			name:            "Service errors",
   220  			logLevel:        "INFO",
   221  			displayLevel:    "INFO",
   222  			serviceResponse: nil,
   223  			serviceErr:      errTest,
   224  			clientErr:       errTest,
   225  		},
   226  		{
   227  			name:            "Invalid log level",
   228  			logLevel:        "invalid",
   229  			displayLevel:    "INFO",
   230  			serviceResponse: nil,
   231  			serviceErr:      nil,
   232  			clientErr:       logging.ErrUnknownLevel,
   233  		},
   234  		{
   235  			name:            "Invalid display level",
   236  			logLevel:        "INFO",
   237  			displayLevel:    "invalid",
   238  			serviceResponse: nil,
   239  			serviceErr:      nil,
   240  			clientErr:       logging.ErrUnknownLevel,
   241  		},
   242  	}
   243  	for _, tt := range tests {
   244  		t.Run(tt.name, func(t *testing.T) {
   245  			require := require.New(t)
   246  
   247  			c := client{
   248  				requester: NewMockClient(
   249  					&LoggerLevelReply{
   250  						LoggerLevels: tt.serviceResponse,
   251  					},
   252  					tt.serviceErr,
   253  				),
   254  			}
   255  			res, err := c.SetLoggerLevel(
   256  				context.Background(),
   257  				"",
   258  				tt.logLevel,
   259  				tt.displayLevel,
   260  			)
   261  			require.ErrorIs(err, tt.clientErr)
   262  			if tt.clientErr != nil {
   263  				return
   264  			}
   265  			require.Equal(tt.serviceResponse, res)
   266  		})
   267  	}
   268  }
   269  
   270  func TestGetLoggerLevel(t *testing.T) {
   271  	type test struct {
   272  		name            string
   273  		loggerName      string
   274  		serviceResponse map[string]LogAndDisplayLevels
   275  		serviceErr      error
   276  		clientErr       error
   277  	}
   278  	tests := []test{
   279  		{
   280  			name:       "Happy Path",
   281  			loggerName: "foo",
   282  			serviceResponse: map[string]LogAndDisplayLevels{
   283  				"foo": {LogLevel: logging.Info, DisplayLevel: logging.Info},
   284  			},
   285  			serviceErr: nil,
   286  			clientErr:  nil,
   287  		},
   288  		{
   289  			name:            "service errors",
   290  			loggerName:      "foo",
   291  			serviceResponse: nil,
   292  			serviceErr:      errTest,
   293  			clientErr:       errTest,
   294  		},
   295  	}
   296  	for _, tt := range tests {
   297  		t.Run(tt.name, func(t *testing.T) {
   298  			require := require.New(t)
   299  
   300  			c := client{
   301  				requester: NewMockClient(
   302  					&LoggerLevelReply{
   303  						LoggerLevels: tt.serviceResponse,
   304  					},
   305  					tt.serviceErr,
   306  				),
   307  			}
   308  			res, err := c.GetLoggerLevel(
   309  				context.Background(),
   310  				tt.loggerName,
   311  			)
   312  			require.ErrorIs(err, tt.clientErr)
   313  			if tt.clientErr != nil {
   314  				return
   315  			}
   316  			require.Equal(tt.serviceResponse, res)
   317  		})
   318  	}
   319  }
   320  
   321  func TestGetConfig(t *testing.T) {
   322  	type test struct {
   323  		name             string
   324  		serviceErr       error
   325  		clientErr        error
   326  		expectedResponse interface{}
   327  	}
   328  	var resp interface{} = "response"
   329  	tests := []test{
   330  		{
   331  			name:             "Happy path",
   332  			serviceErr:       nil,
   333  			clientErr:        nil,
   334  			expectedResponse: &resp,
   335  		},
   336  		{
   337  			name:             "service errors",
   338  			serviceErr:       errTest,
   339  			clientErr:        errTest,
   340  			expectedResponse: nil,
   341  		},
   342  	}
   343  	for _, tt := range tests {
   344  		t.Run(tt.name, func(t *testing.T) {
   345  			require := require.New(t)
   346  
   347  			c := client{
   348  				requester: NewMockClient(tt.expectedResponse, tt.serviceErr),
   349  			}
   350  			res, err := c.GetConfig(context.Background())
   351  			require.ErrorIs(err, tt.clientErr)
   352  			if tt.clientErr != nil {
   353  				return
   354  			}
   355  			require.Equal(resp, res)
   356  		})
   357  	}
   358  }