github.com/diggerhq/digger/libs@v0.0.0-20240604170430-9d61cdf01cc5/locking/azure/storage_account.go (about)

     1  package azure
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"os"
     8  	"strings"
     9  
    10  	"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
    11  	"github.com/Azure/azure-sdk-for-go/sdk/data/aztables"
    12  )
    13  
    14  const (
    15  	TABLE_NAME = "DIGGERLOCK"
    16  )
    17  
    18  var (
    19  	SERVICE_URL_FORMAT = "https://%s.table.core.windows.net"
    20  )
    21  
    22  type StorageAccount struct {
    23  	tableClient *aztables.Client
    24  	svcClient   *aztables.ServiceClient
    25  }
    26  
    27  func NewStorageAccountLock() (*StorageAccount, error) {
    28  	authMethod := os.Getenv("DIGGER_AZURE_AUTH_METHOD")
    29  	if authMethod == "" {
    30  		return nil, fmt.Errorf("'DIGGER_AZURE_AUTH_METHOD' environment variable must be set to either 'SHARED_KEY' or 'CONNECTION_STRING' or 'CLIENT_SECRET'")
    31  	}
    32  
    33  	svcClient, err := getServiceClient(authMethod)
    34  	if err != nil {
    35  		return nil, err
    36  	}
    37  
    38  	sal := &StorageAccount{
    39  		svcClient:   svcClient,
    40  		tableClient: svcClient.NewClient(TABLE_NAME),
    41  	}
    42  
    43  	if err := sal.createTableIfNotExists(); err != nil {
    44  		return nil, fmt.Errorf("error while creating table: %v", err)
    45  	}
    46  	return sal, nil
    47  }
    48  
    49  func (sal *StorageAccount) Lock(transactionId int, resource string) (bool, error) {
    50  	resource = normalizeResourceName(resource)
    51  	entity := aztables.EDMEntity{
    52  		Properties: map[string]interface{}{
    53  			"transaction_id": transactionId,
    54  		},
    55  		Entity: aztables.Entity{
    56  			PartitionKey: "digger",
    57  			RowKey:       resource,
    58  		},
    59  	}
    60  	b, err := json.Marshal(entity)
    61  	if err != nil {
    62  		return false, fmt.Errorf("could not marshall entity: %v", err)
    63  	}
    64  
    65  	_, err = sal.tableClient.AddEntity(context.Background(), b, nil)
    66  	if err != nil {
    67  		if strings.Contains(err.Error(), "EntityAlreadyExists") {
    68  			return false, nil
    69  		}
    70  		return false, fmt.Errorf("could not add entity: \n%v", err)
    71  	}
    72  
    73  	return true, nil
    74  }
    75  
    76  func (sal *StorageAccount) Unlock(resource string) (bool, error) {
    77  	resource = normalizeResourceName(resource)
    78  	_, err := sal.tableClient.DeleteEntity(context.Background(), "digger", resource, nil)
    79  	if err != nil {
    80  		return false, fmt.Errorf("could not delete lock: %v", err)
    81  	}
    82  
    83  	return true, nil
    84  }
    85  
    86  func (sal *StorageAccount) GetLock(resource string) (*int, error) {
    87  	resource = normalizeResourceName(resource)
    88  	filterQuery := fmt.Sprintf("PartitionKey eq 'digger' and RowKey eq '%s'", resource)
    89  	selectQuery := "RowKey,PartitionKey,transaction_id"
    90  	listOpts := aztables.ListEntitiesOptions{
    91  		Filter: &filterQuery,
    92  		Select: &selectQuery,
    93  	}
    94  
    95  	entitiesPager := sal.tableClient.NewListEntitiesPager(&listOpts)
    96  	for entitiesPager.More() {
    97  		res, err := entitiesPager.NextPage(context.Background())
    98  		if err != nil {
    99  			return nil, fmt.Errorf("could not retrieve the entities: %v", err)
   100  		}
   101  
   102  		for _, e := range res.Entities {
   103  			var entity aztables.EDMEntity
   104  			err := json.Unmarshal(e, &entity)
   105  			if err != nil {
   106  				return nil, fmt.Errorf("could not unmarshall entity: %v", err)
   107  			}
   108  
   109  			transactionId := int(entity.Properties["transaction_id"].(int32))
   110  			return &transactionId, nil
   111  		}
   112  	}
   113  
   114  	// Lock doesn't exist
   115  	return nil, nil
   116  }
   117  
   118  func getServiceClient(authMethod string) (*aztables.ServiceClient, error) {
   119  	if authMethod == "SHARED_KEY" {
   120  		return getSharedKeySvcClient()
   121  	}
   122  
   123  	if authMethod == "CONNECTION_STRING" {
   124  		return getConnStringSvcClient()
   125  	}
   126  
   127  	if authMethod == "CLIENT_SECRET" {
   128  		return getClientSecretSvcClient()
   129  	}
   130  
   131  	return nil, fmt.Errorf("could not initialize service client, because no valid authentication method was found")
   132  }
   133  
   134  func getSharedKeySvcClient() (*aztables.ServiceClient, error) {
   135  	key := os.Getenv("DIGGER_AZURE_SHARED_KEY")
   136  	saName := os.Getenv("DIGGER_AZURE_SA_NAME")
   137  	if saName == "" || key == "" {
   138  		return nil, fmt.Errorf("you must set 'DIGGER_AZURE_SA_NAME' and 'DIGGER_AZURE_SHARED_KEY' environment variable when using shared key authentication")
   139  	}
   140  
   141  	sharedCreds, err := aztables.NewSharedKeyCredential(saName, key)
   142  	if err != nil {
   143  		return nil, fmt.Errorf("could not create shared key credentials: %v", err)
   144  	}
   145  
   146  	serviceURL := getServiceURL(saName)
   147  	svcClient, err := aztables.NewServiceClientWithSharedKey(serviceURL, sharedCreds, nil)
   148  	if err != nil {
   149  		return nil, fmt.Errorf("could not create service client with shared key authentication: %v", err)
   150  	}
   151  	return svcClient, nil
   152  }
   153  
   154  func getConnStringSvcClient() (*aztables.ServiceClient, error) {
   155  	connStr := os.Getenv("DIGGER_AZURE_CONNECTION_STRING")
   156  	if connStr == "" {
   157  		return nil, fmt.Errorf("you must set 'DIGGER_AZURE_CONNECTION_STRING' when using connection string authentication")
   158  	}
   159  
   160  	svcClient, err := aztables.NewServiceClientFromConnectionString(connStr, nil)
   161  	if err != nil {
   162  		return nil, fmt.Errorf("could not create service client with connection string authentication: %v", err)
   163  	}
   164  	return svcClient, err
   165  }
   166  
   167  func getClientSecretSvcClient() (*aztables.ServiceClient, error) {
   168  	tenantId := os.Getenv("DIGGER_AZURE_TENANT_ID")
   169  	clientId := os.Getenv("DIGGER_AZURE_CLIENT_ID")
   170  	secret := os.Getenv("DIGGER_AZURE_CLIENT_SECRET")
   171  	saName := os.Getenv("DIGGER_AZURE_SA_NAME")
   172  
   173  	if clientId == "" || secret == "" || tenantId == "" || saName == "" {
   174  		return nil, fmt.Errorf("you must set 'DIGGER_AZURE_CLIENT_ID' and 'DIGGER_AZURE_CLIENT_SECRET' and 'DIGGER_AZURE_TENANT_ID' and 'DIGGER_AZURE_SA_NAME' when using client secret authentication")
   175  	}
   176  
   177  	serviceURL := getServiceURL(saName)
   178  	cred, err := azidentity.NewClientSecretCredential(tenantId, clientId, secret, nil)
   179  	if err != nil {
   180  		return nil, fmt.Errorf("could not create create client secret credential: %v", err)
   181  	}
   182  
   183  	svcClient, err := aztables.NewServiceClient(serviceURL, cred, nil)
   184  	if err != nil {
   185  		return nil, fmt.Errorf("could not create service client with client secret authentication: %v", err)
   186  	}
   187  	return svcClient, nil
   188  }
   189  
   190  func (sal *StorageAccount) createTableIfNotExists() error {
   191  	exists, err := sal.isTableExists(TABLE_NAME)
   192  	if err != nil {
   193  		return err
   194  	}
   195  
   196  	if exists {
   197  		return nil
   198  	}
   199  
   200  	// Table doesn't exist, we create it
   201  	_, err = sal.tableClient.CreateTable(context.TODO(), nil)
   202  	if err != nil {
   203  		return fmt.Errorf("could not create table: %v", err)
   204  	}
   205  
   206  	return nil
   207  }
   208  
   209  func (sal *StorageAccount) isTableExists(table string) (bool, error) {
   210  	tablesPager := sal.svcClient.NewListTablesPager(nil)
   211  	for tablesPager.More() {
   212  		res, err := tablesPager.NextPage(context.Background())
   213  		if err != nil {
   214  			return false, fmt.Errorf("could not retrieve the tables: %v", err)
   215  		}
   216  
   217  		for _, t := range res.Tables {
   218  			if *t.Name == table {
   219  				return true, nil
   220  			}
   221  		}
   222  	}
   223  
   224  	return false, nil
   225  }
   226  
   227  func getServiceURL(saName string) string {
   228  	return fmt.Sprintf(SERVICE_URL_FORMAT, saName)
   229  }
   230  
   231  func normalizeResourceName(resourceName string) string {
   232  	resourceName = strings.ReplaceAll(resourceName, "/", "-")
   233  	return strings.ReplaceAll(resourceName, "#", "-")
   234  }