github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/block/azure/multipart_block_writer.go (about)

     1  package azure
     2  
     3  import (
     4  	"bufio"
     5  	"context"
     6  	"encoding/base64"
     7  	"encoding/hex"
     8  	"fmt"
     9  	"io"
    10  	"sort"
    11  	"strconv"
    12  	"strings"
    13  
    14  	"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
    15  	"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob"
    16  	"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob"
    17  	"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
    18  	"github.com/google/uuid"
    19  	"github.com/treeverse/lakefs/pkg/block"
    20  	"github.com/treeverse/lakefs/pkg/logging"
    21  )
    22  
    23  type MultipartBlockWriter struct {
    24  	reader *block.HashingReader // the reader that would be passed to copyFromReader, this is needed in order to get size and md5
    25  	// to is the location we are writing our chunks to.
    26  	to      *blockblob.Client
    27  	toIDs   *blockblob.Client
    28  	toSizes *blockblob.Client
    29  	etag    string
    30  }
    31  
    32  func NewMultipartBlockWriter(reader *block.HashingReader, containerURL container.Client, objName string) *MultipartBlockWriter {
    33  	return &MultipartBlockWriter{
    34  		reader:  reader,
    35  		to:      containerURL.NewBlockBlobClient(objName),
    36  		toIDs:   containerURL.NewBlockBlobClient(objName + idSuffix),
    37  		toSizes: containerURL.NewBlockBlobClient(objName + sizeSuffix),
    38  	}
    39  }
    40  
    41  func (m *MultipartBlockWriter) StageBlock(ctx context.Context, base64BlockID string, body io.ReadSeekCloser, options *blockblob.StageBlockOptions) (blockblob.StageBlockResponse, error) {
    42  	return m.to.StageBlock(ctx, base64BlockID, body, options)
    43  }
    44  
    45  func (m *MultipartBlockWriter) CommitBlockList(ctx context.Context, ids []string, options *blockblob.CommitBlockListOptions) (blockblob.CommitBlockListResponse, error) {
    46  	m.etag = "\"" + hex.EncodeToString(m.reader.Md5.Sum(nil)) + "\""
    47  	base64Etag := base64.StdEncoding.EncodeToString([]byte(m.etag))
    48  
    49  	// write to blockIDs
    50  	pd := strings.Join(ids, "\n") + "\n"
    51  	var leaseAccessConditions *blob.LeaseAccessConditions
    52  	if options.AccessConditions != nil {
    53  		leaseAccessConditions = options.AccessConditions.LeaseAccessConditions
    54  	}
    55  	_, err := m.toIDs.StageBlock(ctx, base64Etag, streaming.NopCloser(strings.NewReader(pd)), &blockblob.StageBlockOptions{
    56  		LeaseAccessConditions: leaseAccessConditions,
    57  	})
    58  	if err != nil {
    59  		return blockblob.CommitBlockListResponse{}, fmt.Errorf("failed staging part data: %w", err)
    60  	}
    61  	// write block sizes
    62  	sd := strconv.Itoa(int(m.reader.CopiedSize)) + "\n"
    63  	_, err = m.toSizes.StageBlock(ctx, base64Etag, streaming.NopCloser(strings.NewReader(sd)), &blockblob.StageBlockOptions{
    64  		LeaseAccessConditions: leaseAccessConditions,
    65  	})
    66  	if err != nil {
    67  		return blockblob.CommitBlockListResponse{}, fmt.Errorf("failed staging part data: %w", err)
    68  	}
    69  
    70  	return blockblob.CommitBlockListResponse{}, err
    71  }
    72  
    73  func (m *MultipartBlockWriter) Upload(_ context.Context, _ io.ReadSeekCloser, _ *blockblob.UploadOptions) (blockblob.UploadResponse, error) {
    74  	panic("Should not be called")
    75  }
    76  
    77  func completeMultipart(ctx context.Context, parts []block.MultipartPart, container container.Client, objName string) (*block.CompleteMultiPartUploadResponse, error) {
    78  	sort.Slice(parts, func(i, j int) bool {
    79  		return parts[i].PartNumber < parts[j].PartNumber
    80  	})
    81  	// extract staging blockIDs
    82  	metaBlockIDs := make([]string, len(parts))
    83  	for i, part := range parts {
    84  		// add Quotations marks (") if missing, Etags sent by spark include Quotations marks, Etags sent aws cli don't include Quotations marks
    85  		etag := strings.Trim(part.ETag, "\"")
    86  		etag = "\"" + etag + "\""
    87  		base64Etag := base64.StdEncoding.EncodeToString([]byte(etag))
    88  		metaBlockIDs[i] = base64Etag
    89  	}
    90  
    91  	stageBlockIDs, err := getMultipartIDs(ctx, container, objName, metaBlockIDs)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  
    96  	size, err := getMultipartSize(ctx, container, objName, metaBlockIDs)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  	blobURL := container.NewBlockBlobClient(objName)
   101  
   102  	res, err := blobURL.CommitBlockList(ctx, stageBlockIDs, nil)
   103  	if err != nil {
   104  		return nil, err
   105  	}
   106  	etag := string(*res.ETag)
   107  	return &block.CompleteMultiPartUploadResponse{
   108  		ETag:          etag,
   109  		ContentLength: size,
   110  	}, nil
   111  }
   112  
   113  func getMultipartIDs(ctx context.Context, container container.Client, objName string, base64BlockIDs []string) ([]string, error) {
   114  	blobURL := container.NewBlockBlobClient(objName + idSuffix)
   115  	_, err := blobURL.CommitBlockList(ctx, base64BlockIDs, nil)
   116  	if err != nil {
   117  		return nil, err
   118  	}
   119  
   120  	downloadResponse, err := blobURL.DownloadStream(ctx, nil)
   121  	if err != nil {
   122  		return nil, err
   123  	}
   124  	bodyStream := downloadResponse.Body
   125  	defer func() {
   126  		_ = bodyStream.Close()
   127  	}()
   128  	scanner := bufio.NewScanner(bodyStream)
   129  	ids := make([]string, 0)
   130  	for scanner.Scan() {
   131  		id := scanner.Text()
   132  		ids = append(ids, id)
   133  	}
   134  
   135  	// remove
   136  	_, err = blobURL.Delete(ctx, nil)
   137  	if err != nil {
   138  		logging.FromContext(ctx).WithField("blob_url", blobURL.URL()).WithError(err).Warn("Failed to delete multipart ids data file")
   139  	}
   140  	return ids, nil
   141  }
   142  
   143  func getMultipartSize(ctx context.Context, container container.Client, objName string, base64BlockIDs []string) (int64, error) {
   144  	blobURL := container.NewBlockBlobClient(objName + sizeSuffix)
   145  	_, err := blobURL.CommitBlockList(ctx, base64BlockIDs, nil)
   146  	if err != nil {
   147  		return 0, err
   148  	}
   149  
   150  	downloadResponse, err := blobURL.DownloadStream(ctx, nil)
   151  	if err != nil {
   152  		return 0, err
   153  	}
   154  	bodyStream := downloadResponse.Body
   155  	defer func() {
   156  		_ = bodyStream.Close()
   157  	}()
   158  	scanner := bufio.NewScanner(bodyStream)
   159  	size := 0
   160  	for scanner.Scan() {
   161  		s := scanner.Text()
   162  		stageSize, err := strconv.Atoi(s)
   163  		if err != nil {
   164  			return 0, err
   165  		}
   166  		size += stageSize
   167  	}
   168  
   169  	// remove
   170  	_, err = blobURL.Delete(ctx, nil)
   171  	if err != nil {
   172  		logging.FromContext(ctx).WithField("blob_url", blobURL.URL()).WithError(err).Warn("Failed to delete multipart size data file")
   173  	}
   174  	return int64(size), nil
   175  }
   176  
   177  func copyPartRange(ctx context.Context, clientCache *ClientCache, destinationKey, sourceKey BlobURLInfo, startPosition, count int64) (*block.UploadPartResponse, error) {
   178  	destinationContainer, err := clientCache.NewContainerClient(destinationKey.StorageAccountName, destinationKey.ContainerName)
   179  	if err != nil {
   180  		return nil, fmt.Errorf("copy part: get destination client: %w", err)
   181  	}
   182  	sourceContainer, err := clientCache.NewContainerClient(sourceKey.StorageAccountName, sourceKey.ContainerName)
   183  	if err != nil {
   184  		return nil, fmt.Errorf("copy part: get source client: %w", err)
   185  	}
   186  	base64BlockID := generateRandomBlockID()
   187  	destinationBlob := destinationContainer.NewBlockBlobClient(destinationKey.BlobURL)
   188  	sourceBlob := sourceContainer.NewBlockBlobClient(sourceKey.BlobURL)
   189  
   190  	stageBlockResponse, err := destinationBlob.StageBlockFromURL(ctx, base64BlockID, sourceBlob.URL(),
   191  		&blockblob.StageBlockFromURLOptions{
   192  			Range: blob.HTTPRange{
   193  				Offset: startPosition,
   194  				Count:  count,
   195  			},
   196  		})
   197  	if err != nil {
   198  		return nil, err
   199  	}
   200  
   201  	// add size, etag
   202  	etag := "\"" + hex.EncodeToString(stageBlockResponse.ContentMD5) + "\""
   203  	base64Etag := base64.StdEncoding.EncodeToString([]byte(etag))
   204  	// stage id data
   205  	blobIDsBlob := destinationContainer.NewBlockBlobClient(destinationKey.BlobURL + idSuffix)
   206  	_, err = blobIDsBlob.StageBlock(ctx, base64Etag, streaming.NopCloser(strings.NewReader(base64BlockID+"\n")), nil)
   207  	if err != nil {
   208  		return nil, fmt.Errorf("failed staging part data: %w", err)
   209  	}
   210  
   211  	// stage size data
   212  	sizeData := fmt.Sprintf("%d\n", count)
   213  	blobSizesBlob := destinationContainer.NewBlockBlobClient(destinationKey.BlobURL + sizeSuffix)
   214  	_, err = blobSizesBlob.StageBlock(ctx, base64Etag, streaming.NopCloser(strings.NewReader(sizeData)), nil)
   215  	if err != nil {
   216  		return nil, fmt.Errorf("failed staging part data: %w", err)
   217  	}
   218  
   219  	return &block.UploadPartResponse{
   220  		ETag: strings.Trim(etag, `"`),
   221  	}, nil
   222  }
   223  
   224  func generateRandomBlockID() string {
   225  	uu := uuid.New()
   226  	u := [64]byte{}
   227  	copy(u[:], uu[:])
   228  	return base64.StdEncoding.EncodeToString(u[:])
   229  }