github.com/weaviate/weaviate@v1.24.6/modules/backup-azure/client.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package modstgazure
    13  
    14  import (
    15  	"bytes"
    16  	"context"
    17  	"fmt"
    18  	"io"
    19  	"os"
    20  	"path"
    21  	"strings"
    22  	"time"
    23  
    24  	"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
    25  	"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
    26  	"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
    27  	"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror"
    28  	"github.com/pkg/errors"
    29  	"github.com/weaviate/weaviate/entities/backup"
    30  )
    31  
    32  type azureClient struct {
    33  	client     *azblob.Client
    34  	config     clientConfig
    35  	serviceURL string
    36  	dataPath   string
    37  }
    38  
    39  func newClient(ctx context.Context, config *clientConfig, dataPath string) (*azureClient, error) {
    40  	connectionString := os.Getenv("AZURE_STORAGE_CONNECTION_STRING")
    41  	if connectionString != "" {
    42  		client, err := azblob.NewClientFromConnectionString(connectionString, nil)
    43  		if err != nil {
    44  			return nil, errors.Wrap(err, "create client using connection string")
    45  		}
    46  		serviceURL := ""
    47  		connectionStrings := strings.Split(connectionString, ";")
    48  		for _, str := range connectionStrings {
    49  			if strings.HasPrefix(str, "BlobEndpoint") {
    50  				blobEndpoint := strings.Split(str, "=")
    51  				if len(blobEndpoint) > 1 {
    52  					serviceURL = blobEndpoint[1]
    53  					if !strings.HasSuffix(serviceURL, "/") {
    54  						serviceURL = serviceURL + "/"
    55  					}
    56  				}
    57  			}
    58  		}
    59  		return &azureClient{client, *config, serviceURL, dataPath}, nil
    60  	}
    61  
    62  	// Your account name and key can be obtained from the Azure Portal.
    63  	accountName := os.Getenv("AZURE_STORAGE_ACCOUNT")
    64  	accountKey := os.Getenv("AZURE_STORAGE_KEY")
    65  
    66  	if accountName == "" {
    67  		return nil, errors.New("AZURE_STORAGE_ACCOUNT must be set")
    68  	}
    69  
    70  	// The service URL for blob endpoints is usually in the form: http(s)://<account>.blob.core.windows.net/
    71  	serviceURL := fmt.Sprintf("https://%s.blob.core.windows.net/", accountName)
    72  
    73  	if accountKey != "" {
    74  		cred, err := azblob.NewSharedKeyCredential(accountName, accountKey)
    75  		if err != nil {
    76  			return nil, err
    77  		}
    78  
    79  		client, err := azblob.NewClientWithSharedKeyCredential(serviceURL, cred, nil)
    80  		if err != nil {
    81  			return nil, err
    82  		}
    83  		return &azureClient{client, *config, serviceURL, dataPath}, nil
    84  	}
    85  
    86  	options := &azblob.ClientOptions{
    87  		ClientOptions: policy.ClientOptions{
    88  			Retry: policy.RetryOptions{
    89  				MaxRetries:    3,
    90  				RetryDelay:    4 * time.Second,
    91  				MaxRetryDelay: 120 * time.Second,
    92  			},
    93  		},
    94  	}
    95  
    96  	client, err := azblob.NewClientWithNoCredential(serviceURL, options)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  	return &azureClient{client, *config, serviceURL, dataPath}, nil
   101  }
   102  
   103  func (a *azureClient) HomeDir(backupID string) string {
   104  	return a.serviceURL + path.Join(a.config.Container, a.makeObjectName(backupID))
   105  }
   106  
   107  func (a *azureClient) makeObjectName(parts ...string) string {
   108  	base := path.Join(parts...)
   109  	return path.Join(a.config.BackupPath, base)
   110  }
   111  
   112  func (a *azureClient) GetObject(ctx context.Context, backupID, key string) ([]byte, error) {
   113  	objectName := a.makeObjectName(backupID, key)
   114  
   115  	blobDownloadResponse, err := a.client.DownloadStream(ctx, a.config.Container, objectName, nil)
   116  	if err != nil {
   117  		if bloberror.HasCode(err, bloberror.BlobNotFound) {
   118  			return nil, backup.NewErrNotFound(errors.Wrapf(err, "get object '%s'", objectName))
   119  		}
   120  		return nil, backup.NewErrInternal(errors.Wrapf(err, "download stream for object '%s'", objectName))
   121  	}
   122  
   123  	reader := blobDownloadResponse.Body
   124  	downloadData, err := io.ReadAll(reader)
   125  	errClose := reader.Close()
   126  	if errClose != nil {
   127  		return nil, backup.NewErrInternal(errors.Wrapf(errClose, "close stream for object '%s'", objectName))
   128  	}
   129  	if err != nil {
   130  		return nil, backup.NewErrInternal(errors.Wrapf(err, "read stream for object '%s'", objectName))
   131  	}
   132  
   133  	return downloadData, nil
   134  }
   135  
   136  func (a *azureClient) PutFile(ctx context.Context, backupID, key, srcPath string) error {
   137  	filePath := path.Join(a.dataPath, srcPath)
   138  	file, err := os.Open(filePath)
   139  	if err != nil {
   140  		return backup.NewErrInternal(errors.Wrapf(err, "open file: %q", filePath))
   141  	}
   142  	defer file.Close()
   143  
   144  	objectName := a.makeObjectName(backupID, key)
   145  	_, err = a.client.UploadFile(ctx,
   146  		a.config.Container,
   147  		objectName,
   148  		file,
   149  		&azblob.UploadFileOptions{
   150  			Metadata: map[string]*string{"backupid": to.Ptr(backupID)},
   151  			Tags:     map[string]string{"backupid": backupID},
   152  		})
   153  	if err != nil {
   154  		return backup.NewErrInternal(errors.Wrapf(err, "upload file for object '%s'", objectName))
   155  	}
   156  
   157  	return nil
   158  }
   159  
   160  func (a *azureClient) PutObject(ctx context.Context, backupID, key string, data []byte) error {
   161  	objectName := a.makeObjectName(backupID, key)
   162  
   163  	reader := bytes.NewReader(data)
   164  	_, err := a.client.UploadStream(ctx,
   165  		a.config.Container,
   166  		objectName,
   167  		reader,
   168  		&azblob.UploadStreamOptions{
   169  			Metadata: map[string]*string{"backupid": to.Ptr(backupID)},
   170  			Tags:     map[string]string{"backupid": backupID},
   171  		})
   172  	if err != nil {
   173  		return backup.NewErrInternal(errors.Wrapf(err, "upload stream for object '%s'", objectName))
   174  	}
   175  
   176  	return nil
   177  }
   178  
   179  func (a *azureClient) Initialize(ctx context.Context, backupID string) error {
   180  	key := "access-check"
   181  
   182  	if err := a.PutObject(ctx, backupID, key, []byte("")); err != nil {
   183  		return errors.Wrap(err, "failed to access-check Azure backup module")
   184  	}
   185  
   186  	objectName := a.makeObjectName(backupID, key)
   187  	if _, err := a.client.DeleteBlob(ctx, a.config.Container, objectName, nil); err != nil {
   188  		return errors.Wrap(err, "failed to remove access-check Azure backup module")
   189  	}
   190  
   191  	return nil
   192  }
   193  
   194  func (a *azureClient) WriteToFile(ctx context.Context, backupID, key, destPath string) error {
   195  	dir := path.Dir(destPath)
   196  	if err := os.MkdirAll(dir, os.ModePerm); err != nil {
   197  		return errors.Wrapf(err, "make dir '%s'", dir)
   198  	}
   199  
   200  	file, err := os.Create(destPath)
   201  	if err != nil {
   202  		return backup.NewErrInternal(errors.Wrapf(err, "create file: %q", destPath))
   203  	}
   204  	defer file.Close()
   205  
   206  	objectName := a.makeObjectName(backupID, key)
   207  	_, err = a.client.DownloadFile(ctx, a.config.Container, objectName, file, nil)
   208  	if err != nil {
   209  		if bloberror.HasCode(err, bloberror.BlobNotFound) {
   210  			return backup.NewErrNotFound(errors.Wrapf(err, "get object '%s'", objectName))
   211  		}
   212  		return backup.NewErrInternal(errors.Wrapf(err, "download file for object '%s'", objectName))
   213  	}
   214  
   215  	return nil
   216  }
   217  
   218  func (a *azureClient) Write(ctx context.Context, backupID, key string, r io.ReadCloser) (written int64, err error) {
   219  	path := a.makeObjectName(backupID, key)
   220  	reader := &reader{src: r}
   221  	defer func() {
   222  		r.Close()
   223  		written = int64(reader.count)
   224  	}()
   225  
   226  	if _, err = a.client.UploadStream(ctx,
   227  		a.config.Container,
   228  		path,
   229  		reader,
   230  		&azblob.UploadStreamOptions{
   231  			Metadata: map[string]*string{"backupid": to.Ptr(backupID)},
   232  			Tags:     map[string]string{"backupid": backupID},
   233  		}); err != nil {
   234  		err = fmt.Errorf("upload stream %q: %w", path, err)
   235  	}
   236  
   237  	return
   238  }
   239  
   240  func (a *azureClient) Read(ctx context.Context, backupID, key string, w io.WriteCloser) (int64, error) {
   241  	defer w.Close()
   242  
   243  	path := a.makeObjectName(backupID, key)
   244  	resp, err := a.client.DownloadStream(ctx, a.config.Container, path, nil)
   245  	if err != nil {
   246  		err = fmt.Errorf("find object %q: %w", path, err)
   247  		if bloberror.HasCode(err, bloberror.BlobNotFound) {
   248  			err = backup.NewErrNotFound(err)
   249  		}
   250  		return 0, err
   251  	}
   252  	defer resp.Body.Close()
   253  
   254  	read, err := io.Copy(w, resp.Body)
   255  	if err != nil {
   256  		return read, fmt.Errorf("io.copy %q: %w", path, err)
   257  	}
   258  
   259  	return read, nil
   260  }
   261  
   262  func (a *azureClient) SourceDataPath() string {
   263  	return a.dataPath
   264  }
   265  
   266  // reader is a wrapper used to count number of written bytes
   267  // Unlike GCS and S3 Azure Interface does not provide this information
   268  type reader struct {
   269  	src   io.Reader
   270  	count int
   271  }
   272  
   273  func (r *reader) Read(p []byte) (n int, err error) {
   274  	n, err = r.src.Read(p)
   275  	r.count += n
   276  	return
   277  }