code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/margin_modes_test.go (about) 1 // Copyright (C) 2023 Gobalsky Labs Limited 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU Affero General Public License as 5 // published by the Free Software Foundation, either version 3 of the 6 // License, or (at your option) any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU Affero General Public License for more details. 12 // 13 // You should have received a copy of the GNU Affero General Public License 14 // along with this program. If not, see <http://www.gnu.org/licenses/>. 15 16 package sqlstore_test 17 18 import ( 19 "encoding/json" 20 "strings" 21 "testing" 22 23 "code.vegaprotocol.io/vega/datanode/entities" 24 "code.vegaprotocol.io/vega/datanode/sqlstore" 25 "code.vegaprotocol.io/vega/libs/num" 26 "code.vegaprotocol.io/vega/libs/ptr" 27 vegapb "code.vegaprotocol.io/vega/protos/vega" 28 29 "github.com/stretchr/testify/assert" 30 "github.com/stretchr/testify/require" 31 "golang.org/x/exp/slices" 32 ) 33 34 func TestMarginModesStore(t *testing.T) { 35 ctx := tempTransaction(t) 36 37 marginModesStore := sqlstore.NewMarginModes(connectionSource) 38 39 market1 := entities.MarketID(GenerateID()) 40 market2 := entities.MarketID(GenerateID()) 41 party1 := entities.PartyID(GenerateID()) 42 party2 := entities.PartyID(GenerateID()) 43 44 marginMode11 := entities.PartyMarginMode{ 45 MarketID: market1, 46 PartyID: party1, 47 MarginMode: vegapb.MarginMode_MARGIN_MODE_CROSS_MARGIN, 48 AtEpoch: 5, 49 } 50 marginMode12 := entities.PartyMarginMode{ 51 MarketID: market1, 52 PartyID: party2, 53 MarginMode: vegapb.MarginMode_MARGIN_MODE_ISOLATED_MARGIN, 54 MarginFactor: ptr.From(num.DecimalFromFloat(1.20)), 55 MinTheoreticalMarginFactor: ptr.From(num.DecimalFromFloat(1.21)), 56 MaxTheoreticalLeverage: ptr.From(num.DecimalFromFloat(1.22)), 57 AtEpoch: 6, 58 } 59 marginMode21 := entities.PartyMarginMode{ 60 MarketID: market2, 61 PartyID: party1, 62 MarginMode: vegapb.MarginMode_MARGIN_MODE_ISOLATED_MARGIN, 63 MarginFactor: ptr.From(num.DecimalFromFloat(2.10)), 64 MinTheoreticalMarginFactor: ptr.From(num.DecimalFromFloat(2.11)), 65 MaxTheoreticalLeverage: ptr.From(num.DecimalFromFloat(2.12)), 66 AtEpoch: 10, 67 } 68 marginMode22 := entities.PartyMarginMode{ 69 MarketID: market2, 70 PartyID: party2, 71 MarginMode: vegapb.MarginMode_MARGIN_MODE_CROSS_MARGIN, 72 AtEpoch: 12, 73 } 74 75 t.Run("Inserting brand new market/party combination", func(t *testing.T) { 76 expectedMarginModes := []entities.PartyMarginMode{marginMode11, marginMode12, marginMode21, marginMode22} 77 sortMarginModes(&expectedMarginModes) 78 79 for _, mode := range expectedMarginModes { 80 require.NoError(t, marginModesStore.UpdatePartyMarginMode(ctx, mode)) 81 } 82 83 foundMarginModes, _, err := marginModesStore.ListPartyMarginModes(ctx, entities.DefaultCursorPagination(false), sqlstore.ListPartyMarginModesFilters{}) 84 require.NoError(t, err) 85 expectedStatsJson, _ := json.Marshal(expectedMarginModes) 86 statsJson, _ := json.Marshal(foundMarginModes) 87 assert.JSONEq(t, string(expectedStatsJson), string(statsJson)) 88 }) 89 90 marginMode11 = entities.PartyMarginMode{ 91 MarketID: market1, 92 PartyID: party1, 93 MarginMode: vegapb.MarginMode_MARGIN_MODE_ISOLATED_MARGIN, 94 MarginFactor: ptr.From(num.DecimalFromFloat(3.10)), 95 MinTheoreticalMarginFactor: ptr.From(num.DecimalFromFloat(3.11)), 96 MaxTheoreticalLeverage: ptr.From(num.DecimalFromFloat(3.12)), 97 AtEpoch: 6, 98 } 99 100 t.Run("Inserting brand new market/party combination", func(t *testing.T) { 101 require.NoError(t, marginModesStore.UpdatePartyMarginMode(ctx, marginMode11)) 102 103 expectedMarginModes := []entities.PartyMarginMode{marginMode11, marginMode12, marginMode21, marginMode22} 104 sortMarginModes(&expectedMarginModes) 105 106 foundMarginModes, _, err := marginModesStore.ListPartyMarginModes(ctx, 107 entities.DefaultCursorPagination(false), 108 sqlstore.ListPartyMarginModesFilters{}, 109 ) 110 require.NoError(t, err) 111 expectedStatsJson, _ := json.Marshal(expectedMarginModes) 112 statsJson, _ := json.Marshal(foundMarginModes) 113 assert.JSONEq(t, string(expectedStatsJson), string(statsJson)) 114 }) 115 116 t.Run("Inserting an update on an existing combination", func(t *testing.T) { 117 expectedMarginModes := []entities.PartyMarginMode{marginMode11, marginMode12, marginMode21, marginMode22} 118 sortMarginModes(&expectedMarginModes) 119 120 for _, mode := range expectedMarginModes { 121 require.NoError(t, marginModesStore.UpdatePartyMarginMode(ctx, mode)) 122 } 123 124 foundMarginModes, _, err := marginModesStore.ListPartyMarginModes(ctx, 125 entities.DefaultCursorPagination(false), 126 sqlstore.ListPartyMarginModesFilters{}, 127 ) 128 require.NoError(t, err) 129 expectedStatsJson, _ := json.Marshal(expectedMarginModes) 130 statsJson, _ := json.Marshal(foundMarginModes) 131 assert.JSONEq(t, string(expectedStatsJson), string(statsJson)) 132 }) 133 134 t.Run("Listing a margin mode for party", func(t *testing.T) { 135 expectedMarginModes := []entities.PartyMarginMode{marginMode11, marginMode21} 136 sortMarginModes(&expectedMarginModes) 137 138 foundMarginModes, _, err := marginModesStore.ListPartyMarginModes(ctx, 139 entities.DefaultCursorPagination(false), 140 sqlstore.ListPartyMarginModesFilters{ 141 PartyID: ptr.From(party1), 142 }, 143 ) 144 require.NoError(t, err) 145 expectedStatsJson, _ := json.Marshal(expectedMarginModes) 146 statsJson, _ := json.Marshal(foundMarginModes) 147 assert.JSONEq(t, string(expectedStatsJson), string(statsJson)) 148 }) 149 150 t.Run("Listing a margin mode for market", func(t *testing.T) { 151 expectedMarginModes := []entities.PartyMarginMode{marginMode11, marginMode12} 152 sortMarginModes(&expectedMarginModes) 153 154 foundMarginModes, _, err := marginModesStore.ListPartyMarginModes(ctx, 155 entities.DefaultCursorPagination(false), 156 sqlstore.ListPartyMarginModesFilters{ 157 MarketID: ptr.From(market1), 158 }, 159 ) 160 require.NoError(t, err) 161 expectedStatsJson, _ := json.Marshal(expectedMarginModes) 162 statsJson, _ := json.Marshal(foundMarginModes) 163 assert.JSONEq(t, string(expectedStatsJson), string(statsJson)) 164 }) 165 166 t.Run("Listing a margin mode for market and party", func(t *testing.T) { 167 expectedMarginModes := []entities.PartyMarginMode{marginMode11} 168 sortMarginModes(&expectedMarginModes) 169 170 foundMarginModes, _, err := marginModesStore.ListPartyMarginModes(ctx, 171 entities.DefaultCursorPagination(false), 172 sqlstore.ListPartyMarginModesFilters{ 173 PartyID: ptr.From(party1), 174 MarketID: ptr.From(market1), 175 }, 176 ) 177 require.NoError(t, err) 178 expectedStatsJson, _ := json.Marshal(expectedMarginModes) 179 statsJson, _ := json.Marshal(foundMarginModes) 180 assert.JSONEq(t, string(expectedStatsJson), string(statsJson)) 181 }) 182 } 183 184 func sortMarginModes(modes *[]entities.PartyMarginMode) { 185 slices.SortStableFunc(*modes, func(a, b entities.PartyMarginMode) int { 186 if a.MarketID == b.MarketID { 187 return strings.Compare(a.PartyID.String(), b.PartyID.String()) 188 } 189 return strings.Compare(a.MarketID.String(), b.MarketID.String()) 190 }) 191 }