github.com/weaviate/weaviate@v1.24.6/modules/backup-azure/client.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package modstgazure 13 14 import ( 15 "bytes" 16 "context" 17 "fmt" 18 "io" 19 "os" 20 "path" 21 "strings" 22 "time" 23 24 "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" 25 "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" 26 "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" 27 "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror" 28 "github.com/pkg/errors" 29 "github.com/weaviate/weaviate/entities/backup" 30 ) 31 32 type azureClient struct { 33 client *azblob.Client 34 config clientConfig 35 serviceURL string 36 dataPath string 37 } 38 39 func newClient(ctx context.Context, config *clientConfig, dataPath string) (*azureClient, error) { 40 connectionString := os.Getenv("AZURE_STORAGE_CONNECTION_STRING") 41 if connectionString != "" { 42 client, err := azblob.NewClientFromConnectionString(connectionString, nil) 43 if err != nil { 44 return nil, errors.Wrap(err, "create client using connection string") 45 } 46 serviceURL := "" 47 connectionStrings := strings.Split(connectionString, ";") 48 for _, str := range connectionStrings { 49 if strings.HasPrefix(str, "BlobEndpoint") { 50 blobEndpoint := strings.Split(str, "=") 51 if len(blobEndpoint) > 1 { 52 serviceURL = blobEndpoint[1] 53 if !strings.HasSuffix(serviceURL, "/") { 54 serviceURL = serviceURL + "/" 55 } 56 } 57 } 58 } 59 return &azureClient{client, *config, serviceURL, dataPath}, nil 60 } 61 62 // Your account name and key can be obtained from the Azure Portal. 63 accountName := os.Getenv("AZURE_STORAGE_ACCOUNT") 64 accountKey := os.Getenv("AZURE_STORAGE_KEY") 65 66 if accountName == "" { 67 return nil, errors.New("AZURE_STORAGE_ACCOUNT must be set") 68 } 69 70 // The service URL for blob endpoints is usually in the form: http(s)://<account>.blob.core.windows.net/ 71 serviceURL := fmt.Sprintf("https://%s.blob.core.windows.net/", accountName) 72 73 if accountKey != "" { 74 cred, err := azblob.NewSharedKeyCredential(accountName, accountKey) 75 if err != nil { 76 return nil, err 77 } 78 79 client, err := azblob.NewClientWithSharedKeyCredential(serviceURL, cred, nil) 80 if err != nil { 81 return nil, err 82 } 83 return &azureClient{client, *config, serviceURL, dataPath}, nil 84 } 85 86 options := &azblob.ClientOptions{ 87 ClientOptions: policy.ClientOptions{ 88 Retry: policy.RetryOptions{ 89 MaxRetries: 3, 90 RetryDelay: 4 * time.Second, 91 MaxRetryDelay: 120 * time.Second, 92 }, 93 }, 94 } 95 96 client, err := azblob.NewClientWithNoCredential(serviceURL, options) 97 if err != nil { 98 return nil, err 99 } 100 return &azureClient{client, *config, serviceURL, dataPath}, nil 101 } 102 103 func (a *azureClient) HomeDir(backupID string) string { 104 return a.serviceURL + path.Join(a.config.Container, a.makeObjectName(backupID)) 105 } 106 107 func (a *azureClient) makeObjectName(parts ...string) string { 108 base := path.Join(parts...) 109 return path.Join(a.config.BackupPath, base) 110 } 111 112 func (a *azureClient) GetObject(ctx context.Context, backupID, key string) ([]byte, error) { 113 objectName := a.makeObjectName(backupID, key) 114 115 blobDownloadResponse, err := a.client.DownloadStream(ctx, a.config.Container, objectName, nil) 116 if err != nil { 117 if bloberror.HasCode(err, bloberror.BlobNotFound) { 118 return nil, backup.NewErrNotFound(errors.Wrapf(err, "get object '%s'", objectName)) 119 } 120 return nil, backup.NewErrInternal(errors.Wrapf(err, "download stream for object '%s'", objectName)) 121 } 122 123 reader := blobDownloadResponse.Body 124 downloadData, err := io.ReadAll(reader) 125 errClose := reader.Close() 126 if errClose != nil { 127 return nil, backup.NewErrInternal(errors.Wrapf(errClose, "close stream for object '%s'", objectName)) 128 } 129 if err != nil { 130 return nil, backup.NewErrInternal(errors.Wrapf(err, "read stream for object '%s'", objectName)) 131 } 132 133 return downloadData, nil 134 } 135 136 func (a *azureClient) PutFile(ctx context.Context, backupID, key, srcPath string) error { 137 filePath := path.Join(a.dataPath, srcPath) 138 file, err := os.Open(filePath) 139 if err != nil { 140 return backup.NewErrInternal(errors.Wrapf(err, "open file: %q", filePath)) 141 } 142 defer file.Close() 143 144 objectName := a.makeObjectName(backupID, key) 145 _, err = a.client.UploadFile(ctx, 146 a.config.Container, 147 objectName, 148 file, 149 &azblob.UploadFileOptions{ 150 Metadata: map[string]*string{"backupid": to.Ptr(backupID)}, 151 Tags: map[string]string{"backupid": backupID}, 152 }) 153 if err != nil { 154 return backup.NewErrInternal(errors.Wrapf(err, "upload file for object '%s'", objectName)) 155 } 156 157 return nil 158 } 159 160 func (a *azureClient) PutObject(ctx context.Context, backupID, key string, data []byte) error { 161 objectName := a.makeObjectName(backupID, key) 162 163 reader := bytes.NewReader(data) 164 _, err := a.client.UploadStream(ctx, 165 a.config.Container, 166 objectName, 167 reader, 168 &azblob.UploadStreamOptions{ 169 Metadata: map[string]*string{"backupid": to.Ptr(backupID)}, 170 Tags: map[string]string{"backupid": backupID}, 171 }) 172 if err != nil { 173 return backup.NewErrInternal(errors.Wrapf(err, "upload stream for object '%s'", objectName)) 174 } 175 176 return nil 177 } 178 179 func (a *azureClient) Initialize(ctx context.Context, backupID string) error { 180 key := "access-check" 181 182 if err := a.PutObject(ctx, backupID, key, []byte("")); err != nil { 183 return errors.Wrap(err, "failed to access-check Azure backup module") 184 } 185 186 objectName := a.makeObjectName(backupID, key) 187 if _, err := a.client.DeleteBlob(ctx, a.config.Container, objectName, nil); err != nil { 188 return errors.Wrap(err, "failed to remove access-check Azure backup module") 189 } 190 191 return nil 192 } 193 194 func (a *azureClient) WriteToFile(ctx context.Context, backupID, key, destPath string) error { 195 dir := path.Dir(destPath) 196 if err := os.MkdirAll(dir, os.ModePerm); err != nil { 197 return errors.Wrapf(err, "make dir '%s'", dir) 198 } 199 200 file, err := os.Create(destPath) 201 if err != nil { 202 return backup.NewErrInternal(errors.Wrapf(err, "create file: %q", destPath)) 203 } 204 defer file.Close() 205 206 objectName := a.makeObjectName(backupID, key) 207 _, err = a.client.DownloadFile(ctx, a.config.Container, objectName, file, nil) 208 if err != nil { 209 if bloberror.HasCode(err, bloberror.BlobNotFound) { 210 return backup.NewErrNotFound(errors.Wrapf(err, "get object '%s'", objectName)) 211 } 212 return backup.NewErrInternal(errors.Wrapf(err, "download file for object '%s'", objectName)) 213 } 214 215 return nil 216 } 217 218 func (a *azureClient) Write(ctx context.Context, backupID, key string, r io.ReadCloser) (written int64, err error) { 219 path := a.makeObjectName(backupID, key) 220 reader := &reader{src: r} 221 defer func() { 222 r.Close() 223 written = int64(reader.count) 224 }() 225 226 if _, err = a.client.UploadStream(ctx, 227 a.config.Container, 228 path, 229 reader, 230 &azblob.UploadStreamOptions{ 231 Metadata: map[string]*string{"backupid": to.Ptr(backupID)}, 232 Tags: map[string]string{"backupid": backupID}, 233 }); err != nil { 234 err = fmt.Errorf("upload stream %q: %w", path, err) 235 } 236 237 return 238 } 239 240 func (a *azureClient) Read(ctx context.Context, backupID, key string, w io.WriteCloser) (int64, error) { 241 defer w.Close() 242 243 path := a.makeObjectName(backupID, key) 244 resp, err := a.client.DownloadStream(ctx, a.config.Container, path, nil) 245 if err != nil { 246 err = fmt.Errorf("find object %q: %w", path, err) 247 if bloberror.HasCode(err, bloberror.BlobNotFound) { 248 err = backup.NewErrNotFound(err) 249 } 250 return 0, err 251 } 252 defer resp.Body.Close() 253 254 read, err := io.Copy(w, resp.Body) 255 if err != nil { 256 return read, fmt.Errorf("io.copy %q: %w", path, err) 257 } 258 259 return read, nil 260 } 261 262 func (a *azureClient) SourceDataPath() string { 263 return a.dataPath 264 } 265 266 // reader is a wrapper used to count number of written bytes 267 // Unlike GCS and S3 Azure Interface does not provide this information 268 type reader struct { 269 src io.Reader 270 count int 271 } 272 273 func (r *reader) Read(p []byte) (n int, err error) { 274 n, err = r.src.Read(p) 275 r.count += n 276 return 277 }