github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/block/azure/multipart_block_writer.go (about) 1 package azure 2 3 import ( 4 "bufio" 5 "context" 6 "encoding/base64" 7 "encoding/hex" 8 "fmt" 9 "io" 10 "sort" 11 "strconv" 12 "strings" 13 14 "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" 15 "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" 16 "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob" 17 "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container" 18 "github.com/google/uuid" 19 "github.com/treeverse/lakefs/pkg/block" 20 "github.com/treeverse/lakefs/pkg/logging" 21 ) 22 23 type MultipartBlockWriter struct { 24 reader *block.HashingReader // the reader that would be passed to copyFromReader, this is needed in order to get size and md5 25 // to is the location we are writing our chunks to. 26 to *blockblob.Client 27 toIDs *blockblob.Client 28 toSizes *blockblob.Client 29 etag string 30 } 31 32 func NewMultipartBlockWriter(reader *block.HashingReader, containerURL container.Client, objName string) *MultipartBlockWriter { 33 return &MultipartBlockWriter{ 34 reader: reader, 35 to: containerURL.NewBlockBlobClient(objName), 36 toIDs: containerURL.NewBlockBlobClient(objName + idSuffix), 37 toSizes: containerURL.NewBlockBlobClient(objName + sizeSuffix), 38 } 39 } 40 41 func (m *MultipartBlockWriter) StageBlock(ctx context.Context, base64BlockID string, body io.ReadSeekCloser, options *blockblob.StageBlockOptions) (blockblob.StageBlockResponse, error) { 42 return m.to.StageBlock(ctx, base64BlockID, body, options) 43 } 44 45 func (m *MultipartBlockWriter) CommitBlockList(ctx context.Context, ids []string, options *blockblob.CommitBlockListOptions) (blockblob.CommitBlockListResponse, error) { 46 m.etag = "\"" + hex.EncodeToString(m.reader.Md5.Sum(nil)) + "\"" 47 base64Etag := base64.StdEncoding.EncodeToString([]byte(m.etag)) 48 49 // write to blockIDs 50 pd := strings.Join(ids, "\n") + "\n" 51 var leaseAccessConditions *blob.LeaseAccessConditions 52 if options.AccessConditions != nil { 53 leaseAccessConditions = options.AccessConditions.LeaseAccessConditions 54 } 55 _, err := m.toIDs.StageBlock(ctx, base64Etag, streaming.NopCloser(strings.NewReader(pd)), &blockblob.StageBlockOptions{ 56 LeaseAccessConditions: leaseAccessConditions, 57 }) 58 if err != nil { 59 return blockblob.CommitBlockListResponse{}, fmt.Errorf("failed staging part data: %w", err) 60 } 61 // write block sizes 62 sd := strconv.Itoa(int(m.reader.CopiedSize)) + "\n" 63 _, err = m.toSizes.StageBlock(ctx, base64Etag, streaming.NopCloser(strings.NewReader(sd)), &blockblob.StageBlockOptions{ 64 LeaseAccessConditions: leaseAccessConditions, 65 }) 66 if err != nil { 67 return blockblob.CommitBlockListResponse{}, fmt.Errorf("failed staging part data: %w", err) 68 } 69 70 return blockblob.CommitBlockListResponse{}, err 71 } 72 73 func (m *MultipartBlockWriter) Upload(_ context.Context, _ io.ReadSeekCloser, _ *blockblob.UploadOptions) (blockblob.UploadResponse, error) { 74 panic("Should not be called") 75 } 76 77 func completeMultipart(ctx context.Context, parts []block.MultipartPart, container container.Client, objName string) (*block.CompleteMultiPartUploadResponse, error) { 78 sort.Slice(parts, func(i, j int) bool { 79 return parts[i].PartNumber < parts[j].PartNumber 80 }) 81 // extract staging blockIDs 82 metaBlockIDs := make([]string, len(parts)) 83 for i, part := range parts { 84 // add Quotations marks (") if missing, Etags sent by spark include Quotations marks, Etags sent aws cli don't include Quotations marks 85 etag := strings.Trim(part.ETag, "\"") 86 etag = "\"" + etag + "\"" 87 base64Etag := base64.StdEncoding.EncodeToString([]byte(etag)) 88 metaBlockIDs[i] = base64Etag 89 } 90 91 stageBlockIDs, err := getMultipartIDs(ctx, container, objName, metaBlockIDs) 92 if err != nil { 93 return nil, err 94 } 95 96 size, err := getMultipartSize(ctx, container, objName, metaBlockIDs) 97 if err != nil { 98 return nil, err 99 } 100 blobURL := container.NewBlockBlobClient(objName) 101 102 res, err := blobURL.CommitBlockList(ctx, stageBlockIDs, nil) 103 if err != nil { 104 return nil, err 105 } 106 etag := string(*res.ETag) 107 return &block.CompleteMultiPartUploadResponse{ 108 ETag: etag, 109 ContentLength: size, 110 }, nil 111 } 112 113 func getMultipartIDs(ctx context.Context, container container.Client, objName string, base64BlockIDs []string) ([]string, error) { 114 blobURL := container.NewBlockBlobClient(objName + idSuffix) 115 _, err := blobURL.CommitBlockList(ctx, base64BlockIDs, nil) 116 if err != nil { 117 return nil, err 118 } 119 120 downloadResponse, err := blobURL.DownloadStream(ctx, nil) 121 if err != nil { 122 return nil, err 123 } 124 bodyStream := downloadResponse.Body 125 defer func() { 126 _ = bodyStream.Close() 127 }() 128 scanner := bufio.NewScanner(bodyStream) 129 ids := make([]string, 0) 130 for scanner.Scan() { 131 id := scanner.Text() 132 ids = append(ids, id) 133 } 134 135 // remove 136 _, err = blobURL.Delete(ctx, nil) 137 if err != nil { 138 logging.FromContext(ctx).WithField("blob_url", blobURL.URL()).WithError(err).Warn("Failed to delete multipart ids data file") 139 } 140 return ids, nil 141 } 142 143 func getMultipartSize(ctx context.Context, container container.Client, objName string, base64BlockIDs []string) (int64, error) { 144 blobURL := container.NewBlockBlobClient(objName + sizeSuffix) 145 _, err := blobURL.CommitBlockList(ctx, base64BlockIDs, nil) 146 if err != nil { 147 return 0, err 148 } 149 150 downloadResponse, err := blobURL.DownloadStream(ctx, nil) 151 if err != nil { 152 return 0, err 153 } 154 bodyStream := downloadResponse.Body 155 defer func() { 156 _ = bodyStream.Close() 157 }() 158 scanner := bufio.NewScanner(bodyStream) 159 size := 0 160 for scanner.Scan() { 161 s := scanner.Text() 162 stageSize, err := strconv.Atoi(s) 163 if err != nil { 164 return 0, err 165 } 166 size += stageSize 167 } 168 169 // remove 170 _, err = blobURL.Delete(ctx, nil) 171 if err != nil { 172 logging.FromContext(ctx).WithField("blob_url", blobURL.URL()).WithError(err).Warn("Failed to delete multipart size data file") 173 } 174 return int64(size), nil 175 } 176 177 func copyPartRange(ctx context.Context, clientCache *ClientCache, destinationKey, sourceKey BlobURLInfo, startPosition, count int64) (*block.UploadPartResponse, error) { 178 destinationContainer, err := clientCache.NewContainerClient(destinationKey.StorageAccountName, destinationKey.ContainerName) 179 if err != nil { 180 return nil, fmt.Errorf("copy part: get destination client: %w", err) 181 } 182 sourceContainer, err := clientCache.NewContainerClient(sourceKey.StorageAccountName, sourceKey.ContainerName) 183 if err != nil { 184 return nil, fmt.Errorf("copy part: get source client: %w", err) 185 } 186 base64BlockID := generateRandomBlockID() 187 destinationBlob := destinationContainer.NewBlockBlobClient(destinationKey.BlobURL) 188 sourceBlob := sourceContainer.NewBlockBlobClient(sourceKey.BlobURL) 189 190 stageBlockResponse, err := destinationBlob.StageBlockFromURL(ctx, base64BlockID, sourceBlob.URL(), 191 &blockblob.StageBlockFromURLOptions{ 192 Range: blob.HTTPRange{ 193 Offset: startPosition, 194 Count: count, 195 }, 196 }) 197 if err != nil { 198 return nil, err 199 } 200 201 // add size, etag 202 etag := "\"" + hex.EncodeToString(stageBlockResponse.ContentMD5) + "\"" 203 base64Etag := base64.StdEncoding.EncodeToString([]byte(etag)) 204 // stage id data 205 blobIDsBlob := destinationContainer.NewBlockBlobClient(destinationKey.BlobURL + idSuffix) 206 _, err = blobIDsBlob.StageBlock(ctx, base64Etag, streaming.NopCloser(strings.NewReader(base64BlockID+"\n")), nil) 207 if err != nil { 208 return nil, fmt.Errorf("failed staging part data: %w", err) 209 } 210 211 // stage size data 212 sizeData := fmt.Sprintf("%d\n", count) 213 blobSizesBlob := destinationContainer.NewBlockBlobClient(destinationKey.BlobURL + sizeSuffix) 214 _, err = blobSizesBlob.StageBlock(ctx, base64Etag, streaming.NopCloser(strings.NewReader(sizeData)), nil) 215 if err != nil { 216 return nil, fmt.Errorf("failed staging part data: %w", err) 217 } 218 219 return &block.UploadPartResponse{ 220 ETag: strings.Trim(etag, `"`), 221 }, nil 222 } 223 224 func generateRandomBlockID() string { 225 uu := uuid.New() 226 u := [64]byte{} 227 copy(u[:], uu[:]) 228 return base64.StdEncoding.EncodeToString(u[:]) 229 }