github.com/diggerhq/digger/libs@v0.0.0-20240604170430-9d61cdf01cc5/locking/azure/storage_account.go (about) 1 package azure 2 3 import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "os" 8 "strings" 9 10 "github.com/Azure/azure-sdk-for-go/sdk/azidentity" 11 "github.com/Azure/azure-sdk-for-go/sdk/data/aztables" 12 ) 13 14 const ( 15 TABLE_NAME = "DIGGERLOCK" 16 ) 17 18 var ( 19 SERVICE_URL_FORMAT = "https://%s.table.core.windows.net" 20 ) 21 22 type StorageAccount struct { 23 tableClient *aztables.Client 24 svcClient *aztables.ServiceClient 25 } 26 27 func NewStorageAccountLock() (*StorageAccount, error) { 28 authMethod := os.Getenv("DIGGER_AZURE_AUTH_METHOD") 29 if authMethod == "" { 30 return nil, fmt.Errorf("'DIGGER_AZURE_AUTH_METHOD' environment variable must be set to either 'SHARED_KEY' or 'CONNECTION_STRING' or 'CLIENT_SECRET'") 31 } 32 33 svcClient, err := getServiceClient(authMethod) 34 if err != nil { 35 return nil, err 36 } 37 38 sal := &StorageAccount{ 39 svcClient: svcClient, 40 tableClient: svcClient.NewClient(TABLE_NAME), 41 } 42 43 if err := sal.createTableIfNotExists(); err != nil { 44 return nil, fmt.Errorf("error while creating table: %v", err) 45 } 46 return sal, nil 47 } 48 49 func (sal *StorageAccount) Lock(transactionId int, resource string) (bool, error) { 50 resource = normalizeResourceName(resource) 51 entity := aztables.EDMEntity{ 52 Properties: map[string]interface{}{ 53 "transaction_id": transactionId, 54 }, 55 Entity: aztables.Entity{ 56 PartitionKey: "digger", 57 RowKey: resource, 58 }, 59 } 60 b, err := json.Marshal(entity) 61 if err != nil { 62 return false, fmt.Errorf("could not marshall entity: %v", err) 63 } 64 65 _, err = sal.tableClient.AddEntity(context.Background(), b, nil) 66 if err != nil { 67 if strings.Contains(err.Error(), "EntityAlreadyExists") { 68 return false, nil 69 } 70 return false, fmt.Errorf("could not add entity: \n%v", err) 71 } 72 73 return true, nil 74 } 75 76 func (sal *StorageAccount) Unlock(resource string) (bool, error) { 77 resource = normalizeResourceName(resource) 78 _, err := sal.tableClient.DeleteEntity(context.Background(), "digger", resource, nil) 79 if err != nil { 80 return false, fmt.Errorf("could not delete lock: %v", err) 81 } 82 83 return true, nil 84 } 85 86 func (sal *StorageAccount) GetLock(resource string) (*int, error) { 87 resource = normalizeResourceName(resource) 88 filterQuery := fmt.Sprintf("PartitionKey eq 'digger' and RowKey eq '%s'", resource) 89 selectQuery := "RowKey,PartitionKey,transaction_id" 90 listOpts := aztables.ListEntitiesOptions{ 91 Filter: &filterQuery, 92 Select: &selectQuery, 93 } 94 95 entitiesPager := sal.tableClient.NewListEntitiesPager(&listOpts) 96 for entitiesPager.More() { 97 res, err := entitiesPager.NextPage(context.Background()) 98 if err != nil { 99 return nil, fmt.Errorf("could not retrieve the entities: %v", err) 100 } 101 102 for _, e := range res.Entities { 103 var entity aztables.EDMEntity 104 err := json.Unmarshal(e, &entity) 105 if err != nil { 106 return nil, fmt.Errorf("could not unmarshall entity: %v", err) 107 } 108 109 transactionId := int(entity.Properties["transaction_id"].(int32)) 110 return &transactionId, nil 111 } 112 } 113 114 // Lock doesn't exist 115 return nil, nil 116 } 117 118 func getServiceClient(authMethod string) (*aztables.ServiceClient, error) { 119 if authMethod == "SHARED_KEY" { 120 return getSharedKeySvcClient() 121 } 122 123 if authMethod == "CONNECTION_STRING" { 124 return getConnStringSvcClient() 125 } 126 127 if authMethod == "CLIENT_SECRET" { 128 return getClientSecretSvcClient() 129 } 130 131 return nil, fmt.Errorf("could not initialize service client, because no valid authentication method was found") 132 } 133 134 func getSharedKeySvcClient() (*aztables.ServiceClient, error) { 135 key := os.Getenv("DIGGER_AZURE_SHARED_KEY") 136 saName := os.Getenv("DIGGER_AZURE_SA_NAME") 137 if saName == "" || key == "" { 138 return nil, fmt.Errorf("you must set 'DIGGER_AZURE_SA_NAME' and 'DIGGER_AZURE_SHARED_KEY' environment variable when using shared key authentication") 139 } 140 141 sharedCreds, err := aztables.NewSharedKeyCredential(saName, key) 142 if err != nil { 143 return nil, fmt.Errorf("could not create shared key credentials: %v", err) 144 } 145 146 serviceURL := getServiceURL(saName) 147 svcClient, err := aztables.NewServiceClientWithSharedKey(serviceURL, sharedCreds, nil) 148 if err != nil { 149 return nil, fmt.Errorf("could not create service client with shared key authentication: %v", err) 150 } 151 return svcClient, nil 152 } 153 154 func getConnStringSvcClient() (*aztables.ServiceClient, error) { 155 connStr := os.Getenv("DIGGER_AZURE_CONNECTION_STRING") 156 if connStr == "" { 157 return nil, fmt.Errorf("you must set 'DIGGER_AZURE_CONNECTION_STRING' when using connection string authentication") 158 } 159 160 svcClient, err := aztables.NewServiceClientFromConnectionString(connStr, nil) 161 if err != nil { 162 return nil, fmt.Errorf("could not create service client with connection string authentication: %v", err) 163 } 164 return svcClient, err 165 } 166 167 func getClientSecretSvcClient() (*aztables.ServiceClient, error) { 168 tenantId := os.Getenv("DIGGER_AZURE_TENANT_ID") 169 clientId := os.Getenv("DIGGER_AZURE_CLIENT_ID") 170 secret := os.Getenv("DIGGER_AZURE_CLIENT_SECRET") 171 saName := os.Getenv("DIGGER_AZURE_SA_NAME") 172 173 if clientId == "" || secret == "" || tenantId == "" || saName == "" { 174 return nil, fmt.Errorf("you must set 'DIGGER_AZURE_CLIENT_ID' and 'DIGGER_AZURE_CLIENT_SECRET' and 'DIGGER_AZURE_TENANT_ID' and 'DIGGER_AZURE_SA_NAME' when using client secret authentication") 175 } 176 177 serviceURL := getServiceURL(saName) 178 cred, err := azidentity.NewClientSecretCredential(tenantId, clientId, secret, nil) 179 if err != nil { 180 return nil, fmt.Errorf("could not create create client secret credential: %v", err) 181 } 182 183 svcClient, err := aztables.NewServiceClient(serviceURL, cred, nil) 184 if err != nil { 185 return nil, fmt.Errorf("could not create service client with client secret authentication: %v", err) 186 } 187 return svcClient, nil 188 } 189 190 func (sal *StorageAccount) createTableIfNotExists() error { 191 exists, err := sal.isTableExists(TABLE_NAME) 192 if err != nil { 193 return err 194 } 195 196 if exists { 197 return nil 198 } 199 200 // Table doesn't exist, we create it 201 _, err = sal.tableClient.CreateTable(context.TODO(), nil) 202 if err != nil { 203 return fmt.Errorf("could not create table: %v", err) 204 } 205 206 return nil 207 } 208 209 func (sal *StorageAccount) isTableExists(table string) (bool, error) { 210 tablesPager := sal.svcClient.NewListTablesPager(nil) 211 for tablesPager.More() { 212 res, err := tablesPager.NextPage(context.Background()) 213 if err != nil { 214 return false, fmt.Errorf("could not retrieve the tables: %v", err) 215 } 216 217 for _, t := range res.Tables { 218 if *t.Name == table { 219 return true, nil 220 } 221 } 222 } 223 224 return false, nil 225 } 226 227 func getServiceURL(saName string) string { 228 return fmt.Sprintf(SERVICE_URL_FORMAT, saName) 229 } 230 231 func normalizeResourceName(resourceName string) string { 232 resourceName = strings.ReplaceAll(resourceName, "/", "-") 233 return strings.ReplaceAll(resourceName, "#", "-") 234 }