eintopf.info@v0.13.16/service/group/store_sql.go (about)

     1  // Copyright (C) 2022 The Eintopf authors
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <https://www.gnu.org/licenses/>.
    15  
    16  package group
    17  
    18  import (
    19  	"context"
    20  	"database/sql"
    21  	"fmt"
    22  	"strings"
    23  
    24  	"eintopf.info/internal/crud"
    25  	"eintopf.info/service/dbmigration"
    26  	"github.com/google/uuid"
    27  	"github.com/jmoiron/sqlx"
    28  )
    29  
    30  // NewSqlStore returns a new sql db group store.
    31  func NewSqlStore(db *sqlx.DB, migrationService dbmigration.Service) (*SqlStore, error) {
    32  	store := &SqlStore{db: db, migrationService: migrationService}
    33  	if err := store.runMigrations(context.Background()); err != nil {
    34  		return nil, err
    35  	}
    36  	return store, nil
    37  }
    38  
    39  type SqlStore struct {
    40  	db               *sqlx.DB
    41  	migrationService dbmigration.Service
    42  }
    43  
    44  func (s *SqlStore) runMigrations(ctx context.Context) error {
    45  	return s.migrationService.RunMigrations(ctx, []dbmigration.Migration{
    46  		dbmigration.NewMigration("createGroupsTable", s.createGroupsTable, nil),
    47  		dbmigration.NewMigration("createGroupsOwnedByTable", s.createGroupsOwnedByTable, nil),
    48  	})
    49  }
    50  
    51  func (s *SqlStore) createGroupsTable(ctx context.Context) error {
    52  	_, err := s.db.ExecContext(ctx, `
    53          CREATE TABLE IF NOT EXISTS groups (
    54              id varchar(32) NOT NULL PRIMARY KEY UNIQUE,
    55              deactivated boolean DEFAULT FALSE,
    56              published boolean DEFAULT FALSE,
    57              name varchar(64) NOT NULL UNIQUE,
    58              link varchar(64) DEFAULT "",
    59              email varchar(64) DEFAULT "",
    60              description varchar(512),
    61              image varchar(128)
    62          );
    63      `)
    64  	return err
    65  }
    66  
    67  func (s *SqlStore) createGroupsOwnedByTable(ctx context.Context) error {
    68  	_, err := s.db.ExecContext(ctx, `
    69          CREATE TABLE IF NOT EXISTS groups_owned_by (
    70              ID INTEGER PRIMARY KEY AUTOINCREMENT,
    71              group_id varchar(32),
    72              user_id varchar(32)
    73          );
    74      `)
    75  	return err
    76  }
    77  
    78  func (s *SqlStore) Create(ctx context.Context, newGroup *NewGroup) (*Group, error) {
    79  	group := GroupFromNewGroup(newGroup, uuid.New().String())
    80  	_, err := s.db.NamedExecContext(ctx, `
    81          INSERT INTO groups (
    82              id,
    83              published,
    84              deactivated,
    85              name,
    86              link,
    87              email,
    88              description,
    89              image
    90          ) VALUES (
    91              :id,
    92              :published,
    93              :deactivated,
    94              :name,
    95              :link,
    96              :email,
    97              :description,
    98              :image
    99          )
   100      `, group)
   101  	if err != nil {
   102  		return nil, err
   103  	}
   104  
   105  	err = s.insertOwnedByForGroup(ctx, group)
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  
   110  	return group, nil
   111  }
   112  
   113  func (s *SqlStore) insertOwnedByForGroup(ctx context.Context, group *Group) error {
   114  	for _, ownedBy := range group.OwnedBy {
   115  		_, err := s.db.ExecContext(ctx, `
   116              INSERT INTO groups_owned_by (
   117                  group_id,
   118                  user_id
   119              ) VALUES (
   120                  $1,
   121                  $2
   122              )
   123          `, group.ID, ownedBy)
   124  		if err != nil {
   125  			return err
   126  		}
   127  	}
   128  	return nil
   129  }
   130  
   131  func (s *SqlStore) Update(ctx context.Context, group *Group) (*Group, error) {
   132  	_, err := s.db.NamedExecContext(ctx, `
   133          UPDATE
   134              groups
   135          SET
   136              published=:published,
   137              deactivated=:deactivated,
   138              name=:name,
   139              link=:link,
   140              email=:email,
   141              description=:description,
   142              image=:image
   143          WHERE
   144              id=:id
   145      `, group)
   146  	if err != nil {
   147  		return nil, err
   148  	}
   149  
   150  	err = s.deleteOwnedByForGroup(ctx, group.ID)
   151  	if err != nil {
   152  		return nil, err
   153  	}
   154  	err = s.insertOwnedByForGroup(ctx, group)
   155  	if err != nil {
   156  		return nil, err
   157  	}
   158  
   159  	return group, err
   160  }
   161  
   162  func (s *SqlStore) Delete(ctx context.Context, id string) error {
   163  	_, err := s.db.ExecContext(ctx, "DELETE FROM groups WHERE id = $1", id)
   164  	if err != nil {
   165  		return err
   166  	}
   167  
   168  	err = s.deleteOwnedByForGroup(ctx, id)
   169  	if err != nil {
   170  		return err
   171  	}
   172  
   173  	return nil
   174  }
   175  
   176  func (s *SqlStore) deleteOwnedByForGroup(ctx context.Context, id string) error {
   177  	_, err := s.db.ExecContext(ctx, "DELETE FROM groups_owned_by WHERE group_id = $1", id)
   178  	if err != nil {
   179  		return err
   180  	}
   181  	return nil
   182  }
   183  
   184  func (s *SqlStore) FindByID(ctx context.Context, id string) (*Group, error) {
   185  	group := &Group{}
   186  	err := s.db.GetContext(ctx, group, `
   187          SELECT *
   188          FROM groups
   189          WHERE groups.id = $1
   190      `, id)
   191  	if err == sql.ErrNoRows {
   192  		return nil, nil
   193  	}
   194  	if err != nil {
   195  		return nil, err
   196  	}
   197  
   198  	s.findOwnedByForGroup(ctx, group)
   199  	if err != nil {
   200  		return nil, err
   201  	}
   202  
   203  	return group, nil
   204  }
   205  
   206  func (s *SqlStore) findOwnedByForGroup(ctx context.Context, group *Group) error {
   207  	ownedBy := []string{}
   208  	err := s.db.SelectContext(ctx, &ownedBy, `
   209          SELECT user_id
   210          FROM groups_owned_by
   211          WHERE group_id = $1
   212      `, group.ID)
   213  	if err == sql.ErrNoRows {
   214  		return nil
   215  	}
   216  	if err != nil {
   217  		return err
   218  	}
   219  	group.OwnedBy = ownedBy
   220  	return nil
   221  }
   222  
   223  var sortableFields = map[string]string{
   224  	"id":          "id",
   225  	"deactivated": "deactivated",
   226  	"published":   "published",
   227  	"name":        "name",
   228  	"link":        "link",
   229  	"email":       "email",
   230  }
   231  
   232  func (s *SqlStore) Find(ctx context.Context, params *crud.FindParams[FindFilters]) ([]*Group, int, error) {
   233  	query := `
   234          SELECT groups.*
   235          FROM groups
   236      `
   237  	joinQuery := ""
   238  	whereStatements := []string{}
   239  	sqlParams := make(map[string]interface{})
   240  	if params != nil {
   241  		if params.Filters != nil {
   242  			if params.Filters.ID != nil {
   243  				whereStatements = append(whereStatements, "groups.id=:id")
   244  				sqlParams["id"] = params.Filters.ID
   245  			}
   246  			if params.Filters.NotID != nil {
   247  				whereStatements = append(whereStatements, "groups.id!=:notID")
   248  				sqlParams["notID"] = params.Filters.NotID
   249  			}
   250  			if params.Filters.Deactivated != nil {
   251  				whereStatements = append(whereStatements, "groups.deactivated=:deactivated")
   252  				sqlParams["deactivated"] = params.Filters.Deactivated
   253  			}
   254  			if params.Filters.Published != nil {
   255  				whereStatements = append(whereStatements, "groups.published=:published")
   256  				sqlParams["published"] = params.Filters.Published
   257  			}
   258  			if params.Filters.Name != nil {
   259  				whereStatements = append(whereStatements, "groups.name=:name")
   260  				sqlParams["name"] = params.Filters.Name
   261  			}
   262  			if params.Filters.LikeName != nil {
   263  				whereStatements = append(whereStatements, "groups.name LIKE :likeName")
   264  				sqlParams["likeName"] = fmt.Sprintf("%%%s%%", *params.Filters.LikeName)
   265  			}
   266  			if params.Filters.Link != nil {
   267  				whereStatements = append(whereStatements, "groups.link=:link")
   268  				sqlParams["link"] = params.Filters.Link
   269  			}
   270  			if params.Filters.Email != nil {
   271  				whereStatements = append(whereStatements, "groups.email=:email")
   272  				sqlParams["email"] = params.Filters.Email
   273  			}
   274  			if params.Filters.Description != nil {
   275  				whereStatements = append(whereStatements, "groups.description=:description")
   276  				sqlParams["Description"] = params.Filters.Description
   277  			}
   278  			if params.Filters.OwnedBy != nil {
   279  				joinQuery += " JOIN groups_owned_by ON groups.id = groups_owned_by.group_id"
   280  				ownedByStatements := make([]string, len(params.Filters.OwnedBy))
   281  				for i, ownedBy := range params.Filters.OwnedBy {
   282  					ownedByRef := fmt.Sprintf("ownedBy%d", i)
   283  					ownedByStatements[i] = fmt.Sprintf("groups_owned_by.user_id=:%s", ownedByRef)
   284  					sqlParams[ownedByRef] = ownedBy
   285  				}
   286  				whereStatements = append(whereStatements, fmt.Sprintf("(%s)", strings.Join(ownedByStatements, " OR ")))
   287  			}
   288  
   289  			if joinQuery != "" {
   290  				query += " " + joinQuery
   291  			}
   292  			if len(whereStatements) > 0 {
   293  				query += " WHERE " + strings.Join(whereStatements, " AND ")
   294  			}
   295  			if joinQuery != "" {
   296  				query += " GROUP BY groups.id"
   297  			}
   298  		}
   299  
   300  		if params.Sort != "" {
   301  			sort, ok := sortableFields[params.Sort]
   302  			if !ok {
   303  				return nil, 0, fmt.Errorf("find groups: invalid sort field: %s", params.Sort)
   304  			}
   305  			order := "ASC"
   306  			if params.Order == "DESC" {
   307  				order = "DESC"
   308  			}
   309  			query += fmt.Sprintf(" ORDER BY %s %s", sort, order)
   310  		}
   311  		if params.Limit > 0 {
   312  			query += fmt.Sprintf(" LIMIT %d", params.Limit)
   313  		}
   314  		if params.Offset > 0 {
   315  			query += fmt.Sprintf(" OFFSET %d", params.Offset)
   316  		}
   317  	}
   318  	rows, err := s.db.NamedQueryContext(ctx, query, sqlParams)
   319  	if err != nil && err != sql.ErrNoRows {
   320  		return nil, 0, fmt.Errorf("find groups: %s", params.Sort)
   321  	}
   322  	defer rows.Close()
   323  	groups := make([]*Group, 0)
   324  	for rows.Next() {
   325  		group := &Group{}
   326  		rows.StructScan(group)
   327  
   328  		groups = append(groups, group)
   329  	}
   330  
   331  	for i, group := range groups {
   332  		err = s.findOwnedByForGroup(ctx, group)
   333  		if err != nil {
   334  			return nil, 0, err
   335  		}
   336  		groups[i] = group
   337  	}
   338  
   339  	totalQuery := `
   340          SELECT COUNT(*) as total
   341          FROM groups
   342      `
   343  	if len(whereStatements) > 0 {
   344  		if joinQuery != "" {
   345  			totalQuery += " " + joinQuery
   346  		}
   347  		totalQuery += " WHERE " + strings.Join(whereStatements, " AND ")
   348  		if joinQuery != "" {
   349  			totalQuery += " GROUP BY groups.id"
   350  		}
   351  	}
   352  	totalGroups := struct {
   353  		Total int `db:"total"`
   354  	}{}
   355  	rows2, err := s.db.NamedQueryContext(ctx, totalQuery, sqlParams)
   356  	if err != nil && err != sql.ErrNoRows {
   357  		return nil, 0, fmt.Errorf("find groups: total: %s", params.Sort)
   358  	}
   359  	defer rows2.Close()
   360  	if rows2.Next() {
   361  		rows2.StructScan(&totalGroups)
   362  	}
   363  
   364  	return groups, totalGroups.Total, nil
   365  }