github.com/shoshinnikita/budget-manager@v0.7.1-0.20220131195411-8c46ff1c6778/internal/db/base/month.go (about) 1 package base 2 3 import ( 4 "bytes" 5 "context" 6 "database/sql" 7 "fmt" 8 "strings" 9 "time" 10 11 common "github.com/ShoshinNikita/budget-manager/internal/db" 12 "github.com/ShoshinNikita/budget-manager/internal/db/base/internal/sqlx" 13 "github.com/ShoshinNikita/budget-manager/internal/pkg/errors" 14 "github.com/ShoshinNikita/budget-manager/internal/pkg/money" 15 ) 16 17 type MonthOverview struct { 18 ID uint `db:"id"` 19 Year int `db:"year"` 20 Month time.Month `db:"month"` 21 DailyBudget money.Money `db:"daily_budget"` 22 TotalIncome money.Money `db:"total_income"` 23 TotalSpend money.Money `db:"total_spend"` 24 Result money.Money `db:"result"` 25 } 26 27 func (m MonthOverview) ToCommon() common.MonthOverview { 28 return common.MonthOverview{ 29 ID: m.ID, 30 Year: m.Year, 31 Month: m.Month, 32 TotalIncome: m.TotalIncome, 33 TotalSpend: m.TotalSpend, 34 DailyBudget: m.DailyBudget, 35 Result: m.Result, 36 } 37 } 38 39 type Month struct { 40 MonthOverview 41 42 Incomes []Income `db:"-"` 43 MonthlyPayments []MonthlyPayment `db:"-"` 44 Days []Day `db:"-"` 45 } 46 47 func (m Month) ToCommon() common.Month { 48 return common.Month{ 49 MonthOverview: m.MonthOverview.ToCommon(), 50 // 51 Incomes: func() []common.Income { 52 incomes := make([]common.Income, 0, len(m.Incomes)) 53 for i := range m.Incomes { 54 incomes = append(incomes, m.Incomes[i].ToCommon(m.Year, m.Month)) 55 } 56 return incomes 57 }(), 58 MonthlyPayments: func() []common.MonthlyPayment { 59 mp := make([]common.MonthlyPayment, 0, len(m.MonthlyPayments)) 60 for i := range m.MonthlyPayments { 61 mp = append(mp, m.MonthlyPayments[i].ToCommon(m.Year, m.Month)) 62 } 63 return mp 64 }(), 65 Days: func() []common.Day { 66 days := make([]common.Day, 0, len(m.Days)) 67 for i := range m.Days { 68 days = append(days, m.Days[i].ToCommon(m.Year, m.Month)) 69 } 70 return days 71 }(), 72 } 73 } 74 75 func (db DB) GetMonthByDate(ctx context.Context, year int, month time.Month) (common.Month, error) { 76 var m Month 77 err := db.db.RunInTransaction(ctx, func(tx *sqlx.Tx) (err error) { 78 m, err = getFullMonth(tx, "year = ? AND month = ?", year, month) 79 return err 80 }) 81 if err != nil { 82 if errors.Is(err, sql.ErrNoRows) { 83 err = common.ErrMonthNotExist 84 } 85 return common.Month{}, err 86 } 87 88 return m.ToCommon(), nil 89 } 90 91 // GetMonths returns month overviews for passed years 92 func (db DB) GetMonths(ctx context.Context, years ...int) ([]common.MonthOverview, error) { 93 var m []MonthOverview 94 err := db.db.RunInTransaction(ctx, func(tx *sqlx.Tx) error { 95 return tx.SelectQuery(&m, sqlx.In(`SELECT * FROM months WHERE year IN (?) ORDER BY id ASC`, years)) 96 }) 97 if err != nil { 98 return nil, err 99 } 100 if len(m) == 0 { 101 return nil, nil 102 } 103 104 res := make([]common.MonthOverview, 0, len(m)) 105 for i := range m { 106 res = append(res, m[i].ToCommon()) 107 } 108 return res, nil 109 } 110 111 // InitMonth inits a month and days for the passed date 112 func (db *DB) InitMonth(ctx context.Context, year int, month time.Month) error { 113 return db.db.RunInTransaction(ctx, func(tx *sqlx.Tx) error { 114 var count int 115 err := tx.Get(&count, `SELECT COUNT(*) FROM months WHERE year = ? AND month = ?`, year, month) 116 if err != nil { 117 return errors.Wrap(err, "couldn't check if the current month exists") 118 } 119 if count != 0 { 120 // The month is already created 121 return nil 122 } 123 124 // We have to init the current month 125 126 var monthID uint 127 err = tx.Get(&monthID, `INSERT INTO months(year, month) VALUES(?, ?) RETURNING id`, year, month) 128 if err != nil { 129 return errors.Wrap(err, "couldn't init the current month") 130 } 131 132 daysNumber := daysInMonth(year, month) 133 134 query := `INSERT INTO days(month_id, day) VALUES ` + strings.Repeat("(?, ?), ", daysNumber) 135 query = query[:len(query)-2] 136 137 sqlArgs := make([]interface{}, 0, daysNumber*2) 138 for i := 0; i < daysNumber; i++ { 139 sqlArgs = append(sqlArgs, monthID, i+1) 140 } 141 if _, err = tx.Exec(query, sqlArgs...); err != nil { 142 return errors.Wrap(err, "couldn't insert days for the current month") 143 } 144 return nil 145 }) 146 } 147 148 func (db DB) recomputeAndUpdateMonth(tx *sqlx.Tx, monthID uint) (err error) { 149 defer func() { 150 if err != nil { 151 err = errors.Wrap(err, "couldn't recompute the month budget") 152 } 153 }() 154 155 m, err := getFullMonth(tx, "id = ?", monthID) 156 if err != nil { 157 return errors.Wrap(err, "couldn't select month") 158 } 159 160 m = recomputeMonth(m) 161 162 // Update Month 163 _, err = tx.Exec( 164 `UPDATE months SET daily_budget = ?, total_income = ?, total_spend = ?, result = ? WHERE id = ?`, 165 m.DailyBudget, m.TotalIncome, m.TotalSpend, m.Result, m.ID, 166 ) 167 168 // Update Days with the following db-agnostic query: 169 // 170 // UPDATE days SET saldo = CASE id 171 // WHEN 1 THEN 100 172 // WHEN 2 THEN 200 173 // WHEN 3 THEN 300 174 // END 175 // WHERE id IN (1, 2, 3) 176 // 177 var ( 178 query = &bytes.Buffer{} 179 dayIDs = make([]int, 0, len(m.Days)) 180 ) 181 query.WriteString("UPDATE days SET saldo = CASE id\n") 182 for _, day := range m.Days { 183 fmt.Fprintf(query, "WHEN %d THEN %d\n", day.ID, int(day.Saldo)) 184 dayIDs = append(dayIDs, int(day.ID)) 185 } 186 query.WriteString("END\n") 187 query.WriteString("WHERE id IN (?)") 188 189 if _, err := tx.ExecQuery(sqlx.In(query.String(), dayIDs)); err != nil { 190 return errors.Wrap(err, "couldn't update days") 191 } 192 193 return nil 194 } 195 196 func recomputeMonth(m Month) Month { 197 // Update Total Income 198 m.TotalIncome = 0 199 for _, in := range m.Incomes { 200 m.TotalIncome = m.TotalIncome.Add(in.Income) 201 } 202 203 // Update Total Spends and Daily Budget 204 205 var monthlyPaymentsCost money.Money 206 for _, mp := range m.MonthlyPayments { 207 monthlyPaymentsCost = monthlyPaymentsCost.Sub(mp.Cost) 208 } 209 210 var spendsCost money.Money 211 for _, day := range m.Days { 212 for _, spend := range day.Spends { 213 spendsCost = spendsCost.Sub(spend.Cost) 214 } 215 } 216 217 // Use "Add" because monthlyPaymentCost and TotalSpend are negative 218 m.DailyBudget = m.TotalIncome.Add(monthlyPaymentsCost).Div(int64(len(m.Days))) 219 m.TotalSpend = monthlyPaymentsCost.Add(spendsCost) 220 m.Result = m.TotalIncome.Add(m.TotalSpend) 221 222 // Update Saldos (it is accumulated) 223 saldo := m.DailyBudget 224 for i := range m.Days { 225 m.Days[i].Saldo = saldo 226 for _, spend := range m.Days[i].Spends { 227 m.Days[i].Saldo = m.Days[i].Saldo.Sub(spend.Cost) 228 } 229 saldo = m.Days[i].Saldo + m.DailyBudget 230 } 231 232 return m 233 } 234 235 func getFullMonth(tx *sqlx.Tx, whereCond string, args ...interface{}) (m Month, err error) { 236 err = tx.Get(&m.MonthOverview, `SELECT * from months WHERE `+whereCond, args...) 237 if err != nil { 238 return Month{}, errors.Wrap(err, "couldn't select month") 239 } 240 241 err = tx.Select(&m.Incomes, `SELECT * FROM incomes WHERE month_id = ? ORDER BY id`, m.ID) 242 if err != nil { 243 return Month{}, errors.Wrap(err, "couldn't select incomes") 244 } 245 err = tx.Select( 246 &m.MonthlyPayments, ` 247 SELECT 248 monthly_payments.*, 249 spend_types.id AS "type.id", 250 spend_types.name AS "type.name", 251 spend_types.parent_id AS "type.parent_id" 252 FROM monthly_payments 253 LEFT JOIN spend_types ON spend_types.id = monthly_payments.type_id 254 WHERE monthly_payments.month_id = ? 255 ORDER BY monthly_payments.id`, m.ID, 256 ) 257 if err != nil { 258 return Month{}, errors.Wrap(err, "couldn't select monthly payments") 259 } 260 261 err = tx.Select(&m.Days, `SELECT * FROM days WHERE month_id = ? ORDER BY day`, m.ID) 262 if err != nil { 263 return Month{}, errors.Wrap(err, "couldn't select days") 264 } 265 266 dayIndexes := make(map[uint]int) // day id -> slice index 267 dayIDs := make([]int, 0, len(m.Days)) 268 for i, d := range m.Days { 269 dayIndexes[d.ID] = i 270 dayIDs = append(dayIDs, int(d.ID)) 271 } 272 var allSpends []Spend 273 err = tx.SelectQuery(&allSpends, sqlx.In(` 274 SELECT 275 spends.*, 276 spend_types.id AS "type.id", 277 spend_types.name AS "type.name", 278 spend_types.parent_id AS "type.parent_id" 279 FROM spends 280 LEFT JOIN spend_types ON spend_types.id = spends.type_id 281 WHERE spends.day_id IN (?) 282 ORDER BY spends.id`, dayIDs, 283 )) 284 if err != nil { 285 return Month{}, errors.Wrap(err, "couldn't select spends") 286 } 287 for _, s := range allSpends { 288 dayIndex, ok := dayIndexes[s.DayID] 289 if !ok { 290 // Just in case 291 return Month{}, errors.Errorf("spend with id %d has unexpected day id: %d", s.ID, s.DayID) 292 } 293 m.Days[dayIndex].Spends = append(m.Days[dayIndex].Spends, s) 294 } 295 296 return m, nil 297 }