github.com/sunrise-zone/sunrise-node@v0.13.1-sr2/share/getters/cascade_test.go (about) 1 package getters 2 3 import ( 4 "context" 5 "errors" 6 "strings" 7 "testing" 8 9 "github.com/golang/mock/gomock" 10 "github.com/stretchr/testify/assert" 11 12 "github.com/celestiaorg/rsmt2d" 13 14 "github.com/sunrise-zone/sunrise-node/header" 15 "github.com/sunrise-zone/sunrise-node/share" 16 "github.com/sunrise-zone/sunrise-node/share/mocks" 17 ) 18 19 func TestCascadeGetter(t *testing.T) { 20 ctx, cancel := context.WithCancel(context.Background()) 21 t.Cleanup(cancel) 22 23 const gettersN = 3 24 headers := make([]*header.ExtendedHeader, gettersN) 25 getters := make([]share.Getter, gettersN) 26 for i := range headers { 27 getters[i], headers[i] = TestGetter(t) 28 } 29 30 getter := NewCascadeGetter(getters) 31 t.Run("GetShare", func(t *testing.T) { 32 for _, eh := range headers { 33 sh, err := getter.GetShare(ctx, eh, 0, 0) 34 assert.NoError(t, err) 35 assert.NotEmpty(t, sh) 36 } 37 }) 38 39 t.Run("GetEDS", func(t *testing.T) { 40 for _, eh := range headers { 41 sh, err := getter.GetEDS(ctx, eh) 42 assert.NoError(t, err) 43 assert.NotEmpty(t, sh) 44 } 45 }) 46 } 47 48 func TestCascade(t *testing.T) { 49 ctrl := gomock.NewController(t) 50 ctx, cancel := context.WithCancel(context.Background()) 51 t.Cleanup(cancel) 52 53 timeoutGetter := mocks.NewMockGetter(ctrl) 54 immediateFailGetter := mocks.NewMockGetter(ctrl) 55 successGetter := mocks.NewMockGetter(ctrl) 56 ctxGetter := mocks.NewMockGetter(ctrl) 57 timeoutGetter.EXPECT().GetEDS(gomock.Any(), gomock.Any()). 58 DoAndReturn(func(ctx context.Context, _ *header.ExtendedHeader) (*rsmt2d.ExtendedDataSquare, error) { 59 return nil, context.DeadlineExceeded 60 }).AnyTimes() 61 immediateFailGetter.EXPECT().GetEDS(gomock.Any(), gomock.Any()). 62 Return(nil, errors.New("second getter fails immediately")).AnyTimes() 63 successGetter.EXPECT().GetEDS(gomock.Any(), gomock.Any()). 64 Return(nil, nil).AnyTimes() 65 ctxGetter.EXPECT().GetEDS(gomock.Any(), gomock.Any()). 66 DoAndReturn(func(ctx context.Context, _ *header.ExtendedHeader) (*rsmt2d.ExtendedDataSquare, error) { 67 return nil, ctx.Err() 68 }).AnyTimes() 69 70 get := func(ctx context.Context, get share.Getter) (*rsmt2d.ExtendedDataSquare, error) { 71 return get.GetEDS(ctx, nil) 72 } 73 74 t.Run("SuccessFirst", func(t *testing.T) { 75 getters := []share.Getter{successGetter, timeoutGetter, immediateFailGetter} 76 _, err := cascadeGetters(ctx, getters, get) 77 assert.NoError(t, err) 78 }) 79 80 t.Run("SuccessSecond", func(t *testing.T) { 81 getters := []share.Getter{immediateFailGetter, successGetter} 82 _, err := cascadeGetters(ctx, getters, get) 83 assert.NoError(t, err) 84 }) 85 86 t.Run("SuccessSecondAfterFirst", func(t *testing.T) { 87 getters := []share.Getter{timeoutGetter, successGetter} 88 _, err := cascadeGetters(ctx, getters, get) 89 assert.NoError(t, err) 90 }) 91 92 t.Run("SuccessAfterMultipleTimeouts", func(t *testing.T) { 93 getters := []share.Getter{timeoutGetter, immediateFailGetter, timeoutGetter, timeoutGetter, successGetter} 94 _, err := cascadeGetters(ctx, getters, get) 95 assert.NoError(t, err) 96 }) 97 98 t.Run("Error", func(t *testing.T) { 99 getters := []share.Getter{immediateFailGetter, timeoutGetter, immediateFailGetter} 100 _, err := cascadeGetters(ctx, getters, get) 101 assert.Error(t, err) 102 assert.Equal(t, strings.Count(err.Error(), "\n"), 2) 103 }) 104 105 t.Run("Context Canceled", func(t *testing.T) { 106 ctx, cancel := context.WithCancel(ctx) 107 cancel() 108 getters := []share.Getter{ctxGetter, ctxGetter, ctxGetter} 109 _, err := cascadeGetters(ctx, getters, get) 110 assert.Error(t, err) 111 assert.Equal(t, strings.Count(err.Error(), "\n"), 0) 112 }) 113 114 t.Run("Single", func(t *testing.T) { 115 getters := []share.Getter{successGetter} 116 _, err := cascadeGetters(ctx, getters, get) 117 assert.NoError(t, err) 118 }) 119 }