eintopf.info@v0.13.16/service/place/store_sql.go (about)

     1  // Copyright (C) 2022 The Eintopf authors
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <https://www.gnu.org/licenses/>.
    15  
    16  package place
    17  
    18  import (
    19  	"context"
    20  	"database/sql"
    21  	"fmt"
    22  	"strings"
    23  
    24  	"eintopf.info/internal/crud"
    25  	"eintopf.info/service/dbmigration"
    26  	"github.com/google/uuid"
    27  	"github.com/jmoiron/sqlx"
    28  )
    29  
    30  // NewSqlStore returns a new sql db place store.
    31  func NewSqlStore(db *sqlx.DB, migrationService dbmigration.Service) (*SqlStore, error) {
    32  	store := &SqlStore{db: db, migrationService: migrationService}
    33  	if err := store.runMigrations(context.Background()); err != nil {
    34  		return nil, err
    35  	}
    36  	return store, nil
    37  }
    38  
    39  type SqlStore struct {
    40  	db               *sqlx.DB
    41  	migrationService dbmigration.Service
    42  }
    43  
    44  func (s *SqlStore) runMigrations(ctx context.Context) error {
    45  	return s.migrationService.RunMigrations(ctx, []dbmigration.Migration{
    46  		dbmigration.NewMigration("createPlacesTable", s.createPlacesTable, nil),
    47  		dbmigration.NewMigration("createPlacesOwnedByTable", s.createPlacesOwnedByTable, nil),
    48  	})
    49  }
    50  
    51  func (s *SqlStore) createPlacesTable(ctx context.Context) error {
    52  	_, err := s.db.ExecContext(ctx, `
    53          CREATE TABLE IF NOT EXISTS places (
    54              id varchar(32) NOT NULL PRIMARY KEY UNIQUE,
    55              deactivated boolean DEFAULT FALSE,
    56              published boolean DEFAULT FALSE,
    57              name varchar(64) NOT NULL UNIQUE,
    58              description varchar(512),
    59              link varchar(128),
    60              email varchar(128),
    61              image varchar(128),
    62              address varchar(128),
    63              lat FLOAT(24),
    64              lng FLOAT(24)
    65          );
    66      `)
    67  	return err
    68  }
    69  
    70  func (s *SqlStore) createPlacesOwnedByTable(ctx context.Context) error {
    71  	_, err := s.db.ExecContext(ctx, `
    72          CREATE TABLE IF NOT EXISTS places_owned_by (
    73              ID INTEGER PRIMARY KEY AUTOINCREMENT,
    74              place_id varchar(32),
    75              user_id varchar(32)
    76          );
    77      `)
    78  	return err
    79  }
    80  
    81  func (s *SqlStore) Create(ctx context.Context, newPlace *NewPlace) (*Place, error) {
    82  	place := PlaceFromNewPlace(newPlace, uuid.New().String())
    83  	_, err := s.db.NamedExecContext(ctx, `
    84          INSERT INTO places (
    85              id,
    86              published,
    87              deactivated,
    88              name,
    89              description,
    90              email,
    91              image,
    92              link,
    93              address,
    94              lat,
    95              lng
    96          ) VALUES (
    97              :id,
    98              :published,
    99              :deactivated,
   100              :name,
   101              :description,
   102              :email,
   103              :image,
   104              :link,
   105              :address,
   106              :lat,
   107              :lng
   108          )
   109      `, place)
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  
   114  	err = s.insertOwnedByForPlace(ctx, place)
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  
   119  	return place, nil
   120  }
   121  
   122  func (s *SqlStore) insertOwnedByForPlace(ctx context.Context, place *Place) error {
   123  	for _, ownedBy := range place.OwnedBy {
   124  		_, err := s.db.ExecContext(ctx, `
   125              INSERT INTO places_owned_by (
   126                  place_id,
   127                  user_id
   128              ) VALUES (
   129                  $1,
   130                  $2
   131              )
   132          `, place.ID, ownedBy)
   133  		if err != nil {
   134  			return err
   135  		}
   136  	}
   137  	return nil
   138  }
   139  
   140  func (s *SqlStore) Update(ctx context.Context, place *Place) (*Place, error) {
   141  	_, err := s.db.NamedExecContext(ctx, `
   142          UPDATE
   143              places
   144          SET
   145              published=:published,
   146              deactivated=:deactivated,
   147              name=:name,
   148              description=:description,
   149              email=:email,
   150              image=:image,
   151              link=:link,
   152              address=:address,
   153              lat=:lat,
   154              lng=:lng
   155          WHERE
   156              id=:id
   157      `, place)
   158  	if err != nil {
   159  		return nil, err
   160  	}
   161  
   162  	err = s.deleteOwnedByForPlace(ctx, place.ID)
   163  	if err != nil {
   164  		return nil, err
   165  	}
   166  	err = s.insertOwnedByForPlace(ctx, place)
   167  	if err != nil {
   168  		return nil, err
   169  	}
   170  
   171  	return place, err
   172  }
   173  
   174  func (s *SqlStore) Delete(ctx context.Context, id string) error {
   175  	_, err := s.db.ExecContext(ctx, "DELETE FROM places WHERE id = $1", id)
   176  	if err != nil {
   177  		return err
   178  	}
   179  
   180  	err = s.deleteOwnedByForPlace(ctx, id)
   181  	if err != nil {
   182  		return err
   183  	}
   184  
   185  	return nil
   186  }
   187  
   188  func (s *SqlStore) deleteOwnedByForPlace(ctx context.Context, id string) error {
   189  	_, err := s.db.ExecContext(ctx, "DELETE FROM places_owned_by WHERE place_id = $1", id)
   190  	if err != nil {
   191  		return err
   192  	}
   193  	return nil
   194  }
   195  
   196  func (s *SqlStore) FindByID(ctx context.Context, id string) (*Place, error) {
   197  	place := &Place{}
   198  	err := s.db.GetContext(ctx, place, `
   199          SELECT *
   200          FROM places
   201          WHERE places.id = $1
   202      `, id)
   203  	if err == sql.ErrNoRows {
   204  		return nil, nil
   205  	}
   206  	if err != nil {
   207  		return nil, err
   208  	}
   209  
   210  	s.findOwnedByForPlace(ctx, place)
   211  	if err != nil {
   212  		return nil, err
   213  	}
   214  
   215  	return place, nil
   216  }
   217  
   218  func (s *SqlStore) findOwnedByForPlace(ctx context.Context, place *Place) error {
   219  	ownedBy := []string{}
   220  	err := s.db.SelectContext(ctx, &ownedBy, `
   221          SELECT user_id
   222          FROM places_owned_by
   223          WHERE place_id = $1
   224      `, place.ID)
   225  	if err == sql.ErrNoRows {
   226  		return nil
   227  	}
   228  	if err != nil {
   229  		return err
   230  	}
   231  	place.OwnedBy = ownedBy
   232  	return nil
   233  }
   234  
   235  var sortableFields = map[string]string{
   236  	"id":          "id",
   237  	"deactivated": "deactivated",
   238  	"published":   "published",
   239  	"name":        "name",
   240  }
   241  
   242  func (s *SqlStore) Find(ctx context.Context, params *crud.FindParams[FindFilters]) ([]*Place, int, error) {
   243  	query := `
   244          SELECT places.*
   245          FROM places
   246          `
   247  	joinQuery := ""
   248  	whereStatements := []string{}
   249  	sqlParams := make(map[string]interface{})
   250  	if params != nil {
   251  		if params.Filters != nil {
   252  			if params.Filters.ID != nil {
   253  				whereStatements = append(whereStatements, "places.id=:id")
   254  				sqlParams["id"] = params.Filters.ID
   255  			}
   256  			if params.Filters.NotID != nil {
   257  				whereStatements = append(whereStatements, "places.id!=:notID")
   258  				sqlParams["notID"] = params.Filters.NotID
   259  			}
   260  			if params.Filters.Deactivated != nil {
   261  				whereStatements = append(whereStatements, "places.deactivated=:deactivated")
   262  				sqlParams["deactivated"] = params.Filters.Deactivated
   263  			}
   264  			if params.Filters.Published != nil {
   265  				whereStatements = append(whereStatements, "places.published=:published")
   266  				sqlParams["published"] = params.Filters.Published
   267  			}
   268  			if params.Filters.Name != nil {
   269  				whereStatements = append(whereStatements, "places.name=:name")
   270  				sqlParams["name"] = params.Filters.Name
   271  			}
   272  			if params.Filters.LikeName != nil {
   273  				whereStatements = append(whereStatements, "places.name LIKE :likeName")
   274  				sqlParams["likeName"] = fmt.Sprintf("%%%s%%", *params.Filters.LikeName)
   275  			}
   276  			if params.Filters.Description != nil {
   277  				whereStatements = append(whereStatements, "places.description=:description")
   278  				sqlParams["Description"] = params.Filters.Description
   279  			}
   280  			if params.Filters.OwnedBy != nil {
   281  				joinQuery = " LEFT JOIN places_owned_by ON places.id = places_owned_by.place_id"
   282  				ownedByStatements := make([]string, len(params.Filters.OwnedBy))
   283  				for i, ownedBy := range params.Filters.OwnedBy {
   284  					ownedByRef := fmt.Sprintf("ownedBy%d", i)
   285  					ownedByStatements[i] = fmt.Sprintf("places_owned_by.user_id=:%s", ownedByRef)
   286  					sqlParams[ownedByRef] = ownedBy
   287  				}
   288  				whereStatements = append(whereStatements, fmt.Sprintf("(%s)", strings.Join(ownedByStatements, " OR ")))
   289  			}
   290  
   291  			if joinQuery != "" {
   292  				query += " " + joinQuery
   293  			}
   294  			if len(whereStatements) > 0 {
   295  				query += " WHERE " + strings.Join(whereStatements, " AND ")
   296  			}
   297  			if joinQuery != "" {
   298  				query += " GROUP BY places.id"
   299  			}
   300  		}
   301  
   302  		if params.Sort != "" {
   303  			sort, ok := sortableFields[params.Sort]
   304  			if !ok {
   305  				return nil, 0, fmt.Errorf("find places: invalid sort field: %s", params.Sort)
   306  			}
   307  			order := "ASC"
   308  			if params.Order == "DESC" {
   309  				order = "DESC"
   310  			}
   311  			query += fmt.Sprintf(" ORDER BY %s %s", sort, order)
   312  		}
   313  		if params.Limit > 0 {
   314  			query += fmt.Sprintf(" LIMIT %d", params.Limit)
   315  		}
   316  		if params.Offset > 0 {
   317  			query += fmt.Sprintf(" OFFSET %d", params.Offset)
   318  		}
   319  	}
   320  	rows, err := s.db.NamedQueryContext(ctx, query, sqlParams)
   321  	if err != nil && err != sql.ErrNoRows {
   322  		return nil, 0, fmt.Errorf("find places: %s", err)
   323  	}
   324  	defer rows.Close()
   325  	places := make([]*Place, 0)
   326  	for rows.Next() {
   327  		place := &Place{}
   328  		rows.StructScan(place)
   329  
   330  		places = append(places, place)
   331  	}
   332  
   333  	for i, place := range places {
   334  		err = s.findOwnedByForPlace(ctx, place)
   335  		if err != nil {
   336  			return nil, 0, err
   337  		}
   338  		places[i] = place
   339  	}
   340  
   341  	totalQuery := `
   342          SELECT COUNT(*) as total
   343          FROM places
   344      `
   345  	if len(whereStatements) > 0 {
   346  		if joinQuery != "" {
   347  			totalQuery += " " + joinQuery
   348  		}
   349  		totalQuery += " WHERE " + strings.Join(whereStatements, " AND ")
   350  		if joinQuery != "" {
   351  			totalQuery += " GROUP BY places.id"
   352  		}
   353  	}
   354  	totalPlaces := struct {
   355  		Total int `db:"total"`
   356  	}{}
   357  	rows2, err := s.db.NamedQueryContext(ctx, totalQuery, sqlParams)
   358  	if err != nil && err != sql.ErrNoRows {
   359  		return nil, 0, fmt.Errorf("find places: total: %s", err)
   360  	}
   361  	defer rows2.Close()
   362  	if rows2.Next() {
   363  		rows2.StructScan(&totalPlaces)
   364  	}
   365  
   366  	return places, totalPlaces.Total, nil
   367  }