github.com/shoshinnikita/budget-manager@v0.7.1-0.20220131195411-8c46ff1c6778/internal/db/base/month.go (about)

     1  package base
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql"
     7  	"fmt"
     8  	"strings"
     9  	"time"
    10  
    11  	common "github.com/ShoshinNikita/budget-manager/internal/db"
    12  	"github.com/ShoshinNikita/budget-manager/internal/db/base/internal/sqlx"
    13  	"github.com/ShoshinNikita/budget-manager/internal/pkg/errors"
    14  	"github.com/ShoshinNikita/budget-manager/internal/pkg/money"
    15  )
    16  
    17  type MonthOverview struct {
    18  	ID          uint        `db:"id"`
    19  	Year        int         `db:"year"`
    20  	Month       time.Month  `db:"month"`
    21  	DailyBudget money.Money `db:"daily_budget"`
    22  	TotalIncome money.Money `db:"total_income"`
    23  	TotalSpend  money.Money `db:"total_spend"`
    24  	Result      money.Money `db:"result"`
    25  }
    26  
    27  func (m MonthOverview) ToCommon() common.MonthOverview {
    28  	return common.MonthOverview{
    29  		ID:          m.ID,
    30  		Year:        m.Year,
    31  		Month:       m.Month,
    32  		TotalIncome: m.TotalIncome,
    33  		TotalSpend:  m.TotalSpend,
    34  		DailyBudget: m.DailyBudget,
    35  		Result:      m.Result,
    36  	}
    37  }
    38  
    39  type Month struct {
    40  	MonthOverview
    41  
    42  	Incomes         []Income         `db:"-"`
    43  	MonthlyPayments []MonthlyPayment `db:"-"`
    44  	Days            []Day            `db:"-"`
    45  }
    46  
    47  func (m Month) ToCommon() common.Month {
    48  	return common.Month{
    49  		MonthOverview: m.MonthOverview.ToCommon(),
    50  		//
    51  		Incomes: func() []common.Income {
    52  			incomes := make([]common.Income, 0, len(m.Incomes))
    53  			for i := range m.Incomes {
    54  				incomes = append(incomes, m.Incomes[i].ToCommon(m.Year, m.Month))
    55  			}
    56  			return incomes
    57  		}(),
    58  		MonthlyPayments: func() []common.MonthlyPayment {
    59  			mp := make([]common.MonthlyPayment, 0, len(m.MonthlyPayments))
    60  			for i := range m.MonthlyPayments {
    61  				mp = append(mp, m.MonthlyPayments[i].ToCommon(m.Year, m.Month))
    62  			}
    63  			return mp
    64  		}(),
    65  		Days: func() []common.Day {
    66  			days := make([]common.Day, 0, len(m.Days))
    67  			for i := range m.Days {
    68  				days = append(days, m.Days[i].ToCommon(m.Year, m.Month))
    69  			}
    70  			return days
    71  		}(),
    72  	}
    73  }
    74  
    75  func (db DB) GetMonthByDate(ctx context.Context, year int, month time.Month) (common.Month, error) {
    76  	var m Month
    77  	err := db.db.RunInTransaction(ctx, func(tx *sqlx.Tx) (err error) {
    78  		m, err = getFullMonth(tx, "year = ? AND month = ?", year, month)
    79  		return err
    80  	})
    81  	if err != nil {
    82  		if errors.Is(err, sql.ErrNoRows) {
    83  			err = common.ErrMonthNotExist
    84  		}
    85  		return common.Month{}, err
    86  	}
    87  
    88  	return m.ToCommon(), nil
    89  }
    90  
    91  // GetMonths returns month overviews for passed years
    92  func (db DB) GetMonths(ctx context.Context, years ...int) ([]common.MonthOverview, error) {
    93  	var m []MonthOverview
    94  	err := db.db.RunInTransaction(ctx, func(tx *sqlx.Tx) error {
    95  		return tx.SelectQuery(&m, sqlx.In(`SELECT * FROM months WHERE year IN (?) ORDER BY id ASC`, years))
    96  	})
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  	if len(m) == 0 {
   101  		return nil, nil
   102  	}
   103  
   104  	res := make([]common.MonthOverview, 0, len(m))
   105  	for i := range m {
   106  		res = append(res, m[i].ToCommon())
   107  	}
   108  	return res, nil
   109  }
   110  
   111  // InitMonth inits a month and days for the passed date
   112  func (db *DB) InitMonth(ctx context.Context, year int, month time.Month) error {
   113  	return db.db.RunInTransaction(ctx, func(tx *sqlx.Tx) error {
   114  		var count int
   115  		err := tx.Get(&count, `SELECT COUNT(*) FROM months WHERE year = ? AND month = ?`, year, month)
   116  		if err != nil {
   117  			return errors.Wrap(err, "couldn't check if the current month exists")
   118  		}
   119  		if count != 0 {
   120  			// The month is already created
   121  			return nil
   122  		}
   123  
   124  		// We have to init the current month
   125  
   126  		var monthID uint
   127  		err = tx.Get(&monthID, `INSERT INTO months(year, month) VALUES(?, ?) RETURNING id`, year, month)
   128  		if err != nil {
   129  			return errors.Wrap(err, "couldn't init the current month")
   130  		}
   131  
   132  		daysNumber := daysInMonth(year, month)
   133  
   134  		query := `INSERT INTO days(month_id, day) VALUES ` + strings.Repeat("(?, ?), ", daysNumber)
   135  		query = query[:len(query)-2]
   136  
   137  		sqlArgs := make([]interface{}, 0, daysNumber*2)
   138  		for i := 0; i < daysNumber; i++ {
   139  			sqlArgs = append(sqlArgs, monthID, i+1)
   140  		}
   141  		if _, err = tx.Exec(query, sqlArgs...); err != nil {
   142  			return errors.Wrap(err, "couldn't insert days for the current month")
   143  		}
   144  		return nil
   145  	})
   146  }
   147  
   148  func (db DB) recomputeAndUpdateMonth(tx *sqlx.Tx, monthID uint) (err error) {
   149  	defer func() {
   150  		if err != nil {
   151  			err = errors.Wrap(err, "couldn't recompute the month budget")
   152  		}
   153  	}()
   154  
   155  	m, err := getFullMonth(tx, "id = ?", monthID)
   156  	if err != nil {
   157  		return errors.Wrap(err, "couldn't select month")
   158  	}
   159  
   160  	m = recomputeMonth(m)
   161  
   162  	// Update Month
   163  	_, err = tx.Exec(
   164  		`UPDATE months SET daily_budget = ?, total_income = ?, total_spend = ?, result = ? WHERE id = ?`,
   165  		m.DailyBudget, m.TotalIncome, m.TotalSpend, m.Result, m.ID,
   166  	)
   167  
   168  	// Update Days with the following db-agnostic query:
   169  	//
   170  	//	UPDATE days SET saldo = CASE id
   171  	//		WHEN 1 THEN 100
   172  	//		WHEN 2 THEN 200
   173  	//		WHEN 3 THEN 300
   174  	//	END
   175  	//	WHERE id IN (1, 2, 3)
   176  	//
   177  	var (
   178  		query  = &bytes.Buffer{}
   179  		dayIDs = make([]int, 0, len(m.Days))
   180  	)
   181  	query.WriteString("UPDATE days SET saldo = CASE id\n")
   182  	for _, day := range m.Days {
   183  		fmt.Fprintf(query, "WHEN %d THEN %d\n", day.ID, int(day.Saldo))
   184  		dayIDs = append(dayIDs, int(day.ID))
   185  	}
   186  	query.WriteString("END\n")
   187  	query.WriteString("WHERE id IN (?)")
   188  
   189  	if _, err := tx.ExecQuery(sqlx.In(query.String(), dayIDs)); err != nil {
   190  		return errors.Wrap(err, "couldn't update days")
   191  	}
   192  
   193  	return nil
   194  }
   195  
   196  func recomputeMonth(m Month) Month {
   197  	// Update Total Income
   198  	m.TotalIncome = 0
   199  	for _, in := range m.Incomes {
   200  		m.TotalIncome = m.TotalIncome.Add(in.Income)
   201  	}
   202  
   203  	// Update Total Spends and Daily Budget
   204  
   205  	var monthlyPaymentsCost money.Money
   206  	for _, mp := range m.MonthlyPayments {
   207  		monthlyPaymentsCost = monthlyPaymentsCost.Sub(mp.Cost)
   208  	}
   209  
   210  	var spendsCost money.Money
   211  	for _, day := range m.Days {
   212  		for _, spend := range day.Spends {
   213  			spendsCost = spendsCost.Sub(spend.Cost)
   214  		}
   215  	}
   216  
   217  	// Use "Add" because monthlyPaymentCost and TotalSpend are negative
   218  	m.DailyBudget = m.TotalIncome.Add(monthlyPaymentsCost).Div(int64(len(m.Days)))
   219  	m.TotalSpend = monthlyPaymentsCost.Add(spendsCost)
   220  	m.Result = m.TotalIncome.Add(m.TotalSpend)
   221  
   222  	// Update Saldos (it is accumulated)
   223  	saldo := m.DailyBudget
   224  	for i := range m.Days {
   225  		m.Days[i].Saldo = saldo
   226  		for _, spend := range m.Days[i].Spends {
   227  			m.Days[i].Saldo = m.Days[i].Saldo.Sub(spend.Cost)
   228  		}
   229  		saldo = m.Days[i].Saldo + m.DailyBudget
   230  	}
   231  
   232  	return m
   233  }
   234  
   235  func getFullMonth(tx *sqlx.Tx, whereCond string, args ...interface{}) (m Month, err error) {
   236  	err = tx.Get(&m.MonthOverview, `SELECT * from months WHERE `+whereCond, args...)
   237  	if err != nil {
   238  		return Month{}, errors.Wrap(err, "couldn't select month")
   239  	}
   240  
   241  	err = tx.Select(&m.Incomes, `SELECT * FROM incomes WHERE month_id = ? ORDER BY id`, m.ID)
   242  	if err != nil {
   243  		return Month{}, errors.Wrap(err, "couldn't select incomes")
   244  	}
   245  	err = tx.Select(
   246  		&m.MonthlyPayments, `
   247  		SELECT
   248  			monthly_payments.*,
   249  			spend_types.id AS "type.id",
   250  			spend_types.name AS "type.name",
   251  			spend_types.parent_id AS "type.parent_id"
   252  		FROM monthly_payments
   253  		LEFT JOIN spend_types ON spend_types.id = monthly_payments.type_id
   254  		WHERE monthly_payments.month_id = ?
   255  		ORDER BY monthly_payments.id`, m.ID,
   256  	)
   257  	if err != nil {
   258  		return Month{}, errors.Wrap(err, "couldn't select monthly payments")
   259  	}
   260  
   261  	err = tx.Select(&m.Days, `SELECT * FROM days WHERE month_id = ? ORDER BY day`, m.ID)
   262  	if err != nil {
   263  		return Month{}, errors.Wrap(err, "couldn't select days")
   264  	}
   265  
   266  	dayIndexes := make(map[uint]int) // day id -> slice index
   267  	dayIDs := make([]int, 0, len(m.Days))
   268  	for i, d := range m.Days {
   269  		dayIndexes[d.ID] = i
   270  		dayIDs = append(dayIDs, int(d.ID))
   271  	}
   272  	var allSpends []Spend
   273  	err = tx.SelectQuery(&allSpends, sqlx.In(`
   274  		SELECT
   275  			spends.*,
   276  			spend_types.id AS "type.id",
   277  			spend_types.name AS "type.name",
   278  			spend_types.parent_id AS "type.parent_id"
   279  		FROM spends
   280  		LEFT JOIN spend_types ON spend_types.id = spends.type_id
   281  		WHERE spends.day_id IN (?)
   282  		ORDER BY spends.id`, dayIDs,
   283  	))
   284  	if err != nil {
   285  		return Month{}, errors.Wrap(err, "couldn't select spends")
   286  	}
   287  	for _, s := range allSpends {
   288  		dayIndex, ok := dayIndexes[s.DayID]
   289  		if !ok {
   290  			// Just in case
   291  			return Month{}, errors.Errorf("spend with id %d has unexpected day id: %d", s.ID, s.DayID)
   292  		}
   293  		m.Days[dayIndex].Spends = append(m.Days[dayIndex].Spends, s)
   294  	}
   295  
   296  	return m, nil
   297  }