github.com/sunrise-zone/sunrise-node@v0.13.1-sr2/share/getters/cascade_test.go (about)

     1  package getters
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"strings"
     7  	"testing"
     8  
     9  	"github.com/golang/mock/gomock"
    10  	"github.com/stretchr/testify/assert"
    11  
    12  	"github.com/celestiaorg/rsmt2d"
    13  
    14  	"github.com/sunrise-zone/sunrise-node/header"
    15  	"github.com/sunrise-zone/sunrise-node/share"
    16  	"github.com/sunrise-zone/sunrise-node/share/mocks"
    17  )
    18  
    19  func TestCascadeGetter(t *testing.T) {
    20  	ctx, cancel := context.WithCancel(context.Background())
    21  	t.Cleanup(cancel)
    22  
    23  	const gettersN = 3
    24  	headers := make([]*header.ExtendedHeader, gettersN)
    25  	getters := make([]share.Getter, gettersN)
    26  	for i := range headers {
    27  		getters[i], headers[i] = TestGetter(t)
    28  	}
    29  
    30  	getter := NewCascadeGetter(getters)
    31  	t.Run("GetShare", func(t *testing.T) {
    32  		for _, eh := range headers {
    33  			sh, err := getter.GetShare(ctx, eh, 0, 0)
    34  			assert.NoError(t, err)
    35  			assert.NotEmpty(t, sh)
    36  		}
    37  	})
    38  
    39  	t.Run("GetEDS", func(t *testing.T) {
    40  		for _, eh := range headers {
    41  			sh, err := getter.GetEDS(ctx, eh)
    42  			assert.NoError(t, err)
    43  			assert.NotEmpty(t, sh)
    44  		}
    45  	})
    46  }
    47  
    48  func TestCascade(t *testing.T) {
    49  	ctrl := gomock.NewController(t)
    50  	ctx, cancel := context.WithCancel(context.Background())
    51  	t.Cleanup(cancel)
    52  
    53  	timeoutGetter := mocks.NewMockGetter(ctrl)
    54  	immediateFailGetter := mocks.NewMockGetter(ctrl)
    55  	successGetter := mocks.NewMockGetter(ctrl)
    56  	ctxGetter := mocks.NewMockGetter(ctrl)
    57  	timeoutGetter.EXPECT().GetEDS(gomock.Any(), gomock.Any()).
    58  		DoAndReturn(func(ctx context.Context, _ *header.ExtendedHeader) (*rsmt2d.ExtendedDataSquare, error) {
    59  			return nil, context.DeadlineExceeded
    60  		}).AnyTimes()
    61  	immediateFailGetter.EXPECT().GetEDS(gomock.Any(), gomock.Any()).
    62  		Return(nil, errors.New("second getter fails immediately")).AnyTimes()
    63  	successGetter.EXPECT().GetEDS(gomock.Any(), gomock.Any()).
    64  		Return(nil, nil).AnyTimes()
    65  	ctxGetter.EXPECT().GetEDS(gomock.Any(), gomock.Any()).
    66  		DoAndReturn(func(ctx context.Context, _ *header.ExtendedHeader) (*rsmt2d.ExtendedDataSquare, error) {
    67  			return nil, ctx.Err()
    68  		}).AnyTimes()
    69  
    70  	get := func(ctx context.Context, get share.Getter) (*rsmt2d.ExtendedDataSquare, error) {
    71  		return get.GetEDS(ctx, nil)
    72  	}
    73  
    74  	t.Run("SuccessFirst", func(t *testing.T) {
    75  		getters := []share.Getter{successGetter, timeoutGetter, immediateFailGetter}
    76  		_, err := cascadeGetters(ctx, getters, get)
    77  		assert.NoError(t, err)
    78  	})
    79  
    80  	t.Run("SuccessSecond", func(t *testing.T) {
    81  		getters := []share.Getter{immediateFailGetter, successGetter}
    82  		_, err := cascadeGetters(ctx, getters, get)
    83  		assert.NoError(t, err)
    84  	})
    85  
    86  	t.Run("SuccessSecondAfterFirst", func(t *testing.T) {
    87  		getters := []share.Getter{timeoutGetter, successGetter}
    88  		_, err := cascadeGetters(ctx, getters, get)
    89  		assert.NoError(t, err)
    90  	})
    91  
    92  	t.Run("SuccessAfterMultipleTimeouts", func(t *testing.T) {
    93  		getters := []share.Getter{timeoutGetter, immediateFailGetter, timeoutGetter, timeoutGetter, successGetter}
    94  		_, err := cascadeGetters(ctx, getters, get)
    95  		assert.NoError(t, err)
    96  	})
    97  
    98  	t.Run("Error", func(t *testing.T) {
    99  		getters := []share.Getter{immediateFailGetter, timeoutGetter, immediateFailGetter}
   100  		_, err := cascadeGetters(ctx, getters, get)
   101  		assert.Error(t, err)
   102  		assert.Equal(t, strings.Count(err.Error(), "\n"), 2)
   103  	})
   104  
   105  	t.Run("Context Canceled", func(t *testing.T) {
   106  		ctx, cancel := context.WithCancel(ctx)
   107  		cancel()
   108  		getters := []share.Getter{ctxGetter, ctxGetter, ctxGetter}
   109  		_, err := cascadeGetters(ctx, getters, get)
   110  		assert.Error(t, err)
   111  		assert.Equal(t, strings.Count(err.Error(), "\n"), 0)
   112  	})
   113  
   114  	t.Run("Single", func(t *testing.T) {
   115  		getters := []share.Getter{successGetter}
   116  		_, err := cascadeGetters(ctx, getters, get)
   117  		assert.NoError(t, err)
   118  	})
   119  }