github.com/Finschia/ostracon@v1.1.5/statesync/stateprovider_test.go (about)

     1  package statesync
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"net/http"
     8  	"os"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/stretchr/testify/assert"
    13  	"github.com/stretchr/testify/require"
    14  
    15  	"github.com/tendermint/tendermint/proto/tendermint/state"
    16  	tmproto "github.com/tendermint/tendermint/proto/tendermint/types"
    17  	tmversion "github.com/tendermint/tendermint/proto/tendermint/version"
    18  
    19  	"github.com/Finschia/ostracon/config"
    20  	"github.com/Finschia/ostracon/libs/log"
    21  	tmrand "github.com/Finschia/ostracon/libs/rand"
    22  	"github.com/Finschia/ostracon/light"
    23  	ctypes "github.com/Finschia/ostracon/rpc/core/types"
    24  	rpcserver "github.com/Finschia/ostracon/rpc/jsonrpc/server"
    25  	rpctypes "github.com/Finschia/ostracon/rpc/jsonrpc/types"
    26  	"github.com/Finschia/ostracon/types"
    27  	tmtime "github.com/Finschia/ostracon/types/time"
    28  	"github.com/Finschia/ostracon/version"
    29  )
    30  
    31  func TestNewLightClientStateProvider(t *testing.T) {
    32  	setupVars(t)
    33  	cfg.SetRoot(os.TempDir())
    34  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
    35  	defer cancel()
    36  	listeners, servers, closeListenersFunc := serveTestRPCServers(t, cfg, 2)
    37  	defer closeListenersFunc(listeners)
    38  	type args struct {
    39  		ctx           context.Context
    40  		chainID       string
    41  		version       state.Version
    42  		initialHeight int64
    43  		servers       []string
    44  		trustOptions  light.TrustOptions
    45  		logger        log.Logger
    46  	}
    47  	successFunc := func(t assert.TestingT, err error, i ...interface{}) bool {
    48  		return assert.NoError(t, err)
    49  	}
    50  	serversErrorFunc := func(t assert.TestingT, err error, i ...interface{}) bool {
    51  		return assert.Error(t, err) &&
    52  			assert.Contains(t, err.Error(), "at least 2 RPC servers are required, got ")
    53  	}
    54  	lightErrorFunc := func(t assert.TestingT, err error, i ...interface{}) bool {
    55  		return assert.Error(t, err) &&
    56  			assert.Contains(t, err.Error(), "invalid TrustOptions: negative or zero period")
    57  	}
    58  	tests := []struct {
    59  		name    string
    60  		args    args
    61  		want    StateProvider
    62  		wantErr assert.ErrorAssertionFunc
    63  	}{
    64  		{
    65  			name: "success",
    66  			args: args{
    67  				ctx:     ctx,
    68  				chainID: chainId,
    69  				servers: servers,
    70  				logger:  log.NewNopLogger(),
    71  				trustOptions: light.TrustOptions{
    72  					Period: cfg.StateSync.TrustPeriod,
    73  					Height: 1,
    74  					Hash:   header.Hash(),
    75  				}},
    76  			want:    &lightClientStateProvider{},
    77  			wantErr: successFunc,
    78  		},
    79  		{
    80  			name:    "empty servers",
    81  			args:    args{},
    82  			want:    nil,
    83  			wantErr: serversErrorFunc,
    84  		},
    85  		{
    86  			name:    "duplicated servers",
    87  			args:    args{servers: []string{"a", "a"}},
    88  			want:    nil,
    89  			wantErr: serversErrorFunc,
    90  		},
    91  		{
    92  			name:    "fail light client",
    93  			args:    args{ctx: ctx, servers: servers},
    94  			want:    nil,
    95  			wantErr: lightErrorFunc,
    96  		},
    97  	}
    98  	for _, tt := range tests {
    99  		t.Run(tt.name, func(t *testing.T) {
   100  			got, err := NewLightClientStateProvider(
   101  				tt.args.ctx,
   102  				tt.args.chainID,
   103  				tt.args.version,
   104  				tt.args.initialHeight,
   105  				tt.args.servers,
   106  				tt.args.trustOptions,
   107  				tt.args.logger)
   108  			if !tt.wantErr(t, err) {
   109  				return
   110  			}
   111  			assert.IsType(t, tt.want, got)
   112  		})
   113  	}
   114  }
   115  
   116  const (
   117  	height = int64(1)
   118  	round  = int32(0)
   119  	size   = 1
   120  	index  = int32(0)
   121  )
   122  
   123  var (
   124  	chainId  string
   125  	cfg      *config.Config
   126  	genDoc   *types.GenesisDoc
   127  	privVals []*types.PrivValidator
   128  	vals     []*types.Validator
   129  	header   *types.Header
   130  	commit   *types.Commit
   131  )
   132  
   133  func setupVars(t *testing.T) {
   134  	// config
   135  	chainId = fmt.Sprintf("test-chain-%v", tmrand.Str(6))
   136  	cfg = config.TestConfig()
   137  	// getDoc
   138  	genDoc = &types.GenesisDoc{
   139  		ChainID:         chainId,
   140  		GenesisTime:     tmtime.Now(),
   141  		ConsensusParams: types.DefaultConsensusParams(),
   142  	}
   143  	// validators
   144  	privVals = make([]*types.PrivValidator, size)
   145  	vals = make([]*types.Validator, size)
   146  	for i := 0; i < size; i++ {
   147  		val, privVal := types.RandValidator(true, 1)
   148  		privVals[i] = &privVal
   149  		vals[i] = val
   150  	}
   151  	// header
   152  	valSet, err := types.ValidatorSetFromExistingValidators(vals)
   153  	require.NoError(t, err)
   154  	header = &types.Header{
   155  		Version: tmversion.Consensus{
   156  			Block: version.BlockProtocol,
   157  		},
   158  		ChainID:         chainId,
   159  		Height:          height,
   160  		ValidatorsHash:  valSet.Hash(),
   161  		ProposerAddress: vals[index].Address,
   162  	}
   163  	// block id
   164  	hash := tmrand.Bytes(32)
   165  	blockId := types.BlockID{
   166  		Hash: header.Hash(),
   167  		PartSetHeader: types.PartSetHeader{
   168  			Total: 1,
   169  			Hash:  hash,
   170  		},
   171  	}
   172  	// vote
   173  	vote := &types.Vote{
   174  		ValidatorAddress: vals[index].Address,
   175  		ValidatorIndex:   index,
   176  		Height:           height,
   177  		Round:            round,
   178  		Timestamp:        tmtime.Now(),
   179  		Type:             tmproto.PrecommitType,
   180  		BlockID:          blockId,
   181  	}
   182  	v := vote.ToProto()
   183  	require.NoError(t, (*privVals[index]).SignVote(chainId, v))
   184  	vote.Signature = v.Signature
   185  	vote.Timestamp = v.Timestamp
   186  	// commit
   187  	commit = &types.Commit{
   188  		Height:  height,
   189  		Round:   round,
   190  		BlockID: blockId,
   191  		Signatures: []types.CommitSig{
   192  			{
   193  				BlockIDFlag:      types.BlockIDFlagCommit,
   194  				ValidatorAddress: vote.ValidatorAddress,
   195  				Timestamp:        vote.Timestamp,
   196  				Signature:        vote.Signature,
   197  			},
   198  		},
   199  	}
   200  }
   201  
   202  func serveTestRPCServers(t *testing.T, config *config.Config, num int,
   203  ) (listeners []*net.Listener, servers []string, closeListenersFunc func(listeners []*net.Listener)) {
   204  	// Start the RPC server
   205  	mux := http.NewServeMux()
   206  	rpcserver.RegisterRPCFuncs(mux, routes, log.TestingLogger())
   207  	wm := rpcserver.NewWebsocketManager(routes)
   208  	mux.HandleFunc("/websocket", wm.WebsocketHandler)
   209  	rpcConfig := rpcserver.DefaultConfig()
   210  	listeners = make([]*net.Listener, num)
   211  	servers = make([]string, num)
   212  	for i := 0; i < num; i++ {
   213  		listener, err := rpcserver.Listen("tcp://127.0.0.1:0", rpcConfig)
   214  		require.NoError(t, err)
   215  		listeners[i] = &listener
   216  		servers[i] = listener.Addr().String()
   217  		go func() {
   218  			_ = rpcserver.Serve(listener, mux, log.NewNopLogger(), rpcConfig)
   219  		}()
   220  	}
   221  	closeListenersFunc = func(listeners []*net.Listener) {
   222  		for _, listener := range listeners {
   223  			require.NoError(t, (*listener).Close())
   224  		}
   225  	}
   226  	return listeners, servers, closeListenersFunc
   227  }
   228  
   229  var routes = map[string]*rpcserver.RPCFunc{
   230  	"genesis":    rpcserver.NewRPCFunc(genesisFunc, ""),
   231  	"commit":     rpcserver.NewRPCFunc(commitFunc, "height"),
   232  	"validators": rpcserver.NewRPCFunc(validatorsFunc, "height,page,per_page"),
   233  }
   234  
   235  func genesisFunc(ctx *rpctypes.Context) (*ctypes.ResultGenesis, error) {
   236  	return &ctypes.ResultGenesis{Genesis: genDoc}, nil
   237  }
   238  
   239  func commitFunc(ctx *rpctypes.Context, heightPtr *int64) (*ctypes.ResultCommit, error) {
   240  	return ctypes.NewResultCommit(header, commit, true), nil
   241  }
   242  
   243  func validatorsFunc(ctx *rpctypes.Context, heightPtr *int64, pagePtr, perPagePtr *int,
   244  ) (*ctypes.ResultValidators, error) {
   245  	return &ctypes.ResultValidators{
   246  		BlockHeight: height,
   247  		Validators:  vals,
   248  		Count:       size,
   249  		Total:       size,
   250  	}, nil
   251  }