eintopf.info@v0.13.16/service/invitation/store_sql.go (about)

     1  // Copyright (C) 2022 The Eintopf authors
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <https://www.gnu.org/licenses/>.
    15  
    16  package invitation
    17  
    18  import (
    19  	"context"
    20  	"database/sql"
    21  	"fmt"
    22  	"strings"
    23  
    24  	"github.com/google/uuid"
    25  	"github.com/jmoiron/sqlx"
    26  )
    27  
    28  // NewSqlStore returns a new sql db invitation store.
    29  func NewSqlStore(db *sqlx.DB) *SqlStore {
    30  	return &SqlStore{db: db}
    31  }
    32  
    33  type SqlStore struct {
    34  	db *sqlx.DB
    35  }
    36  
    37  func (s *SqlStore) RunMigrations(ctx context.Context) error {
    38  	_, err := s.db.ExecContext(ctx, `
    39          CREATE TABLE IF NOT EXISTS invitations (
    40              id varchar(32) NOT NULL PRIMARY KEY UNIQUE,
    41              token varchar(128) NOT NULL,
    42              created_at timestamp,
    43              created_by VARCHAR(32),
    44              used_by VARCHAR(32)
    45          );
    46      `)
    47  	return err
    48  }
    49  
    50  func (s *SqlStore) Create(ctx context.Context, newInvitation *NewInvitation) (*Invitation, error) {
    51  	invitation := InvitationFromNewInvitation(newInvitation, uuid.New().String())
    52  	_, err := s.db.NamedExecContext(ctx, `
    53          INSERT INTO invitations (
    54              id,
    55              token,
    56              created_at,
    57              created_by,
    58              used_by
    59          ) VALUES (
    60              :id,
    61              :token,
    62              :created_at,
    63              :created_by,
    64              :used_by
    65          )
    66      `, invitation)
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  
    71  	return invitation, nil
    72  }
    73  
    74  func (s *SqlStore) Update(ctx context.Context, invitation *Invitation) (*Invitation, error) {
    75  	_, err := s.db.NamedExecContext(ctx, `
    76          UPDATE
    77              invitations
    78          SET
    79              token=:token,
    80              created_at=:created_at,
    81              created_by=:created_by,
    82              used_by=:used_by
    83          WHERE
    84              id=:id
    85      `, invitation)
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  
    90  	return invitation, err
    91  }
    92  
    93  func (s *SqlStore) Delete(ctx context.Context, id string) error {
    94  	_, err := s.db.ExecContext(ctx, "DELETE FROM invitations WHERE id = $1", id)
    95  	if err != nil {
    96  		return err
    97  	}
    98  
    99  	return nil
   100  }
   101  
   102  func (s *SqlStore) FindByID(ctx context.Context, id string) (*Invitation, error) {
   103  	invitation := &Invitation{}
   104  	err := s.db.GetContext(ctx, invitation, `
   105          SELECT *
   106          FROM invitations
   107          WHERE invitations.id = $1
   108      `, id)
   109  	if err == sql.ErrNoRows {
   110  		return nil, nil
   111  	}
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  
   116  	return invitation, nil
   117  }
   118  
   119  var sortableFields = map[string]string{
   120  	"id":        "id",
   121  	"token":     "token",
   122  	"createdAt": "created_at",
   123  	"createdBy": "created_by",
   124  	"usedBy":    "used_by",
   125  }
   126  
   127  func (s *SqlStore) Find(ctx context.Context, params *FindParams) ([]Invitation, int, error) {
   128  	query := `SELECT invitations.* FROM invitations`
   129  	whereStatements := []string{}
   130  	sqlParams := make(map[string]interface{})
   131  	if params != nil {
   132  		if params.Filters != nil {
   133  			if params.Filters.ID != nil {
   134  				whereStatements = append(whereStatements, "invitations.id=:id")
   135  				sqlParams["id"] = params.Filters.ID
   136  			}
   137  			if params.Filters.Token != nil {
   138  				whereStatements = append(whereStatements, "invitations.token=:token")
   139  				sqlParams["token"] = params.Filters.Token
   140  			}
   141  			if params.Filters.CreatedAt != nil {
   142  				whereStatements = append(whereStatements, "invitations.created_at=:created_at")
   143  				sqlParams["created_at"] = params.Filters.CreatedAt
   144  			}
   145  			if params.Filters.CreatedBy != nil {
   146  				whereStatements = append(whereStatements, "invitations.created_by=:created_by")
   147  				sqlParams["created_by"] = params.Filters.CreatedBy
   148  			}
   149  			if params.Filters.UsedBy != nil {
   150  				whereStatements = append(whereStatements, "invitations.used_by=:used_by")
   151  				sqlParams["used_by"] = params.Filters.UsedBy
   152  			}
   153  			if len(whereStatements) > 0 {
   154  				query += " WHERE " + strings.Join(whereStatements, " AND ")
   155  			}
   156  		}
   157  
   158  		if params.Sort != "" {
   159  			sort, ok := sortableFields[params.Sort]
   160  			if !ok {
   161  				return nil, 0, fmt.Errorf("find invitations: invalid sort field: %s", params.Sort)
   162  			}
   163  			order := "ASC"
   164  			if params.Order == "DESC" {
   165  				order = "DESC"
   166  			}
   167  			query += fmt.Sprintf(" ORDER BY %s %s", sort, order)
   168  		}
   169  		if params.Limit > 0 {
   170  			query += fmt.Sprintf(" LIMIT %d", params.Limit)
   171  		}
   172  		if params.Offset > 0 {
   173  			query += fmt.Sprintf(" OFFSET %d", params.Offset)
   174  		}
   175  	}
   176  	rows, err := s.db.NamedQueryContext(ctx, query, sqlParams)
   177  	if err != nil && err != sql.ErrNoRows {
   178  		return nil, 0, fmt.Errorf("find invitations: %s", err)
   179  	}
   180  	defer rows.Close()
   181  	invitations := make([]Invitation, 0)
   182  	for rows.Next() {
   183  		invitation := Invitation{}
   184  		rows.StructScan(&invitation)
   185  
   186  		invitations = append(invitations, invitation)
   187  	}
   188  
   189  	totalQuery := `SELECT COUNT(*) as total FROM invitations`
   190  	if len(whereStatements) > 0 {
   191  		totalQuery += " WHERE " + strings.Join(whereStatements, " AND ")
   192  	}
   193  	totalInvitations := struct {
   194  		Total int `db:"total"`
   195  	}{}
   196  	rows2, err := s.db.NamedQueryContext(ctx, totalQuery, sqlParams)
   197  	if err != nil && err != sql.ErrNoRows {
   198  		return nil, 0, fmt.Errorf("find invitations: %s", err)
   199  	}
   200  	defer rows2.Close()
   201  	if rows2.Next() {
   202  		rows2.StructScan(&totalInvitations)
   203  	}
   204  
   205  	return invitations, totalInvitations.Total, nil
   206  }