github.com/opentofu/opentofu@v1.7.1/internal/backend/remote-state/pg/client.go (about)

     1  // Copyright (c) The OpenTofu Authors
     2  // SPDX-License-Identifier: MPL-2.0
     3  // Copyright (c) 2023 HashiCorp, Inc.
     4  // SPDX-License-Identifier: MPL-2.0
     5  
     6  package pg
     7  
     8  import (
     9  	"crypto/md5"
    10  	"database/sql"
    11  	"fmt"
    12  
    13  	uuid "github.com/hashicorp/go-uuid"
    14  	_ "github.com/lib/pq"
    15  	"github.com/opentofu/opentofu/internal/states/remote"
    16  	"github.com/opentofu/opentofu/internal/states/statemgr"
    17  )
    18  
    19  // RemoteClient is a remote client that stores data in a Postgres database
    20  type RemoteClient struct {
    21  	Client     *sql.DB
    22  	Name       string
    23  	SchemaName string
    24  
    25  	info *statemgr.LockInfo
    26  }
    27  
    28  func (c *RemoteClient) Get() (*remote.Payload, error) {
    29  	query := `SELECT data FROM %s.%s WHERE name = $1`
    30  	row := c.Client.QueryRow(fmt.Sprintf(query, c.SchemaName, statesTableName), c.Name)
    31  	var data []byte
    32  	err := row.Scan(&data)
    33  	switch {
    34  	case err == sql.ErrNoRows:
    35  		// No existing state returns empty.
    36  		return nil, nil
    37  	case err != nil:
    38  		return nil, err
    39  	default:
    40  		md5 := md5.Sum(data)
    41  		return &remote.Payload{
    42  			Data: data,
    43  			MD5:  md5[:],
    44  		}, nil
    45  	}
    46  }
    47  
    48  func (c *RemoteClient) Put(data []byte) error {
    49  	query := `INSERT INTO %s.%s (name, data) VALUES ($1, $2)
    50  		ON CONFLICT (name) DO UPDATE
    51  		SET data = $2 WHERE %s.name = $1`
    52  	_, err := c.Client.Exec(fmt.Sprintf(query, c.SchemaName, statesTableName, statesTableName), c.Name, data)
    53  	if err != nil {
    54  		return err
    55  	}
    56  	return nil
    57  }
    58  
    59  func (c *RemoteClient) Delete() error {
    60  	query := `DELETE FROM %s.%s WHERE name = $1`
    61  	_, err := c.Client.Exec(fmt.Sprintf(query, c.SchemaName, statesTableName), c.Name)
    62  	if err != nil {
    63  		return err
    64  	}
    65  	return nil
    66  }
    67  
    68  func (c *RemoteClient) Lock(info *statemgr.LockInfo) (string, error) {
    69  	var err error
    70  	var lockID string
    71  
    72  	if info.ID == "" {
    73  		lockID, err = uuid.GenerateUUID()
    74  		if err != nil {
    75  			return "", err
    76  		}
    77  		info.ID = lockID
    78  	}
    79  
    80  	// Local helper function so we can call it multiple places
    81  	//
    82  	lockUnlock := func(pgLockId string) error {
    83  		query := `SELECT pg_advisory_unlock(%s)`
    84  		row := c.Client.QueryRow(fmt.Sprintf(query, pgLockId))
    85  		var didUnlock []byte
    86  		err := row.Scan(&didUnlock)
    87  		if err != nil {
    88  			return &statemgr.LockError{Info: info, Err: err}
    89  		}
    90  		return nil
    91  	}
    92  
    93  	// Try to acquire locks for the existing row `id` and the creation lock `-1`.
    94  	query := `SELECT %s.id, pg_try_advisory_lock(%s.id), pg_try_advisory_lock(-1) FROM %s.%s WHERE %s.name = $1`
    95  	row := c.Client.QueryRow(fmt.Sprintf(query, statesTableName, statesTableName, c.SchemaName, statesTableName, statesTableName), c.Name)
    96  	var pgLockId, didLock, didLockForCreate []byte
    97  	err = row.Scan(&pgLockId, &didLock, &didLockForCreate)
    98  	switch {
    99  	case err == sql.ErrNoRows:
   100  		// No rows means we're creating the workspace. Take the creation lock.
   101  		innerRow := c.Client.QueryRow(`SELECT pg_try_advisory_lock(-1)`)
   102  		var innerDidLock []byte
   103  		err := innerRow.Scan(&innerDidLock)
   104  		if err != nil {
   105  			return "", &statemgr.LockError{Info: info, Err: err}
   106  		}
   107  		if string(innerDidLock) == "false" {
   108  			return "", &statemgr.LockError{Info: info, Err: fmt.Errorf("Already locked for workspace creation: %s", c.Name)}
   109  		}
   110  		info.Path = "-1"
   111  	case err != nil:
   112  		return "", &statemgr.LockError{Info: info, Err: err}
   113  	case string(didLock) == "false":
   114  		// Existing workspace is already locked. Release the attempted creation lock.
   115  		lockUnlock("-1")
   116  		return "", &statemgr.LockError{Info: info, Err: fmt.Errorf("Workspace is already locked: %s", c.Name)}
   117  	case string(didLockForCreate) == "false":
   118  		// Someone has the creation lock already. Release the existing workspace because it might not be safe to touch.
   119  		lockUnlock(string(pgLockId))
   120  		return "", &statemgr.LockError{Info: info, Err: fmt.Errorf("Cannot lock workspace; already locked for workspace creation: %s", c.Name)}
   121  	default:
   122  		// Existing workspace is now locked. Release the attempted creation lock.
   123  		lockUnlock("-1")
   124  		info.Path = string(pgLockId)
   125  	}
   126  	c.info = info
   127  
   128  	return info.ID, nil
   129  }
   130  
   131  func (c *RemoteClient) getLockInfo() (*statemgr.LockInfo, error) {
   132  	return c.info, nil
   133  }
   134  
   135  func (c *RemoteClient) Unlock(id string) error {
   136  	if c.info != nil && c.info.Path != "" {
   137  		query := `SELECT pg_advisory_unlock(%s)`
   138  		row := c.Client.QueryRow(fmt.Sprintf(query, c.info.Path))
   139  		var didUnlock []byte
   140  		err := row.Scan(&didUnlock)
   141  		if err != nil {
   142  			return &statemgr.LockError{Info: c.info, Err: err}
   143  		}
   144  		c.info = nil
   145  	}
   146  	return nil
   147  }