git.frostfs.info/TrueCloudLab/frostfs-sdk-go@v0.0.0-20241022124111-5361f0ecebd3/object/erasurecode/reconstruct_test.go (about)

     1  package erasurecode_test
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"math"
     7  	"testing"
     8  
     9  	cidtest "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/container/id/test"
    10  	objectSDK "git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object"
    11  	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/erasurecode"
    12  	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/object/transformer"
    13  	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/user"
    14  	"git.frostfs.info/TrueCloudLab/frostfs-sdk-go/version"
    15  	"github.com/nspcc-dev/neo-go/pkg/crypto/keys"
    16  	"github.com/stretchr/testify/require"
    17  )
    18  
    19  func TestErasureCodeReconstruct(t *testing.T) {
    20  	const payloadSize = 99
    21  	const dataCount = 3
    22  	const parityCount = 2
    23  
    24  	// We would also like to test padding behaviour,
    25  	// so ensure padding is done.
    26  	require.NotZero(t, payloadSize%(dataCount+parityCount))
    27  
    28  	pk, err := keys.NewPrivateKey()
    29  	require.NoError(t, err)
    30  
    31  	original := newObject(t, payloadSize, pk)
    32  
    33  	c, err := erasurecode.NewConstructor(dataCount, parityCount)
    34  	require.NoError(t, err)
    35  
    36  	parts, err := c.Split(original, &pk.PrivateKey)
    37  	require.NoError(t, err)
    38  
    39  	t.Run("reconstruct header", func(t *testing.T) {
    40  		original := original.CutPayload()
    41  		parts := cloneSlice(parts)
    42  		for i := range parts {
    43  			parts[i] = parts[i].CutPayload()
    44  		}
    45  		t.Run("from data", func(t *testing.T) {
    46  			parts := cloneSlice(parts)
    47  			for i := dataCount; i < dataCount+parityCount; i++ {
    48  				parts[i] = nil
    49  			}
    50  			reconstructed, err := c.ReconstructHeader(parts)
    51  			require.NoError(t, err)
    52  			verifyReconstruction(t, original, reconstructed)
    53  		})
    54  		t.Run("from parity", func(t *testing.T) {
    55  			parts := cloneSlice(parts)
    56  			for i := range parityCount {
    57  				parts[i] = nil
    58  			}
    59  			reconstructed, err := c.ReconstructHeader(parts)
    60  			require.NoError(t, err)
    61  			verifyReconstruction(t, original, reconstructed)
    62  
    63  			t.Run("not enough shards", func(t *testing.T) {
    64  				parts[parityCount] = nil
    65  				_, err := c.ReconstructHeader(parts)
    66  				require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
    67  			})
    68  		})
    69  		t.Run("only nil parts", func(t *testing.T) {
    70  			parts := make([]*objectSDK.Object, len(parts))
    71  			_, err := c.ReconstructHeader(parts)
    72  			require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
    73  		})
    74  		t.Run("missing EC header", func(t *testing.T) {
    75  			parts := cloneSlice(parts)
    76  			parts[0] = deepCopy(t, parts[0])
    77  			parts[0].SetECHeader(nil)
    78  
    79  			_, err := c.ReconstructHeader(parts)
    80  			require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
    81  		})
    82  		t.Run("invalid index", func(t *testing.T) {
    83  			parts := cloneSlice(parts)
    84  			parts[0] = deepCopy(t, parts[0])
    85  
    86  			ec := parts[0].GetECHeader()
    87  			ec.SetIndex(1)
    88  			parts[0].SetECHeader(ec)
    89  
    90  			_, err := c.ReconstructHeader(parts)
    91  			require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
    92  		})
    93  		t.Run("invalid total", func(t *testing.T) {
    94  			parts := cloneSlice(parts)
    95  			parts[0] = deepCopy(t, parts[0])
    96  
    97  			ec := parts[0].GetECHeader()
    98  			ec.SetTotal(uint32(len(parts) + 1))
    99  			parts[0].SetECHeader(ec)
   100  
   101  			_, err := c.ReconstructHeader(parts)
   102  			require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
   103  		})
   104  		t.Run("inconsistent header length", func(t *testing.T) {
   105  			parts := cloneSlice(parts)
   106  			parts[0] = deepCopy(t, parts[0])
   107  
   108  			ec := parts[0].GetECHeader()
   109  			ec.SetHeaderLength(ec.HeaderLength() - 1)
   110  			parts[0].SetECHeader(ec)
   111  
   112  			_, err := c.ReconstructHeader(parts)
   113  			require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
   114  		})
   115  		t.Run("invalid header length", func(t *testing.T) {
   116  			parts := cloneSlice(parts)
   117  			for i := range parts {
   118  				parts[i] = deepCopy(t, parts[i])
   119  
   120  				ec := parts[0].GetECHeader()
   121  				ec.SetHeaderLength(math.MaxUint32)
   122  				parts[0].SetECHeader(ec)
   123  			}
   124  
   125  			_, err := c.ReconstructHeader(parts)
   126  			require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
   127  		})
   128  	})
   129  	t.Run("reconstruct data", func(t *testing.T) {
   130  		t.Run("from data", func(t *testing.T) {
   131  			parts := cloneSlice(parts)
   132  			for i := dataCount; i < dataCount+parityCount; i++ {
   133  				parts[i] = nil
   134  			}
   135  			reconstructed, err := c.Reconstruct(parts)
   136  			require.NoError(t, err)
   137  			verifyReconstruction(t, original, reconstructed)
   138  		})
   139  		t.Run("from parity", func(t *testing.T) {
   140  			parts := cloneSlice(parts)
   141  			for i := range parityCount {
   142  				parts[i] = nil
   143  			}
   144  			reconstructed, err := c.Reconstruct(parts)
   145  			require.NoError(t, err)
   146  			verifyReconstruction(t, original, reconstructed)
   147  
   148  			t.Run("not enough shards", func(t *testing.T) {
   149  				parts[parityCount] = nil
   150  				_, err := c.Reconstruct(parts)
   151  				require.ErrorIs(t, err, erasurecode.ErrMalformedSlice)
   152  			})
   153  		})
   154  	})
   155  	t.Run("reconstruct parts", func(t *testing.T) {
   156  		// We would like to also test that ReconstructParts doesn't perform
   157  		// excessive work, so ensure this test makes sense.
   158  		require.GreaterOrEqual(t, parityCount, 2)
   159  
   160  		t.Run("from data", func(t *testing.T) {
   161  			oldParts := parts
   162  			parts := cloneSlice(parts)
   163  			for i := dataCount; i < dataCount+parityCount; i++ {
   164  				parts[i] = nil
   165  			}
   166  
   167  			required := make([]bool, len(parts))
   168  			required[dataCount] = true
   169  
   170  			require.NoError(t, c.ReconstructParts(parts, required, nil))
   171  
   172  			old := deepCopy(t, oldParts[dataCount])
   173  			old.SetSignature(nil)
   174  			require.Equal(t, old, parts[dataCount])
   175  
   176  			for i := dataCount + 1; i < dataCount+parityCount; i++ {
   177  				require.Nil(t, parts[i])
   178  			}
   179  		})
   180  		t.Run("from parity", func(t *testing.T) {
   181  			oldParts := parts
   182  			parts := cloneSlice(parts)
   183  			for i := range parityCount {
   184  				parts[i] = nil
   185  			}
   186  
   187  			required := make([]bool, len(parts))
   188  			required[0] = true
   189  
   190  			require.NoError(t, c.ReconstructParts(parts, required, nil))
   191  
   192  			old := deepCopy(t, oldParts[0])
   193  			old.SetSignature(nil)
   194  			require.Equal(t, old, parts[0])
   195  
   196  			for i := 1; i < parityCount; i++ {
   197  				require.Nil(t, parts[i])
   198  			}
   199  		})
   200  	})
   201  }
   202  
   203  func newObject(t *testing.T, size uint64, pk *keys.PrivateKey) *objectSDK.Object {
   204  	// Use transformer to form object to avoid potential bugs with yet another helper object creation in tests.
   205  	tt := &testTarget{}
   206  	p := transformer.NewPayloadSizeLimiter(transformer.Params{
   207  		Key:                    &pk.PrivateKey,
   208  		NextTargetInit:         func() transformer.ObjectWriter { return tt },
   209  		NetworkState:           dummyEpochSource(123),
   210  		MaxSize:                size + 1,
   211  		WithoutHomomorphicHash: true,
   212  	})
   213  	cnr := cidtest.ID()
   214  	ver := version.Current()
   215  	hdr := objectSDK.New()
   216  	hdr.SetContainerID(cnr)
   217  	hdr.SetType(objectSDK.TypeRegular)
   218  	hdr.SetVersion(&ver)
   219  
   220  	var owner user.ID
   221  	user.IDFromKey(&owner, pk.PrivateKey.PublicKey)
   222  	hdr.SetOwnerID(owner)
   223  
   224  	var attr objectSDK.Attribute
   225  	attr.SetKey("somekey")
   226  	attr.SetValue("somevalue")
   227  	hdr.SetAttributes(attr)
   228  
   229  	expectedPayload := make([]byte, size)
   230  	_, _ = rand.Read(expectedPayload)
   231  	writeObject(t, context.Background(), p, hdr, expectedPayload)
   232  	require.Len(t, tt.objects, 1)
   233  	return tt.objects[0]
   234  }
   235  
   236  func writeObject(t *testing.T, ctx context.Context, target transformer.ChunkedObjectWriter, header *objectSDK.Object, payload []byte) *transformer.AccessIdentifiers {
   237  	require.NoError(t, target.WriteHeader(ctx, header))
   238  
   239  	_, err := target.Write(ctx, payload)
   240  	require.NoError(t, err)
   241  
   242  	ids, err := target.Close(ctx)
   243  	require.NoError(t, err)
   244  
   245  	return ids
   246  }
   247  
   248  func verifyReconstruction(t *testing.T, original, reconstructed *objectSDK.Object) {
   249  	require.True(t, reconstructed.VerifyIDSignature())
   250  	reconstructed.ToV2().SetMarshalData(nil)
   251  	original.ToV2().SetMarshalData(nil)
   252  
   253  	require.Equal(t, original, reconstructed)
   254  }
   255  
   256  func deepCopy(t *testing.T, obj *objectSDK.Object) *objectSDK.Object {
   257  	data, err := obj.Marshal()
   258  	require.NoError(t, err)
   259  
   260  	res := objectSDK.New()
   261  	require.NoError(t, res.Unmarshal(data))
   262  	return res
   263  }
   264  
   265  func cloneSlice[T any](src []T) []T {
   266  	dst := make([]T, len(src))
   267  	copy(dst, src)
   268  	return dst
   269  }
   270  
   271  type dummyEpochSource uint64
   272  
   273  func (s dummyEpochSource) CurrentEpoch() uint64 {
   274  	return uint64(s)
   275  }
   276  
   277  type testTarget struct {
   278  	objects []*objectSDK.Object
   279  }
   280  
   281  func (tt *testTarget) WriteObject(_ context.Context, o *objectSDK.Object) error {
   282  	tt.objects = append(tt.objects, o)
   283  	return nil // AccessIdentifiers should not be used.
   284  }