code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/margin_modes_test.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore_test
    17  
    18  import (
    19  	"encoding/json"
    20  	"strings"
    21  	"testing"
    22  
    23  	"code.vegaprotocol.io/vega/datanode/entities"
    24  	"code.vegaprotocol.io/vega/datanode/sqlstore"
    25  	"code.vegaprotocol.io/vega/libs/num"
    26  	"code.vegaprotocol.io/vega/libs/ptr"
    27  	vegapb "code.vegaprotocol.io/vega/protos/vega"
    28  
    29  	"github.com/stretchr/testify/assert"
    30  	"github.com/stretchr/testify/require"
    31  	"golang.org/x/exp/slices"
    32  )
    33  
    34  func TestMarginModesStore(t *testing.T) {
    35  	ctx := tempTransaction(t)
    36  
    37  	marginModesStore := sqlstore.NewMarginModes(connectionSource)
    38  
    39  	market1 := entities.MarketID(GenerateID())
    40  	market2 := entities.MarketID(GenerateID())
    41  	party1 := entities.PartyID(GenerateID())
    42  	party2 := entities.PartyID(GenerateID())
    43  
    44  	marginMode11 := entities.PartyMarginMode{
    45  		MarketID:   market1,
    46  		PartyID:    party1,
    47  		MarginMode: vegapb.MarginMode_MARGIN_MODE_CROSS_MARGIN,
    48  		AtEpoch:    5,
    49  	}
    50  	marginMode12 := entities.PartyMarginMode{
    51  		MarketID:                   market1,
    52  		PartyID:                    party2,
    53  		MarginMode:                 vegapb.MarginMode_MARGIN_MODE_ISOLATED_MARGIN,
    54  		MarginFactor:               ptr.From(num.DecimalFromFloat(1.20)),
    55  		MinTheoreticalMarginFactor: ptr.From(num.DecimalFromFloat(1.21)),
    56  		MaxTheoreticalLeverage:     ptr.From(num.DecimalFromFloat(1.22)),
    57  		AtEpoch:                    6,
    58  	}
    59  	marginMode21 := entities.PartyMarginMode{
    60  		MarketID:                   market2,
    61  		PartyID:                    party1,
    62  		MarginMode:                 vegapb.MarginMode_MARGIN_MODE_ISOLATED_MARGIN,
    63  		MarginFactor:               ptr.From(num.DecimalFromFloat(2.10)),
    64  		MinTheoreticalMarginFactor: ptr.From(num.DecimalFromFloat(2.11)),
    65  		MaxTheoreticalLeverage:     ptr.From(num.DecimalFromFloat(2.12)),
    66  		AtEpoch:                    10,
    67  	}
    68  	marginMode22 := entities.PartyMarginMode{
    69  		MarketID:   market2,
    70  		PartyID:    party2,
    71  		MarginMode: vegapb.MarginMode_MARGIN_MODE_CROSS_MARGIN,
    72  		AtEpoch:    12,
    73  	}
    74  
    75  	t.Run("Inserting brand new market/party combination", func(t *testing.T) {
    76  		expectedMarginModes := []entities.PartyMarginMode{marginMode11, marginMode12, marginMode21, marginMode22}
    77  		sortMarginModes(&expectedMarginModes)
    78  
    79  		for _, mode := range expectedMarginModes {
    80  			require.NoError(t, marginModesStore.UpdatePartyMarginMode(ctx, mode))
    81  		}
    82  
    83  		foundMarginModes, _, err := marginModesStore.ListPartyMarginModes(ctx, entities.DefaultCursorPagination(false), sqlstore.ListPartyMarginModesFilters{})
    84  		require.NoError(t, err)
    85  		expectedStatsJson, _ := json.Marshal(expectedMarginModes)
    86  		statsJson, _ := json.Marshal(foundMarginModes)
    87  		assert.JSONEq(t, string(expectedStatsJson), string(statsJson))
    88  	})
    89  
    90  	marginMode11 = entities.PartyMarginMode{
    91  		MarketID:                   market1,
    92  		PartyID:                    party1,
    93  		MarginMode:                 vegapb.MarginMode_MARGIN_MODE_ISOLATED_MARGIN,
    94  		MarginFactor:               ptr.From(num.DecimalFromFloat(3.10)),
    95  		MinTheoreticalMarginFactor: ptr.From(num.DecimalFromFloat(3.11)),
    96  		MaxTheoreticalLeverage:     ptr.From(num.DecimalFromFloat(3.12)),
    97  		AtEpoch:                    6,
    98  	}
    99  
   100  	t.Run("Inserting brand new market/party combination", func(t *testing.T) {
   101  		require.NoError(t, marginModesStore.UpdatePartyMarginMode(ctx, marginMode11))
   102  
   103  		expectedMarginModes := []entities.PartyMarginMode{marginMode11, marginMode12, marginMode21, marginMode22}
   104  		sortMarginModes(&expectedMarginModes)
   105  
   106  		foundMarginModes, _, err := marginModesStore.ListPartyMarginModes(ctx,
   107  			entities.DefaultCursorPagination(false),
   108  			sqlstore.ListPartyMarginModesFilters{},
   109  		)
   110  		require.NoError(t, err)
   111  		expectedStatsJson, _ := json.Marshal(expectedMarginModes)
   112  		statsJson, _ := json.Marshal(foundMarginModes)
   113  		assert.JSONEq(t, string(expectedStatsJson), string(statsJson))
   114  	})
   115  
   116  	t.Run("Inserting an update on an existing combination", func(t *testing.T) {
   117  		expectedMarginModes := []entities.PartyMarginMode{marginMode11, marginMode12, marginMode21, marginMode22}
   118  		sortMarginModes(&expectedMarginModes)
   119  
   120  		for _, mode := range expectedMarginModes {
   121  			require.NoError(t, marginModesStore.UpdatePartyMarginMode(ctx, mode))
   122  		}
   123  
   124  		foundMarginModes, _, err := marginModesStore.ListPartyMarginModes(ctx,
   125  			entities.DefaultCursorPagination(false),
   126  			sqlstore.ListPartyMarginModesFilters{},
   127  		)
   128  		require.NoError(t, err)
   129  		expectedStatsJson, _ := json.Marshal(expectedMarginModes)
   130  		statsJson, _ := json.Marshal(foundMarginModes)
   131  		assert.JSONEq(t, string(expectedStatsJson), string(statsJson))
   132  	})
   133  
   134  	t.Run("Listing a margin mode for party", func(t *testing.T) {
   135  		expectedMarginModes := []entities.PartyMarginMode{marginMode11, marginMode21}
   136  		sortMarginModes(&expectedMarginModes)
   137  
   138  		foundMarginModes, _, err := marginModesStore.ListPartyMarginModes(ctx,
   139  			entities.DefaultCursorPagination(false),
   140  			sqlstore.ListPartyMarginModesFilters{
   141  				PartyID: ptr.From(party1),
   142  			},
   143  		)
   144  		require.NoError(t, err)
   145  		expectedStatsJson, _ := json.Marshal(expectedMarginModes)
   146  		statsJson, _ := json.Marshal(foundMarginModes)
   147  		assert.JSONEq(t, string(expectedStatsJson), string(statsJson))
   148  	})
   149  
   150  	t.Run("Listing a margin mode for market", func(t *testing.T) {
   151  		expectedMarginModes := []entities.PartyMarginMode{marginMode11, marginMode12}
   152  		sortMarginModes(&expectedMarginModes)
   153  
   154  		foundMarginModes, _, err := marginModesStore.ListPartyMarginModes(ctx,
   155  			entities.DefaultCursorPagination(false),
   156  			sqlstore.ListPartyMarginModesFilters{
   157  				MarketID: ptr.From(market1),
   158  			},
   159  		)
   160  		require.NoError(t, err)
   161  		expectedStatsJson, _ := json.Marshal(expectedMarginModes)
   162  		statsJson, _ := json.Marshal(foundMarginModes)
   163  		assert.JSONEq(t, string(expectedStatsJson), string(statsJson))
   164  	})
   165  
   166  	t.Run("Listing a margin mode for market and party", func(t *testing.T) {
   167  		expectedMarginModes := []entities.PartyMarginMode{marginMode11}
   168  		sortMarginModes(&expectedMarginModes)
   169  
   170  		foundMarginModes, _, err := marginModesStore.ListPartyMarginModes(ctx,
   171  			entities.DefaultCursorPagination(false),
   172  			sqlstore.ListPartyMarginModesFilters{
   173  				PartyID:  ptr.From(party1),
   174  				MarketID: ptr.From(market1),
   175  			},
   176  		)
   177  		require.NoError(t, err)
   178  		expectedStatsJson, _ := json.Marshal(expectedMarginModes)
   179  		statsJson, _ := json.Marshal(foundMarginModes)
   180  		assert.JSONEq(t, string(expectedStatsJson), string(statsJson))
   181  	})
   182  }
   183  
   184  func sortMarginModes(modes *[]entities.PartyMarginMode) {
   185  	slices.SortStableFunc(*modes, func(a, b entities.PartyMarginMode) int {
   186  		if a.MarketID == b.MarketID {
   187  			return strings.Compare(a.PartyID.String(), b.PartyID.String())
   188  		}
   189  		return strings.Compare(a.MarketID.String(), b.MarketID.String())
   190  	})
   191  }