github.com/koko1123/flow-go-1@v0.29.6/admin/command_runner_test.go (about)

     1  package admin
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/ecdsa"
     7  	"crypto/elliptic"
     8  	"crypto/rand"
     9  	"crypto/tls"
    10  	"crypto/x509"
    11  	"crypto/x509/pkix"
    12  	"encoding/json"
    13  	"encoding/pem"
    14  	"errors"
    15  	"fmt"
    16  	"math/big"
    17  	"net"
    18  	"net/http"
    19  	"os"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/dapperlabs/testingdock"
    24  	"github.com/rs/zerolog"
    25  	"github.com/stretchr/testify/require"
    26  	"github.com/stretchr/testify/suite"
    27  	"google.golang.org/grpc"
    28  	"google.golang.org/grpc/codes"
    29  	grpcinsecure "google.golang.org/grpc/credentials/insecure"
    30  	"google.golang.org/grpc/status"
    31  	"google.golang.org/protobuf/types/known/structpb"
    32  
    33  	pb "github.com/koko1123/flow-go-1/admin/admin"
    34  	"github.com/koko1123/flow-go-1/module/irrecoverable"
    35  	"github.com/koko1123/flow-go-1/utils/grpcutils"
    36  	"github.com/koko1123/flow-go-1/utils/unittest"
    37  )
    38  
    39  type CommandRunnerSuite struct {
    40  	suite.Suite
    41  
    42  	runner          *CommandRunner
    43  	bootstrapper    *CommandRunnerBootstrapper
    44  	httpAddress     string
    45  	grpcAddressSock string
    46  
    47  	client pb.AdminClient
    48  	conn   *grpc.ClientConn
    49  
    50  	cancel context.CancelFunc
    51  }
    52  
    53  func TestCommandRunner(t *testing.T) {
    54  	suite.Run(t, new(CommandRunnerSuite))
    55  }
    56  
    57  func (suite *CommandRunnerSuite) SetupTest() {
    58  	suite.httpAddress = unittest.IPPort(testingdock.RandomPort(suite.T()))
    59  	suite.bootstrapper = NewCommandRunnerBootstrapper()
    60  }
    61  
    62  func (suite *CommandRunnerSuite) TearDownTest() {
    63  	if suite.conn != nil {
    64  		err := suite.conn.Close()
    65  		suite.NoError(err)
    66  	}
    67  	if suite.grpcAddressSock != "" {
    68  		err := os.Remove(suite.grpcAddressSock)
    69  		suite.NoError(err)
    70  	}
    71  	suite.cancel()
    72  	<-suite.runner.Done()
    73  }
    74  
    75  func (suite *CommandRunnerSuite) SetupCommandRunner(opts ...CommandRunnerOption) {
    76  	ctx, cancel := context.WithCancel(context.Background())
    77  	suite.cancel = cancel
    78  
    79  	signalerCtx := irrecoverable.NewMockSignalerContext(suite.T(), ctx)
    80  
    81  	suite.grpcAddressSock = fmt.Sprintf("%s/%s-flow-node-admin.sock", os.TempDir(), unittest.GenerateRandomStringWithLen(16))
    82  	opts = append(opts, WithGRPCAddress(suite.grpcAddressSock), WithMaxMsgSize(grpcutils.DefaultMaxMsgSize))
    83  
    84  	logger := zerolog.New(zerolog.NewConsoleWriter())
    85  	suite.runner = suite.bootstrapper.Bootstrap(logger, suite.httpAddress, opts...)
    86  	suite.runner.Start(signalerCtx)
    87  	<-suite.runner.Ready()
    88  
    89  	conn, err := grpc.Dial("unix:///"+suite.runner.grpcAddress, grpc.WithTransportCredentials(grpcinsecure.NewCredentials()))
    90  	suite.NoError(err)
    91  	suite.conn = conn
    92  	suite.client = pb.NewAdminClient(conn)
    93  }
    94  
    95  func (suite *CommandRunnerSuite) TestHandler() {
    96  	called := false
    97  
    98  	suite.bootstrapper.RegisterHandler("foo", func(ctx context.Context, req *CommandRequest) (interface{}, error) {
    99  		select {
   100  		case <-ctx.Done():
   101  			return nil, ctx.Err()
   102  		default:
   103  		}
   104  
   105  		data := req.Data.(map[string]interface{})
   106  
   107  		suite.EqualValues(data["string"], "foo")
   108  		suite.EqualValues(data["number"], 123)
   109  		called = true
   110  
   111  		return "ok", nil
   112  	})
   113  
   114  	suite.SetupCommandRunner()
   115  
   116  	data := make(map[string]interface{})
   117  	data["string"] = "foo"
   118  	data["number"] = 123
   119  	val, err := structpb.NewValue(data)
   120  	suite.NoError(err)
   121  
   122  	ctx, cancel := context.WithCancel(context.Background())
   123  	defer cancel()
   124  	request := &pb.RunCommandRequest{
   125  		CommandName: "foo",
   126  		Data:        val,
   127  	}
   128  
   129  	_, err = suite.client.RunCommand(ctx, request)
   130  	suite.NoError(err)
   131  	suite.True(called)
   132  }
   133  
   134  func (suite *CommandRunnerSuite) TestUnimplementedHandler() {
   135  	suite.SetupCommandRunner()
   136  
   137  	data := make(map[string]interface{})
   138  	data["key"] = "value"
   139  	val, err := structpb.NewValue(data)
   140  	suite.NoError(err)
   141  
   142  	ctx, cancel := context.WithCancel(context.Background())
   143  	defer cancel()
   144  	request := &pb.RunCommandRequest{
   145  		CommandName: "foo",
   146  		Data:        val,
   147  	}
   148  
   149  	_, err = suite.client.RunCommand(ctx, request)
   150  	suite.Equal(codes.Unimplemented, status.Code(err))
   151  }
   152  
   153  func (suite *CommandRunnerSuite) TestValidator() {
   154  	calls := 0
   155  
   156  	suite.bootstrapper.RegisterHandler("foo", func(ctx context.Context, req *CommandRequest) (interface{}, error) {
   157  		select {
   158  		case <-ctx.Done():
   159  			return nil, ctx.Err()
   160  		default:
   161  		}
   162  
   163  		calls += 1
   164  
   165  		return "ok", nil
   166  	})
   167  
   168  	validatorErr := NewInvalidAdminReqErrorf("unexpected value")
   169  	suite.bootstrapper.RegisterValidator("foo", func(req *CommandRequest) error {
   170  		if req.Data.(map[string]interface{})["key"] != "value" {
   171  			return validatorErr
   172  		}
   173  		return nil
   174  	})
   175  
   176  	suite.SetupCommandRunner()
   177  
   178  	data := make(map[string]interface{})
   179  	data["key"] = "value"
   180  	val, err := structpb.NewValue(data)
   181  	suite.NoError(err)
   182  
   183  	ctx, cancel := context.WithCancel(context.Background())
   184  	defer cancel()
   185  	request := &pb.RunCommandRequest{
   186  		CommandName: "foo",
   187  		Data:        val,
   188  	}
   189  
   190  	_, err = suite.client.RunCommand(ctx, request)
   191  	suite.NoError(err)
   192  	suite.Equal(calls, 1)
   193  
   194  	data["key"] = "blah"
   195  	val, err = structpb.NewValue(data)
   196  	suite.NoError(err)
   197  	request.Data = val
   198  	_, err = suite.client.RunCommand(ctx, request)
   199  	suite.Equal(status.Convert(err).Message(), validatorErr.Error())
   200  	suite.Equal(codes.InvalidArgument, status.Code(err))
   201  	suite.Equal(calls, 1)
   202  }
   203  
   204  func (suite *CommandRunnerSuite) TestHandlerError() {
   205  	handlerErr := errors.New("handler error")
   206  	suite.bootstrapper.RegisterHandler("foo", func(ctx context.Context, req *CommandRequest) (interface{}, error) {
   207  		select {
   208  		case <-ctx.Done():
   209  			return nil, ctx.Err()
   210  		default:
   211  		}
   212  
   213  		return nil, handlerErr
   214  	})
   215  
   216  	suite.SetupCommandRunner()
   217  
   218  	data := make(map[string]interface{})
   219  	data["key"] = "value"
   220  	val, err := structpb.NewValue(data)
   221  	suite.NoError(err)
   222  
   223  	ctx, cancel := context.WithCancel(context.Background())
   224  	defer cancel()
   225  	request := &pb.RunCommandRequest{
   226  		CommandName: "foo",
   227  		Data:        val,
   228  	}
   229  
   230  	_, err = suite.client.RunCommand(ctx, request)
   231  	suite.Equal(status.Convert(err).Message(), handlerErr.Error())
   232  	suite.Equal(codes.Unknown, status.Code(err))
   233  }
   234  
   235  func (suite *CommandRunnerSuite) TestTimeout() {
   236  	suite.bootstrapper.RegisterHandler("foo", func(ctx context.Context, req *CommandRequest) (interface{}, error) {
   237  		<-ctx.Done()
   238  		return nil, ctx.Err()
   239  	})
   240  
   241  	suite.SetupCommandRunner()
   242  
   243  	data := make(map[string]interface{})
   244  	data["key"] = "value"
   245  	val, err := structpb.NewValue(data)
   246  	suite.NoError(err)
   247  
   248  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   249  	defer cancel()
   250  	request := &pb.RunCommandRequest{
   251  		CommandName: "foo",
   252  		Data:        val,
   253  	}
   254  
   255  	_, err = suite.client.RunCommand(ctx, request)
   256  	suite.Equal(codes.DeadlineExceeded, status.Code(err))
   257  }
   258  
   259  func (suite *CommandRunnerSuite) TestHTTPServer() {
   260  	called := false
   261  
   262  	suite.bootstrapper.RegisterHandler("foo", func(ctx context.Context, req *CommandRequest) (interface{}, error) {
   263  		select {
   264  		case <-ctx.Done():
   265  			return nil, ctx.Err()
   266  		default:
   267  		}
   268  
   269  		suite.EqualValues(req.Data.(map[string]interface{})["key"], "value")
   270  		called = true
   271  
   272  		return "ok", nil
   273  	})
   274  
   275  	suite.SetupCommandRunner()
   276  
   277  	url := fmt.Sprintf("http://%s/admin/run_command", suite.httpAddress)
   278  	reqBody := bytes.NewBuffer([]byte(`{"commandName": "foo", "data": {"key": "value"}}`))
   279  	resp, err := http.Post(url, "application/json", reqBody)
   280  	require.NoError(suite.T(), err)
   281  	defer func() {
   282  		if resp.Body != nil {
   283  			resp.Body.Close()
   284  		}
   285  	}()
   286  
   287  	suite.True(called)
   288  	suite.Equal("200 OK", resp.Status)
   289  }
   290  
   291  func (suite *CommandRunnerSuite) TestHTTPPProf() {
   292  	suite.SetupCommandRunner()
   293  
   294  	url := fmt.Sprintf("http://%s/debug/pprof/goroutine", suite.httpAddress)
   295  	resp, err := http.Get(url)
   296  	require.NoError(suite.T(), err)
   297  	defer func() {
   298  		if resp.Body != nil {
   299  			resp.Body.Close()
   300  		}
   301  	}()
   302  
   303  	suite.Equal(resp.Status, "200 OK")
   304  	suite.Equal(resp.Header.Get("Content-Type"), "application/octet-stream")
   305  }
   306  
   307  func (suite *CommandRunnerSuite) TestListCommands() {
   308  	suite.bootstrapper.RegisterHandler("foo", func(ctx context.Context, req *CommandRequest) (interface{}, error) {
   309  		return nil, nil
   310  	})
   311  	suite.bootstrapper.RegisterHandler("bar", func(ctx context.Context, req *CommandRequest) (interface{}, error) {
   312  		return nil, nil
   313  	})
   314  	suite.bootstrapper.RegisterHandler("baz", func(ctx context.Context, req *CommandRequest) (interface{}, error) {
   315  		return nil, nil
   316  	})
   317  
   318  	suite.SetupCommandRunner()
   319  
   320  	url := fmt.Sprintf("http://%s/admin/run_command", suite.httpAddress)
   321  	reqBody := bytes.NewBuffer([]byte(`{"commandName": "list-commands"}`))
   322  	resp, err := http.Post(url, "application/json", reqBody)
   323  	require.NoError(suite.T(), err)
   324  	defer func() {
   325  		if resp.Body != nil {
   326  			resp.Body.Close()
   327  		}
   328  	}()
   329  
   330  	suite.Equal("200 OK", resp.Status)
   331  
   332  	var response map[string][]string
   333  	require.NoError(suite.T(), json.NewDecoder(resp.Body).Decode(&response))
   334  	suite.Subset(response["output"], []string{"foo", "bar", "baz"})
   335  }
   336  
   337  func generateCerts(t *testing.T) (tls.Certificate, *x509.CertPool, tls.Certificate, *x509.CertPool) {
   338  	ca := &x509.Certificate{
   339  		SerialNumber: big.NewInt(1),
   340  		Subject: pkix.Name{
   341  			Organization: []string{"Dapper Labs, Inc."},
   342  		},
   343  		NotBefore:             time.Now(),
   344  		NotAfter:              time.Now().Add(time.Hour * 24 * 180),
   345  		KeyUsage:              x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
   346  		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
   347  		BasicConstraintsValid: true,
   348  		IsCA:                  true,
   349  	}
   350  	caPrivKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   351  	require.NoError(t, err)
   352  	caPrivKeyBytes, err := x509.MarshalECPrivateKey(caPrivKey)
   353  	require.NoError(t, err)
   354  	caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey)
   355  	require.NoError(t, err)
   356  	caPEM := new(bytes.Buffer)
   357  	err = pem.Encode(caPEM, &pem.Block{
   358  		Type:  "CERTIFICATE",
   359  		Bytes: caBytes,
   360  	})
   361  	require.NoError(t, err)
   362  	caPrivKeyPem := new(bytes.Buffer)
   363  	err = pem.Encode(caPrivKeyPem, &pem.Block{
   364  		Type:  "PRIVATE KEY",
   365  		Bytes: caPrivKeyBytes,
   366  	})
   367  	require.NoError(t, err)
   368  
   369  	serverTemplate := &x509.Certificate{
   370  		SerialNumber: big.NewInt(2),
   371  		Subject: pkix.Name{
   372  			Organization: []string{"Dapper Labs, Inc."},
   373  		},
   374  		IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
   375  		DNSNames:    []string{"localhost"},
   376  		NotBefore:   time.Now(),
   377  		NotAfter:    time.Now().Add(time.Hour * 24 * 180),
   378  		KeyUsage:    x509.KeyUsageDigitalSignature,
   379  		ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
   380  	}
   381  	serverPrivKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   382  	require.NoError(t, err)
   383  	serverPrivKeyBytes, err := x509.MarshalECPrivateKey(serverPrivKey)
   384  	require.NoError(t, err)
   385  	serverCertBytes, err := x509.CreateCertificate(rand.Reader, serverTemplate, ca, &serverPrivKey.PublicKey, caPrivKey)
   386  	require.NoError(t, err)
   387  	serverCertPEM := new(bytes.Buffer)
   388  	err = pem.Encode(serverCertPEM, &pem.Block{
   389  		Type:  "CERTIFICATE",
   390  		Bytes: serverCertBytes,
   391  	})
   392  	require.NoError(t, err)
   393  	serverPrivKeyPem := new(bytes.Buffer)
   394  	err = pem.Encode(serverPrivKeyPem, &pem.Block{
   395  		Type:  "PRIVATE KEY",
   396  		Bytes: serverPrivKeyBytes,
   397  	})
   398  	require.NoError(t, err)
   399  	serverCert, err := tls.X509KeyPair(serverCertPEM.Bytes(), serverPrivKeyPem.Bytes())
   400  	require.NoError(t, err)
   401  	serverCert.Leaf, err = x509.ParseCertificate(serverCert.Certificate[0])
   402  	require.NoError(t, err)
   403  	serverCertPool := x509.NewCertPool()
   404  	serverCertPool.AddCert(serverCert.Leaf)
   405  
   406  	clientTemplate := &x509.Certificate{
   407  		SerialNumber: big.NewInt(3),
   408  		Subject: pkix.Name{
   409  			Organization: []string{"Dapper Labs, Inc."},
   410  		},
   411  		NotBefore:   time.Now(),
   412  		NotAfter:    time.Now().Add(time.Hour * 24 * 180),
   413  		KeyUsage:    x509.KeyUsageDigitalSignature,
   414  		ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
   415  	}
   416  	clientPrivKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   417  	require.NoError(t, err)
   418  	clientPrivKeyBytes, err := x509.MarshalECPrivateKey(clientPrivKey)
   419  	require.NoError(t, err)
   420  	clientCertBytes, err := x509.CreateCertificate(rand.Reader, clientTemplate, ca, &clientPrivKey.PublicKey, caPrivKey)
   421  	require.NoError(t, err)
   422  	clientCertPEM := new(bytes.Buffer)
   423  	err = pem.Encode(clientCertPEM, &pem.Block{
   424  		Type:  "CERTIFICATE",
   425  		Bytes: clientCertBytes,
   426  	})
   427  	require.NoError(t, err)
   428  	clientPrivKeyPem := new(bytes.Buffer)
   429  	err = pem.Encode(clientPrivKeyPem, &pem.Block{
   430  		Type:  "PRIVATE KEY",
   431  		Bytes: clientPrivKeyBytes,
   432  	})
   433  	require.NoError(t, err)
   434  	clientCert, err := tls.X509KeyPair(clientCertPEM.Bytes(), clientPrivKeyPem.Bytes())
   435  	require.NoError(t, err)
   436  	clientCert.Leaf, err = x509.ParseCertificate(clientCert.Certificate[0])
   437  	require.NoError(t, err)
   438  	clientCertPool := x509.NewCertPool()
   439  	clientCertPool.AddCert(clientCert.Leaf)
   440  
   441  	return serverCert, serverCertPool, clientCert, clientCertPool
   442  }
   443  
   444  func (suite *CommandRunnerSuite) TestTLS() {
   445  	called := false
   446  
   447  	suite.bootstrapper.RegisterHandler("foo", func(ctx context.Context, req *CommandRequest) (interface{}, error) {
   448  		select {
   449  		case <-ctx.Done():
   450  			return nil, ctx.Err()
   451  		default:
   452  		}
   453  
   454  		suite.EqualValues(req.Data.(map[string]interface{})["key"], "value")
   455  		called = true
   456  
   457  		return "ok", nil
   458  	})
   459  
   460  	serverCert, serverCertPool, clientCert, clientCertPool := generateCerts(suite.T())
   461  	serverConfig := &tls.Config{
   462  		MinVersion:   tls.VersionTLS13,
   463  		ClientAuth:   tls.RequireAndVerifyClientCert,
   464  		Certificates: []tls.Certificate{serverCert},
   465  		ClientCAs:    clientCertPool,
   466  	}
   467  	clientConfig := &tls.Config{
   468  		MinVersion:   tls.VersionTLS13,
   469  		Certificates: []tls.Certificate{clientCert},
   470  		RootCAs:      serverCertPool,
   471  	}
   472  
   473  	suite.SetupCommandRunner(WithTLS(serverConfig))
   474  
   475  	client := &http.Client{
   476  		Transport: &http.Transport{
   477  			TLSClientConfig: clientConfig,
   478  		},
   479  	}
   480  	url := fmt.Sprintf("https://%s/admin/run_command", suite.httpAddress)
   481  	reqBody := bytes.NewBuffer([]byte(`{"commandName": "foo", "data": {"key": "value"}}`))
   482  	resp, err := client.Post(url, "application/json", reqBody)
   483  	require.NoError(suite.T(), err)
   484  	defer resp.Body.Close()
   485  
   486  	suite.True(called)
   487  	suite.Equal("200 OK", resp.Status)
   488  }