github.com/diggerhq/digger/libs@v0.0.0-20240604170430-9d61cdf01cc5/locking/locking.go (about)

     1  package locking
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"github.com/diggerhq/digger/libs/comment_utils/utils"
     8  	"github.com/diggerhq/digger/libs/locking/aws"
     9  	"github.com/diggerhq/digger/libs/locking/azure"
    10  	"github.com/diggerhq/digger/libs/locking/gcp"
    11  	"log"
    12  	"os"
    13  	"strconv"
    14  	"strings"
    15  
    16  	"github.com/diggerhq/digger/libs/comment_utils/reporting"
    17  	"github.com/diggerhq/digger/libs/locking/aws/envprovider"
    18  	"github.com/diggerhq/digger/libs/orchestrator"
    19  
    20  	"cloud.google.com/go/storage"
    21  	awssdk "github.com/aws/aws-sdk-go-v2/aws"
    22  	"github.com/aws/aws-sdk-go-v2/config"
    23  	"github.com/aws/aws-sdk-go-v2/service/dynamodb"
    24  	"github.com/aws/aws-sdk-go-v2/service/sts"
    25  )
    26  
    27  type PullRequestLock struct {
    28  	InternalLock     Lock
    29  	CIService        orchestrator.PullRequestService
    30  	Reporter         reporting.Reporter
    31  	ProjectName      string
    32  	ProjectNamespace string
    33  	PrNumber         int
    34  }
    35  
    36  type NoOpLock struct {
    37  }
    38  
    39  func (noOpLock NoOpLock) Lock(transactionId int, resource string) (bool, error) {
    40  	return true, nil
    41  }
    42  
    43  func (noOpLock NoOpLock) Unlock(resource string) (bool, error) {
    44  	return true, nil
    45  }
    46  
    47  func (noOpLock NoOpLock) GetLock(resource string) (*int, error) {
    48  	return nil, nil
    49  }
    50  
    51  func (projectLock *PullRequestLock) Lock() (bool, error) {
    52  	lockId := projectLock.LockId()
    53  	log.Printf("Lock %s\n", lockId)
    54  
    55  	noHangingLocks, err := projectLock.verifyNoHangingLocks()
    56  
    57  	if err != nil {
    58  		return false, err
    59  	}
    60  
    61  	if !noHangingLocks {
    62  		return false, nil
    63  	}
    64  
    65  	existingLockTransactionId, err := projectLock.InternalLock.GetLock(lockId)
    66  	if err != nil {
    67  		log.Printf("failed to get lock: %v\n", err)
    68  		return false, err
    69  	}
    70  	if existingLockTransactionId != nil {
    71  		if *existingLockTransactionId == projectLock.PrNumber {
    72  			return true, nil
    73  		} else {
    74  			transactionIdStr := strconv.Itoa(*existingLockTransactionId)
    75  			comment := "Project " + projectLock.projectId() + " locked by another PR #" + transactionIdStr + " (failed to acquire lock " + projectLock.ProjectNamespace + "). The locking plan must be applied or discarded before future plans can execute"
    76  
    77  			reportLockingFailed(projectLock.Reporter, comment)
    78  			return false, fmt.Errorf(comment)
    79  		}
    80  	}
    81  	lockAcquired, err := projectLock.InternalLock.Lock(projectLock.PrNumber, lockId)
    82  	if err != nil {
    83  		return false, err
    84  	}
    85  
    86  	_, isNoOpLock := projectLock.InternalLock.(*NoOpLock)
    87  
    88  	if lockAcquired && !isNoOpLock {
    89  		comment := "Project " + projectLock.projectId() + " has been locked by PR #" + strconv.Itoa(projectLock.PrNumber)
    90  		reportingLockingSuccess(projectLock.Reporter, comment)
    91  		log.Println("project " + projectLock.projectId() + " locked successfully. PR # " + strconv.Itoa(projectLock.PrNumber))
    92  
    93  	}
    94  	return lockAcquired, nil
    95  }
    96  
    97  func reportingLockingSuccess(r reporting.Reporter, comment string) {
    98  	if r.SupportsMarkdown() {
    99  		_, _, err := r.Report(comment, utils.AsCollapsibleComment("Locking successful", false))
   100  		if err != nil {
   101  			log.Println("failed to publish comment: " + err.Error())
   102  		}
   103  	} else {
   104  		_, _, err := r.Report(comment, utils.AsComment("Locking successful"))
   105  		if err != nil {
   106  			log.Println("failed to publish comment: " + err.Error())
   107  		}
   108  	}
   109  }
   110  
   111  func reportLockingFailed(r reporting.Reporter, comment string) {
   112  	if r.SupportsMarkdown() {
   113  		_, _, err := r.Report(comment, utils.AsCollapsibleComment("Locking failed", false))
   114  		if err != nil {
   115  			log.Println("failed to publish comment: " + err.Error())
   116  		}
   117  	} else {
   118  		_, _, err := r.Report(comment, utils.AsComment("Locking failed"))
   119  		if err != nil {
   120  			log.Println("failed to publish comment: " + err.Error())
   121  		}
   122  	}
   123  }
   124  
   125  func (projectLock *PullRequestLock) verifyNoHangingLocks() (bool, error) {
   126  	lockId := projectLock.LockId()
   127  	transactionId, err := projectLock.InternalLock.GetLock(lockId)
   128  
   129  	if err != nil {
   130  		return false, err
   131  	}
   132  
   133  	if transactionId != nil {
   134  		if *transactionId != projectLock.PrNumber {
   135  			isPrClosed, err := projectLock.CIService.IsClosed(*transactionId)
   136  			if err != nil {
   137  				return false, fmt.Errorf("failed to check if PR holding a lock is closed: %w", err)
   138  			}
   139  			if isPrClosed {
   140  				_, err := projectLock.InternalLock.Unlock(lockId)
   141  				if err != nil {
   142  					return false, fmt.Errorf("failed to unlock a lock acquired by closed PR %v: %w", transactionId, err)
   143  				}
   144  				return true, nil
   145  			}
   146  			transactionIdStr := strconv.Itoa(*transactionId)
   147  			comment := "Project " + projectLock.projectId() + " locked by another PR #" + transactionIdStr + "(failed to acquire lock " + projectLock.ProjectName + "). The locking plan must be applied or discarded before future plans can execute"
   148  			reportLockingFailed(projectLock.Reporter, comment)
   149  			return false, fmt.Errorf(comment)
   150  		}
   151  		return true, nil
   152  	}
   153  	return true, nil
   154  }
   155  
   156  func (projectLock *PullRequestLock) Unlock() (bool, error) {
   157  	lockId := projectLock.LockId()
   158  	log.Printf("Unlock %s\n", lockId)
   159  	lock, err := projectLock.InternalLock.GetLock(lockId)
   160  	if err != nil {
   161  		return false, err
   162  	}
   163  
   164  	if lock != nil {
   165  		transactionId := *lock
   166  		if projectLock.PrNumber == transactionId {
   167  			lockReleased, err := projectLock.InternalLock.Unlock(lockId)
   168  			if err != nil {
   169  				return false, err
   170  			}
   171  			if lockReleased {
   172  				comment := "Project unlocked (" + projectLock.projectId() + ")."
   173  				reportSuccessfulUnlocking(projectLock.Reporter, comment)
   174  
   175  				log.Println("Project unlocked")
   176  				return true, nil
   177  			}
   178  		}
   179  	}
   180  	return false, nil
   181  }
   182  
   183  func reportSuccessfulUnlocking(r reporting.Reporter, comment string) {
   184  	if r.SupportsMarkdown() {
   185  		_, _, err := r.Report(comment, utils.AsCollapsibleComment("Unlocking successful", false))
   186  		if err != nil {
   187  			log.Println("failed to publish comment: " + err.Error())
   188  		}
   189  	} else {
   190  		_, _, err := r.Report(comment, utils.AsComment("Unlocking successful"))
   191  		if err != nil {
   192  			log.Println("failed to publish comment: " + err.Error())
   193  		}
   194  	}
   195  }
   196  
   197  func (projectLock *PullRequestLock) ForceUnlock() error {
   198  	lockId := projectLock.LockId()
   199  	log.Printf("ForceUnlock %s\n", lockId)
   200  	lock, err := projectLock.InternalLock.GetLock(lockId)
   201  	if err != nil {
   202  		return err
   203  	}
   204  	if lock != nil {
   205  		lockReleased, err := projectLock.InternalLock.Unlock(lockId)
   206  		if err != nil {
   207  			return err
   208  		}
   209  
   210  		if lockReleased {
   211  			comment := "Project unlocked (" + projectLock.projectId() + ")."
   212  			reportSuccessfulUnlocking(projectLock.Reporter, comment)
   213  			log.Println("Project unlocked")
   214  		}
   215  		return nil
   216  	}
   217  	return nil
   218  }
   219  
   220  func (projectLock *PullRequestLock) projectId() string {
   221  	return projectLock.ProjectNamespace + "#" + projectLock.ProjectName
   222  }
   223  
   224  func (projectLock *PullRequestLock) LockId() string {
   225  	return projectLock.ProjectNamespace + "#" + projectLock.ProjectName
   226  }
   227  
   228  // DoEnvVarsExist return true if any of env vars do exist, false otherwise
   229  func DoEnvVarsExist(envVars []string) bool {
   230  	result := false
   231  	for _, key := range envVars {
   232  		value := os.Getenv(key)
   233  		if value != "" {
   234  			result = true
   235  		}
   236  	}
   237  	return result
   238  }
   239  
   240  func GetLock() (Lock, error) {
   241  	awsRegion := strings.ToLower(os.Getenv("AWS_REGION"))
   242  	awsProfile := strings.ToLower(os.Getenv("AWS_PROFILE"))
   243  	lockProvider := strings.ToLower(os.Getenv("LOCK_PROVIDER"))
   244  	disableLocking := strings.ToLower(os.Getenv("DISABLE_LOCKING")) == "true"
   245  
   246  	if disableLocking {
   247  		log.Println("Using NoOp lock provider.")
   248  		return &NoOpLock{}, nil
   249  	}
   250  	if lockProvider == "" || lockProvider == "aws" {
   251  		log.Println("Using AWS lock provider.")
   252  
   253  		// https://aws.github.io/aws-sdk-go-v2/docs/configuring-sdk/
   254  		// https://aws.github.io/aws-sdk-go-v2/docs/migrating/
   255  		keysToCheck := []string{"DIGGER_AWS_ACCESS_KEY_ID", "AWS_ACCESS_KEY_ID", "AWS_ACCESS_KEY"}
   256  		awsCredsProvided := DoEnvVarsExist(keysToCheck)
   257  
   258  		var cfg awssdk.Config
   259  		var err error
   260  		if awsCredsProvided {
   261  			cfg, err = config.LoadDefaultConfig(context.Background(),
   262  				config.WithSharedConfigProfile(awsProfile),
   263  				config.WithRegion(awsRegion),
   264  				config.WithCredentialsProvider(&envprovider.EnvProvider{}))
   265  			if err != nil {
   266  				return nil, err
   267  			}
   268  		} else {
   269  			log.Printf("Using keyless aws digger_config\n")
   270  			cfg, err = config.LoadDefaultConfig(context.Background(), config.WithRegion(awsRegion))
   271  			if err != nil {
   272  				return nil, err
   273  			}
   274  		}
   275  
   276  		stsClient := sts.NewFromConfig(cfg)
   277  		result, err := stsClient.GetCallerIdentity(context.Background(), &sts.GetCallerIdentityInput{})
   278  		if err != nil {
   279  			return nil, fmt.Errorf("failed to connect to AWS account. %v", err)
   280  		}
   281  		log.Printf("Successfully connected to AWS account %s, user Id: %s\n", *result.Account, *result.UserId)
   282  
   283  		dynamoDb := dynamodb.NewFromConfig(cfg)
   284  		dynamoDbLock := aws.DynamoDbLock{DynamoDb: dynamoDb}
   285  		return &dynamoDbLock, nil
   286  	} else if lockProvider == "gcp" {
   287  		log.Println("Using GCP lock provider.")
   288  		ctx, client := gcp.GetGoogleStorageClient()
   289  		defer func(client *storage.Client) {
   290  			err := client.Close()
   291  			if err != nil {
   292  				log.Fatalf("Failed to close Google Storage client: %v", err)
   293  			}
   294  		}(client)
   295  
   296  		bucketName := strings.ToLower(os.Getenv("GOOGLE_STORAGE_LOCK_BUCKET"))
   297  		if bucketName == "" {
   298  			return nil, errors.New("GOOGLE_STORAGE_LOCK_BUCKET is not set")
   299  		}
   300  		bucket := client.Bucket(bucketName)
   301  		lock := gcp.GoogleStorageLock{Client: client, Bucket: bucket, Context: ctx}
   302  		return &lock, nil
   303  	} else if lockProvider == "azure" {
   304  		return azure.NewStorageAccountLock()
   305  	}
   306  
   307  	return nil, errors.New("failed to find lock provider")
   308  }