github.com/rudderlabs/rudder-go-kit@v0.30.0/filemanager/azureblobmanager.go (about)

     1  package filemanager
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"net/url"
     9  	"os"
    10  	"path"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/Azure/azure-storage-blob-go/azblob"
    15  
    16  	"github.com/rudderlabs/rudder-go-kit/logger"
    17  )
    18  
    19  type AzureBlobConfig struct {
    20  	Container      string
    21  	Prefix         string
    22  	AccountName    string
    23  	AccountKey     string
    24  	SASToken       string
    25  	EndPoint       *string
    26  	ForcePathStyle *bool
    27  	DisableSSL     *bool
    28  	UseSASTokens   bool
    29  }
    30  
    31  // NewAzureBlobManager creates a new file manager for Azure Blob Storage
    32  func NewAzureBlobManager(config map[string]interface{}, log logger.Logger, defaultTimeout func() time.Duration) (*azureBlobManager, error) {
    33  	return &azureBlobManager{
    34  		baseManager: &baseManager{
    35  			logger:         log,
    36  			defaultTimeout: defaultTimeout,
    37  		},
    38  		config: azureBlobConfig(config),
    39  	}, nil
    40  }
    41  
    42  func (manager *azureBlobManager) ListFilesWithPrefix(ctx context.Context, startAfter, prefix string, maxItems int64) ListSession {
    43  	return &azureBlobListSession{
    44  		baseListSession: &baseListSession{
    45  			ctx:        ctx,
    46  			startAfter: startAfter,
    47  			prefix:     prefix,
    48  			maxItems:   maxItems,
    49  		},
    50  		manager: manager,
    51  	}
    52  }
    53  
    54  // Upload passed in file to Azure Blob Storage
    55  func (manager *azureBlobManager) Upload(ctx context.Context, file *os.File, prefixes ...string) (UploadedFile, error) {
    56  	containerURL, err := manager.getContainerURL()
    57  	if err != nil {
    58  		return UploadedFile{}, err
    59  	}
    60  
    61  	ctx, cancel := context.WithTimeout(ctx, manager.getTimeout())
    62  	defer cancel()
    63  
    64  	if manager.createContainer() {
    65  		_, err = containerURL.Create(ctx, azblob.Metadata{}, azblob.PublicAccessNone)
    66  		err = manager.suppressMinorErrors(err)
    67  		if err != nil {
    68  			return UploadedFile{}, err
    69  		}
    70  	}
    71  
    72  	fileName := path.Join(manager.config.Prefix, path.Join(prefixes...), path.Base(file.Name()))
    73  
    74  	// Here's how to upload a blob.
    75  	blobURL := containerURL.NewBlockBlobURL(fileName)
    76  	_, err = azblob.UploadFileToBlockBlob(ctx, file, blobURL, azblob.UploadToBlockBlobOptions{
    77  		BlockSize:   4 * 1024 * 1024,
    78  		Parallelism: 16,
    79  	})
    80  	if err != nil {
    81  		return UploadedFile{}, err
    82  	}
    83  
    84  	return UploadedFile{Location: manager.blobLocation(&blobURL), ObjectName: fileName}, nil
    85  }
    86  
    87  func (manager *azureBlobManager) Download(ctx context.Context, output *os.File, key string) error {
    88  	containerURL, err := manager.getContainerURL()
    89  	if err != nil {
    90  		return err
    91  	}
    92  
    93  	blobURL := containerURL.NewBlockBlobURL(key)
    94  
    95  	ctx, cancel := context.WithTimeout(ctx, manager.getTimeout())
    96  	defer cancel()
    97  
    98  	// Here's how to download the blob
    99  	downloadResponse, err := blobURL.Download(ctx, 0, azblob.CountToEnd, azblob.BlobAccessConditions{}, false, azblob.ClientProvidedKeyOptions{})
   100  	if err != nil {
   101  		return err
   102  	}
   103  
   104  	// NOTE: automatically retries are performed if the connection fails
   105  	bodyStream := downloadResponse.Body(azblob.RetryReaderOptions{MaxRetryRequests: 20})
   106  
   107  	// read the body into a buffer
   108  	downloadedData := bytes.Buffer{}
   109  	_, err = downloadedData.ReadFrom(bodyStream)
   110  	if err != nil {
   111  		return err
   112  	}
   113  
   114  	_, err = output.Write(downloadedData.Bytes())
   115  	return err
   116  }
   117  
   118  func (manager *azureBlobManager) Delete(ctx context.Context, keys []string) (err error) {
   119  	containerURL, err := manager.getContainerURL()
   120  	if err != nil {
   121  		return err
   122  	}
   123  
   124  	for _, key := range keys {
   125  		blobURL := containerURL.NewBlockBlobURL(key)
   126  
   127  		_ctx, cancel := context.WithTimeout(ctx, manager.getTimeout())
   128  		_, err := blobURL.Delete(_ctx, azblob.DeleteSnapshotsOptionNone, azblob.BlobAccessConditions{})
   129  		if err != nil {
   130  			cancel()
   131  			return err
   132  		}
   133  		cancel()
   134  	}
   135  	return
   136  }
   137  
   138  func (manager *azureBlobManager) Prefix() string {
   139  	return manager.config.Prefix
   140  }
   141  
   142  func (manager *azureBlobManager) GetObjectNameFromLocation(location string) (string, error) {
   143  	strToken := strings.Split(location, fmt.Sprintf("%s/", manager.config.Container))
   144  	return strToken[len(strToken)-1], nil
   145  }
   146  
   147  func (manager *azureBlobManager) GetDownloadKeyFromFileLocation(location string) string {
   148  	str := strings.Split(location, fmt.Sprintf("%s/", manager.config.Container))
   149  	return str[len(str)-1]
   150  }
   151  
   152  func (manager *azureBlobManager) suppressMinorErrors(err error) error {
   153  	if err != nil {
   154  		if storageError, ok := err.(azblob.StorageError); ok { // This error is a Service-specific
   155  			switch storageError.ServiceCode() { // Compare serviceCode to ServiceCodeXxx constants
   156  			case azblob.ServiceCodeContainerAlreadyExists:
   157  				manager.logger.Debug("Received 409. Container already exists")
   158  				return nil
   159  			}
   160  		}
   161  	}
   162  	return err
   163  }
   164  
   165  func (manager *azureBlobManager) getBaseURL() *url.URL {
   166  	protocol := "https"
   167  	if manager.config.DisableSSL != nil && *manager.config.DisableSSL {
   168  		protocol = "http"
   169  	}
   170  
   171  	endpoint := "blob.core.windows.net"
   172  	if manager.config.EndPoint != nil && *manager.config.EndPoint != "" {
   173  		endpoint = *manager.config.EndPoint
   174  	}
   175  
   176  	baseURL := url.URL{
   177  		Scheme: protocol,
   178  		Host:   fmt.Sprintf("%s.%s", manager.config.AccountName, endpoint),
   179  	}
   180  
   181  	if manager.config.UseSASTokens {
   182  		baseURL.RawQuery = manager.config.SASToken
   183  	}
   184  
   185  	if manager.config.ForcePathStyle != nil && *manager.config.ForcePathStyle {
   186  		baseURL.Host = endpoint
   187  		baseURL.Path = fmt.Sprintf("/%s/", manager.config.AccountName)
   188  	}
   189  
   190  	return &baseURL
   191  }
   192  
   193  func (manager *azureBlobManager) getContainerURL() (azblob.ContainerURL, error) {
   194  	if manager.config.Container == "" {
   195  		return azblob.ContainerURL{}, errors.New("no container configured")
   196  	}
   197  
   198  	credential, err := manager.getCredentials()
   199  	if err != nil {
   200  		return azblob.ContainerURL{}, err
   201  	}
   202  
   203  	p := azblob.NewPipeline(credential, azblob.PipelineOptions{})
   204  
   205  	// From the Azure portal, get your storage account blob service URL endpoint.
   206  	baseURL := manager.getBaseURL()
   207  	serviceURL := azblob.NewServiceURL(*baseURL, p)
   208  	containerURL := serviceURL.NewContainerURL(manager.config.Container)
   209  
   210  	return containerURL, nil
   211  }
   212  
   213  func (manager *azureBlobManager) getCredentials() (azblob.Credential, error) {
   214  	if manager.config.UseSASTokens {
   215  		return azblob.NewAnonymousCredential(), nil
   216  	}
   217  
   218  	accountName, accountKey := manager.config.AccountName, manager.config.AccountKey
   219  	if accountName == "" || accountKey == "" {
   220  		return nil, errors.New("either accountName or accountKey is empty")
   221  	}
   222  
   223  	// Create a default request pipeline using your storage account name and account key.
   224  	return azblob.NewSharedKeyCredential(accountName, accountKey)
   225  }
   226  
   227  func (manager *azureBlobManager) createContainer() bool {
   228  	return !manager.config.UseSASTokens
   229  }
   230  
   231  func (manager *azureBlobManager) blobLocation(blobURL *azblob.BlockBlobURL) string {
   232  	if !manager.config.UseSASTokens {
   233  		return blobURL.String()
   234  	}
   235  
   236  	// Reset SAS Query parameters
   237  	blobURLParts := azblob.NewBlobURLParts(blobURL.URL())
   238  	blobURLParts.SAS = azblob.SASQueryParameters{}
   239  	newBlobURL := blobURLParts.URL()
   240  	return newBlobURL.String()
   241  }
   242  
   243  type azureBlobManager struct {
   244  	*baseManager
   245  	config *AzureBlobConfig
   246  }
   247  
   248  func azureBlobConfig(config map[string]interface{}) *AzureBlobConfig {
   249  	var containerName, accountName, accountKey, sasToken, prefix string
   250  	var endPoint *string
   251  	var forcePathStyle, disableSSL *bool
   252  	var useSASTokens bool
   253  	if config["containerName"] != nil {
   254  		tmp, ok := config["containerName"].(string)
   255  		if ok {
   256  			containerName = tmp
   257  		}
   258  	}
   259  	if config["prefix"] != nil {
   260  		tmp, ok := config["prefix"].(string)
   261  		if ok {
   262  			prefix = tmp
   263  		}
   264  	}
   265  	if config["accountName"] != nil {
   266  		tmp, ok := config["accountName"].(string)
   267  		if ok {
   268  			accountName = tmp
   269  		}
   270  	}
   271  	if config["useSASTokens"] != nil {
   272  		tmp, ok := config["useSASTokens"].(bool)
   273  		if ok {
   274  			useSASTokens = tmp
   275  		}
   276  	}
   277  	if config["sasToken"] != nil {
   278  		tmp, ok := config["sasToken"].(string)
   279  		if ok {
   280  			sasToken = strings.TrimPrefix(tmp, "?")
   281  		}
   282  	}
   283  	if config["accountKey"] != nil {
   284  		tmp, ok := config["accountKey"].(string)
   285  		if ok {
   286  			accountKey = tmp
   287  		}
   288  	}
   289  	if config["endPoint"] != nil {
   290  		tmp, ok := config["endPoint"].(string)
   291  		if ok {
   292  			endPoint = &tmp
   293  		}
   294  	}
   295  	if config["forcePathStyle"] != nil {
   296  		tmp, ok := config["forcePathStyle"].(bool)
   297  		if ok {
   298  			forcePathStyle = &tmp
   299  		}
   300  	}
   301  	if config["disableSSL"] != nil {
   302  		tmp, ok := config["disableSSL"].(bool)
   303  		if ok {
   304  			disableSSL = &tmp
   305  		}
   306  	}
   307  	return &AzureBlobConfig{
   308  		Container:      containerName,
   309  		Prefix:         prefix,
   310  		AccountName:    accountName,
   311  		AccountKey:     accountKey,
   312  		UseSASTokens:   useSASTokens,
   313  		SASToken:       sasToken,
   314  		EndPoint:       endPoint,
   315  		ForcePathStyle: forcePathStyle,
   316  		DisableSSL:     disableSSL,
   317  	}
   318  }
   319  
   320  type azureBlobListSession struct {
   321  	*baseListSession
   322  	manager *azureBlobManager
   323  
   324  	Marker azblob.Marker
   325  }
   326  
   327  func (l *azureBlobListSession) Next() (fileObjects []*FileInfo, err error) {
   328  	manager := l.manager
   329  	maxItems := l.maxItems
   330  
   331  	containerURL, err := manager.getContainerURL()
   332  	if err != nil {
   333  		return []*FileInfo{}, err
   334  	}
   335  
   336  	blobListingDetails := azblob.BlobListingDetails{
   337  		Metadata: true,
   338  	}
   339  	segmentOptions := azblob.ListBlobsSegmentOptions{
   340  		Details:    blobListingDetails,
   341  		Prefix:     l.prefix,
   342  		MaxResults: int32(l.maxItems),
   343  	}
   344  
   345  	ctx, cancel := context.WithTimeout(l.ctx, manager.getTimeout())
   346  	defer cancel()
   347  
   348  	// List the blobs in the container
   349  	var response *azblob.ListBlobsFlatSegmentResponse
   350  
   351  	// Checking if maxItems > 0 to avoid function calls which expect only maxItems to be returned and not more in the code
   352  	for maxItems > 0 && l.Marker.NotDone() {
   353  		response, err = containerURL.ListBlobsFlatSegment(ctx, l.Marker, segmentOptions)
   354  		if err != nil {
   355  			return
   356  		}
   357  		l.Marker = response.NextMarker
   358  
   359  		fileObjects = make([]*FileInfo, 0)
   360  		for idx := range response.Segment.BlobItems {
   361  			if strings.Compare(response.Segment.BlobItems[idx].Name, l.startAfter) > 0 {
   362  				fileObjects = append(fileObjects, &FileInfo{response.Segment.BlobItems[idx].Name, response.Segment.BlobItems[idx].Properties.LastModified})
   363  				maxItems--
   364  			}
   365  		}
   366  	}
   367  	return
   368  }