github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/atc/db/wall.go (about)

     1  package db
     2  
     3  import (
     4  	"database/sql"
     5  
     6  	"github.com/pf-qiu/concourse/v6/atc"
     7  
     8  	sq "github.com/Masterminds/squirrel"
     9  )
    10  
    11  //go:generate counterfeiter . Wall
    12  
    13  type Wall interface {
    14  	SetWall(atc.Wall) error
    15  	GetWall() (atc.Wall, error)
    16  	Clear() error
    17  }
    18  
    19  type wall struct {
    20  	conn  Conn
    21  	clock Clock
    22  }
    23  
    24  func NewWall(conn Conn, clock Clock) Wall {
    25  	return &wall{
    26  		conn: conn,
    27  		clock: clock,
    28  	}
    29  }
    30  
    31  func (w wall) SetWall(wall atc.Wall) error {
    32  	tx, err := w.conn.Begin()
    33  	if err != nil {
    34  		return err
    35  	}
    36  
    37  	defer Rollback(tx)
    38  
    39  	_, err = psql.Delete("wall").RunWith(tx).Exec()
    40  	if err != nil {
    41  		return err
    42  	}
    43  
    44  	query := psql.Insert("wall").
    45  		Columns("message")
    46  
    47  	if wall.TTL != 0 {
    48  		expiresAt := w.clock.Now().Add(wall.TTL)
    49  		query = query.Columns("expires_at").Values(wall.Message, expiresAt)
    50  	} else {
    51  		query = query.Values(wall.Message)
    52  	}
    53  
    54  	_, err = query.RunWith(tx).Exec()
    55  	if err != nil {
    56  		return err
    57  	}
    58  
    59  	err = tx.Commit()
    60  	if err != nil {
    61  		return err
    62  	}
    63  
    64  	return nil
    65  }
    66  
    67  func (w wall) GetWall() (atc.Wall, error) {
    68  	var wall atc.Wall
    69  
    70  	row := psql.Select("message", "expires_at").
    71  		From("wall").
    72  		Where(sq.Or{
    73  			sq.Gt{"expires_at": w.clock.Now()},
    74  			sq.Eq{"expires_at": nil},
    75  		}).
    76  		RunWith(w.conn).QueryRow()
    77  
    78  	err := w.scanWall(&wall, row)
    79  	if err != nil && err != sql.ErrNoRows {
    80  		return atc.Wall{}, err
    81  	}
    82  
    83  	return wall, nil
    84  }
    85  
    86  func (w *wall) scanWall(wall *atc.Wall, scan scannable) error {
    87  	var expiresAt sql.NullTime
    88  
    89  	err := scan.Scan(&wall.Message, &expiresAt)
    90  	if err != nil {
    91  		return err
    92  	}
    93  
    94  	if expiresAt.Valid {
    95  		wall.TTL = w.clock.Until(expiresAt.Time)
    96  	}
    97  
    98  	return nil
    99  }
   100  
   101  func (w wall) Clear() error {
   102  	_, err := psql.Delete("wall").RunWith(w.conn).Exec()
   103  	if err != nil {
   104  		return err
   105  	}
   106  	return nil
   107  }