github.com/shoshinnikita/budget-manager@v0.7.1-0.20220131195411-8c46ff1c6778/internal/db/base/search.go (about)

     1  package base
     2  
     3  import (
     4  	"context"
     5  	"strings"
     6  	"time"
     7  
     8  	common "github.com/ShoshinNikita/budget-manager/internal/db"
     9  	"github.com/ShoshinNikita/budget-manager/internal/db/base/internal/sqlx"
    10  	"github.com/ShoshinNikita/budget-manager/internal/db/base/types"
    11  	"github.com/ShoshinNikita/budget-manager/internal/pkg/money"
    12  )
    13  
    14  func (db DB) SearchSpends(ctx context.Context, args common.SearchSpendsArgs) ([]common.Spend, error) {
    15  	var spends []struct {
    16  		ID    uint         `db:"id"`
    17  		Year  int          `db:"year"`
    18  		Month time.Month   `db:"month"`
    19  		Day   int          `db:"day"`
    20  		Title string       `db:"title"`
    21  		Notes types.String `db:"notes"`
    22  		Cost  money.Money  `db:"cost"`
    23  
    24  		Type SpendType `db:"type"`
    25  	}
    26  	err := db.db.RunInTransaction(ctx, func(tx *sqlx.Tx) error {
    27  		query, sqlArgs := db.buildSearchSpendsQuery(args)
    28  		return tx.Select(&spends, query, sqlArgs...)
    29  	})
    30  	if err != nil {
    31  		return nil, err
    32  	}
    33  
    34  	// Convert the internal model to the common one
    35  	res := make([]common.Spend, 0, len(spends))
    36  	for _, s := range spends {
    37  		res = append(res, common.Spend{
    38  			ID:    s.ID,
    39  			Year:  s.Year,
    40  			Month: s.Month,
    41  			Day:   s.Day,
    42  			Title: s.Title,
    43  			Type:  s.Type.ToCommon(),
    44  			Notes: string(s.Notes),
    45  			Cost:  s.Cost,
    46  		})
    47  	}
    48  	return res, nil
    49  }
    50  
    51  // buildSearchSpendsQuery builds a query to search for spends
    52  //nolint:funlen
    53  func (DB) buildSearchSpendsQuery(args common.SearchSpendsArgs) (string, []interface{}) {
    54  	var (
    55  		wheres    []string
    56  		whereArgs []interface{}
    57  	)
    58  	addWhere := func(where string, args ...interface{}) {
    59  		wheres = append(wheres, where)
    60  		whereArgs = append(whereArgs, args...)
    61  	}
    62  
    63  	query := "SELECT "
    64  
    65  	query += strings.Join([]string{
    66  		`spend.id AS id`,
    67  		`month.year AS year`,
    68  		`month.month AS month`,
    69  		`day.day AS day`,
    70  		`spend.title AS title`,
    71  		`spend.notes AS notes`,
    72  		`spend.cost AS cost`,
    73  		`spend_type.id AS "type.id"`,
    74  		`spend_type.name AS "type.name"`,
    75  		`spend_type.parent_id AS "type.parent_id"`,
    76  	}, ", ")
    77  
    78  	query += " FROM spends AS spend "
    79  
    80  	query += strings.Join([]string{
    81  		`INNER JOIN days AS day ON day.id = spend.day_id`,
    82  		`INNER JOIN months AS month ON month.id = day.month_id`,
    83  		`LEFT JOIN spend_types AS spend_type ON spend_type.id = spend.type_id`,
    84  	}, " ")
    85  
    86  	if args.Title != "" {
    87  		title := "%" + args.Title + "%"
    88  		if args.TitleExactly {
    89  			title = args.Title
    90  		}
    91  		addWhere("LOWER(spend.title) LIKE ?", title)
    92  	}
    93  
    94  	if args.Notes != "" {
    95  		notes := "%" + args.Notes + "%"
    96  		if args.NotesExactly {
    97  			notes = args.Notes
    98  		}
    99  		addWhere("LOWER(spend.notes) LIKE ?", notes)
   100  	}
   101  
   102  	if q, args := getQueryToFilterByTime(args.After, args.Before); q != "" {
   103  		addWhere(q, args...)
   104  	}
   105  
   106  	switch {
   107  	case args.MinCost != 0 && args.MaxCost != 0:
   108  		addWhere("spend.cost BETWEEN ? AND ?", int(args.MinCost), int(args.MaxCost))
   109  	case args.MinCost != 0:
   110  		addWhere("spend.cost >= ?", int(args.MinCost))
   111  	case args.MaxCost != 0:
   112  		addWhere("spend.cost <= ?", int(args.MaxCost))
   113  	}
   114  
   115  	if len(args.TypeIDs) != 0 {
   116  		var (
   117  			orWheres    []string
   118  			typeIDsArgs []interface{}
   119  		)
   120  
   121  		typeIDs := args.TypeIDs
   122  		for i, id := range typeIDs {
   123  			if id == 0 {
   124  				// Search for spends without type
   125  				orWheres = append(orWheres, "spend.type_id IS NULL")
   126  				typeIDs = append(typeIDs[:i], typeIDs[i+1:]...)
   127  				break
   128  			}
   129  		}
   130  
   131  		if len(typeIDs) != 0 {
   132  			inPlaceholders := strings.Repeat("?,", len(typeIDs))
   133  			inPlaceholders = inPlaceholders[:len(inPlaceholders)-1]
   134  
   135  			orWheres = append(orWheres, "spend.type_id IN ("+inPlaceholders+")")
   136  			for _, id := range typeIDs {
   137  				typeIDsArgs = append(typeIDsArgs, int(id))
   138  			}
   139  		}
   140  
   141  		addWhere("("+strings.Join(orWheres, " OR ")+")", typeIDsArgs...)
   142  	}
   143  
   144  	var orders []string
   145  	switch args.Sort {
   146  	case common.SortSpendsByDate:
   147  		orders = []string{"month.year", "month.month", "day.day"}
   148  	case common.SortSpendsByTitle:
   149  		orders = []string{"spend.title"}
   150  	case common.SortSpendsByCost:
   151  		orders = []string{"spend.cost"}
   152  	}
   153  	if args.Order == common.OrderByDesc {
   154  		for i := range orders {
   155  			orders[i] += " DESC"
   156  		}
   157  	}
   158  	orders = append(orders, "spend.id")
   159  
   160  	// Build the final query
   161  	if len(wheres) != 0 {
   162  		query += " WHERE " + strings.Join(wheres, " AND ")
   163  	}
   164  	query += " ORDER BY " + strings.Join(orders, ", ")
   165  
   166  	return query, whereArgs
   167  }
   168  
   169  func getQueryToFilterByTime(after, before time.Time) (where string, args []interface{}) {
   170  	convertTime := func(t time.Time) int {
   171  		return t.Year()*10000 + int(t.Month())*100 + t.Day()
   172  	}
   173  
   174  	// It is a db-agnostic solution to compare dates
   175  	where = "month.year*10000 + month.month*100 + day.day"
   176  
   177  	switch {
   178  	case !after.IsZero() && !before.IsZero():
   179  		where += " BETWEEN ? AND ?"
   180  		args = []interface{}{convertTime(after), convertTime(before)}
   181  
   182  	case !after.IsZero():
   183  		where += " >= ?"
   184  		args = []interface{}{convertTime(after)}
   185  
   186  	case !before.IsZero():
   187  		where += " <= ?"
   188  		args = []interface{}{convertTime(before)}
   189  
   190  	default:
   191  		return "", nil
   192  	}
   193  	return where, args
   194  }