github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/atc/db/user_factory.go (about)

     1  package db
     2  
     3  import (
     4  	"time"
     5  
     6  	sq "github.com/Masterminds/squirrel"
     7  )
     8  
     9  //go:generate counterfeiter . UserFactory
    10  
    11  type UserFactory interface {
    12  	CreateOrUpdateUser(username, connector, sub string) error
    13  	GetAllUsers() ([]User, error)
    14  	GetAllUsersByLoginDate(LastLogin time.Time) ([]User, error)
    15  }
    16  
    17  type userFactory struct {
    18  	conn Conn
    19  }
    20  
    21  func NewUserFactory(conn Conn) UserFactory {
    22  	return &userFactory{
    23  		conn: conn,
    24  	}
    25  }
    26  
    27  func (f *userFactory) CreateOrUpdateUser(username, connector, sub string) error {
    28  	tx, err := f.conn.Begin()
    29  
    30  	if err != nil {
    31  		return err
    32  	}
    33  	defer Rollback(tx)
    34  
    35  	builder := psql.Insert("users").
    36  		Columns("username", "connector", "sub").
    37  		Values(username, connector, sub)
    38  
    39  	_, err = builder.Suffix(`ON CONFLICT (sub) DO UPDATE SET
    40  					username = EXCLUDED.username,
    41  					connector = EXCLUDED.connector,
    42  					sub = EXCLUDED.sub,
    43  					last_login = now()`).
    44  		RunWith(tx).
    45  		Exec()
    46  	if err != nil {
    47  		return err
    48  	}
    49  
    50  	err = tx.Commit()
    51  	if err != nil {
    52  		return err
    53  	}
    54  
    55  	return nil
    56  }
    57  
    58  func (f *userFactory) GetAllUsers() ([]User, error) {
    59  	rows, err := psql.Select("id", "username", "connector", "last_login").
    60  		From("users").
    61  		RunWith(f.conn).
    62  		Query()
    63  
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  
    68  	defer Close(rows)
    69  
    70  	var users []User
    71  
    72  	for rows.Next() {
    73  		var currUser user
    74  		err = rows.Scan(&currUser.id, &currUser.name, &currUser.connector, &currUser.lastLogin)
    75  
    76  		if err != nil {
    77  			return nil, err
    78  		}
    79  
    80  		users = append(users, currUser)
    81  	}
    82  	return users, nil
    83  }
    84  
    85  func (f *userFactory) GetAllUsersByLoginDate(lastLogin time.Time) ([]User, error) {
    86  	rows, err := psql.Select("id", "username", "connector", "last_login").
    87  		From("users").
    88  		Where(sq.GtOrEq{"last_login": lastLogin}).
    89  		RunWith(f.conn).
    90  		Query()
    91  
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  
    96  	defer Close(rows)
    97  
    98  	var users []User
    99  
   100  	for rows.Next() {
   101  		var currUser user
   102  		err = rows.Scan(&currUser.id, &currUser.name, &currUser.connector, &currUser.lastLogin)
   103  
   104  		if err != nil {
   105  			return nil, err
   106  		}
   107  
   108  		users = append(users, currUser)
   109  	}
   110  	return users, nil
   111  }