github.com/koko1123/flow-go-1@v0.29.6/module/executiondatasync/execution_data/downloader.go (about)

     1  package execution_data
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  
     9  	"github.com/ipfs/go-cid"
    10  	"golang.org/x/sync/errgroup"
    11  
    12  	"github.com/koko1123/flow-go-1/model/flow"
    13  	"github.com/koko1123/flow-go-1/module"
    14  	"github.com/koko1123/flow-go-1/module/blobs"
    15  	"github.com/koko1123/flow-go-1/network"
    16  )
    17  
    18  // BlobSizeLimitExceededError is returned when a blob exceeds the maximum size allowed.
    19  type BlobSizeLimitExceededError struct {
    20  	cid cid.Cid
    21  }
    22  
    23  func (e *BlobSizeLimitExceededError) Error() string {
    24  	return fmt.Sprintf("blob %v exceeds maximum blob size", e.cid.String())
    25  }
    26  
    27  // Downloader is used to download execution data blobs from the network via a blob service.
    28  type Downloader interface {
    29  	module.ReadyDoneAware
    30  
    31  	// Download downloads and returns a Block Execution Data from the network.
    32  	// The returned error will be:
    33  	// - MalformedDataError if some level of the blob tree cannot be properly deserialized
    34  	// - BlobNotFoundError if some CID in the blob tree could not be found from the blob service
    35  	// - BlobSizeLimitExceededError if some blob in the blob tree exceeds the maximum allowed size
    36  	Download(ctx context.Context, executionDataID flow.Identifier) (*BlockExecutionData, error)
    37  }
    38  
    39  type downloader struct {
    40  	blobService network.BlobService
    41  	maxBlobSize int
    42  	serializer  Serializer
    43  }
    44  
    45  type DownloaderOption func(*downloader)
    46  
    47  func WithSerializer(serializer Serializer) DownloaderOption {
    48  	return func(d *downloader) {
    49  		d.serializer = serializer
    50  	}
    51  }
    52  
    53  func NewDownloader(blobService network.BlobService, opts ...DownloaderOption) *downloader {
    54  	d := &downloader{
    55  		blobService,
    56  		DefaultMaxBlobSize,
    57  		DefaultSerializer,
    58  	}
    59  
    60  	for _, opt := range opts {
    61  		opt(d)
    62  	}
    63  
    64  	return d
    65  }
    66  
    67  func (d *downloader) Ready() <-chan struct{} {
    68  	return d.blobService.Ready()
    69  }
    70  func (d *downloader) Done() <-chan struct{} {
    71  	return d.blobService.Done()
    72  }
    73  
    74  // Download downloads a blob tree identified by executionDataID from the network and returns the deserialized BlockExecutionData struct
    75  // During normal operation, the returned error will be:
    76  // - MalformedDataError if some level of the blob tree cannot be properly deserialized
    77  // - BlobNotFoundError if some CID in the blob tree could not be found from the blob service
    78  // - BlobSizeLimitExceededError if some blob in the blob tree exceeds the maximum allowed size
    79  func (d *downloader) Download(ctx context.Context, executionDataID flow.Identifier) (*BlockExecutionData, error) {
    80  	blobGetter := d.blobService.GetSession(ctx)
    81  
    82  	// First, download the root execution data record which contains a list of chunk execution data
    83  	// blobs included in the original record.
    84  	edRoot, err := d.getExecutionDataRoot(ctx, executionDataID, blobGetter)
    85  	if err != nil {
    86  		return nil, fmt.Errorf("failed to get execution data root: %w", err)
    87  	}
    88  
    89  	g, gCtx := errgroup.WithContext(ctx)
    90  
    91  	// Next, download each of the chunk execution data blobs
    92  	chunkExecutionDatas := make([]*ChunkExecutionData, len(edRoot.ChunkExecutionDataIDs))
    93  	for i, chunkDataID := range edRoot.ChunkExecutionDataIDs {
    94  		i := i
    95  		chunkDataID := chunkDataID
    96  
    97  		g.Go(func() error {
    98  			ced, err := d.getChunkExecutionData(
    99  				gCtx,
   100  				chunkDataID,
   101  				blobGetter,
   102  			)
   103  
   104  			if err != nil {
   105  				return fmt.Errorf("failed to get chunk execution data at index %d: %w", i, err)
   106  			}
   107  
   108  			chunkExecutionDatas[i] = ced
   109  
   110  			return nil
   111  		})
   112  	}
   113  
   114  	if err := g.Wait(); err != nil {
   115  		return nil, err
   116  	}
   117  
   118  	// Finally, recombine data into original record.
   119  	bed := &BlockExecutionData{
   120  		BlockID:             edRoot.BlockID,
   121  		ChunkExecutionDatas: chunkExecutionDatas,
   122  	}
   123  
   124  	return bed, nil
   125  }
   126  
   127  func (d *downloader) getExecutionDataRoot(
   128  	ctx context.Context,
   129  	rootID flow.Identifier,
   130  	blobGetter network.BlobGetter,
   131  ) (*BlockExecutionDataRoot, error) {
   132  	rootCid := flow.IdToCid(rootID)
   133  
   134  	blob, err := blobGetter.GetBlob(ctx, rootCid)
   135  	if err != nil {
   136  		if errors.Is(err, network.ErrBlobNotFound) {
   137  			return nil, NewBlobNotFoundError(rootCid)
   138  		}
   139  
   140  		return nil, fmt.Errorf("failed to get root blob: %w", err)
   141  	}
   142  
   143  	blobSize := len(blob.RawData())
   144  
   145  	if blobSize > d.maxBlobSize {
   146  		return nil, &BlobSizeLimitExceededError{blob.Cid()}
   147  	}
   148  
   149  	v, err := d.serializer.Deserialize(bytes.NewBuffer(blob.RawData()))
   150  	if err != nil {
   151  		return nil, NewMalformedDataError(err)
   152  	}
   153  
   154  	edRoot, ok := v.(*BlockExecutionDataRoot)
   155  	if !ok {
   156  		return nil, NewMalformedDataError(fmt.Errorf("execution data root blob does not deserialize to a BlockExecutionDataRoot, got %T instead", v))
   157  	}
   158  
   159  	return edRoot, nil
   160  }
   161  
   162  func (d *downloader) getChunkExecutionData(
   163  	ctx context.Context,
   164  	chunkExecutionDataID cid.Cid,
   165  	blobGetter network.BlobGetter,
   166  ) (*ChunkExecutionData, error) {
   167  	cids := []cid.Cid{chunkExecutionDataID}
   168  
   169  	// iteratively process each level of the blob tree until a ChunkExecutionData is returned or an
   170  	// error is encountered
   171  	for i := 0; ; i++ {
   172  		v, err := d.getBlobs(ctx, blobGetter, cids)
   173  		if err != nil {
   174  			return nil, fmt.Errorf("failed to get level %d of blob tree: %w", i, err)
   175  		}
   176  
   177  		switch v := v.(type) {
   178  		case *ChunkExecutionData:
   179  			return v, nil
   180  		case *[]cid.Cid:
   181  			cids = *v
   182  		default:
   183  			return nil, NewMalformedDataError(fmt.Errorf("blob tree contains unexpected type %T at level %d", v, i))
   184  		}
   185  	}
   186  }
   187  
   188  // getBlobs gets the given CIDs from the blobservice, reassembles the blobs, and deserializes the reassembled data into an object.
   189  func (d *downloader) getBlobs(ctx context.Context, blobGetter network.BlobGetter, cids []cid.Cid) (interface{}, error) {
   190  	blobCh, errCh := d.retrieveBlobs(ctx, blobGetter, cids)
   191  	bcr := blobs.NewBlobChannelReader(blobCh)
   192  	v, deserializeErr := d.serializer.Deserialize(bcr)
   193  	err := <-errCh
   194  
   195  	if err != nil {
   196  		return nil, err
   197  	}
   198  
   199  	if deserializeErr != nil {
   200  		return nil, NewMalformedDataError(deserializeErr)
   201  	}
   202  
   203  	return v, nil
   204  }
   205  
   206  // retrieveBlobs asynchronously retrieves the blobs for the given CIDs with the given BlobGetter.
   207  func (d *downloader) retrieveBlobs(parent context.Context, blobGetter network.BlobGetter, cids []cid.Cid) (<-chan blobs.Blob, <-chan error) {
   208  	blobsOut := make(chan blobs.Blob, len(cids))
   209  	errCh := make(chan error, 1)
   210  
   211  	go func() {
   212  		var err error
   213  
   214  		ctx, cancel := context.WithCancel(parent)
   215  		defer cancel()
   216  		defer close(blobsOut)
   217  		defer func() {
   218  			errCh <- err
   219  			close(errCh)
   220  		}()
   221  
   222  		blobChan := blobGetter.GetBlobs(ctx, cids) // initiate a batch request for the given CIDs
   223  		cachedBlobs := make(map[cid.Cid]blobs.Blob)
   224  		cidCounts := make(map[cid.Cid]int) // used to account for duplicate CIDs
   225  
   226  		for _, c := range cids {
   227  			cidCounts[c] += 1
   228  		}
   229  
   230  		// for each cid, find the corresponding blob from the incoming blob channel and send it to
   231  		// the outgoing blob channel in the proper order
   232  		for _, c := range cids {
   233  			blob, ok := cachedBlobs[c]
   234  
   235  			if !ok {
   236  				if blob, err = d.findBlob(blobChan, c, cachedBlobs); err != nil {
   237  					// the blob channel may be closed as a result of the context being canceled,
   238  					// in which case we should return the context error.
   239  					if ctxErr := ctx.Err(); ctxErr != nil {
   240  						err = ctxErr
   241  					}
   242  
   243  					return
   244  				}
   245  			}
   246  
   247  			cidCounts[c] -= 1
   248  
   249  			if cidCounts[c] == 0 {
   250  				delete(cachedBlobs, c)
   251  				delete(cidCounts, c)
   252  			}
   253  
   254  			blobsOut <- blob
   255  		}
   256  	}()
   257  
   258  	return blobsOut, errCh
   259  }
   260  
   261  // findBlob retrieves blobs from the given channel, caching them along the way, until it either
   262  // finds the target blob or exhausts the channel.
   263  func (d *downloader) findBlob(
   264  	blobChan <-chan blobs.Blob,
   265  	target cid.Cid,
   266  	cache map[cid.Cid]blobs.Blob,
   267  ) (blobs.Blob, error) {
   268  	// Note: blobs are returned as they are found, in no particular order
   269  	for blob := range blobChan {
   270  		// check blob size
   271  		blobSize := len(blob.RawData())
   272  
   273  		if blobSize > d.maxBlobSize {
   274  			return nil, &BlobSizeLimitExceededError{blob.Cid()}
   275  		}
   276  
   277  		cache[blob.Cid()] = blob
   278  
   279  		if blob.Cid() == target {
   280  			return blob, nil
   281  		}
   282  	}
   283  
   284  	return nil, NewBlobNotFoundError(target)
   285  }