code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/risk_factor_test.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore_test
    17  
    18  import (
    19  	"testing"
    20  	"time"
    21  
    22  	"code.vegaprotocol.io/vega/datanode/entities"
    23  	"code.vegaprotocol.io/vega/datanode/sqlstore"
    24  	"code.vegaprotocol.io/vega/protos/vega"
    25  
    26  	"github.com/shopspring/decimal"
    27  	"github.com/stretchr/testify/assert"
    28  	"github.com/stretchr/testify/require"
    29  )
    30  
    31  func TestRiskFactors(t *testing.T) {
    32  	t.Run("Make sure you can update risk factors for a market and get latest values", testUpdateMarketRiskFactors)
    33  	t.Run("Upsert should insert risk factor", testAddRiskFactor)
    34  	t.Run("Upsert should update the risk factor if the market already exists in the same block", testUpsertDuplicateMarketInSameBlock)
    35  	t.Run("GetMarketRiskFactors returns the risk factors for the given market id", testGetMarketRiskFactors)
    36  }
    37  
    38  func setupRiskFactorTests(t *testing.T) (*sqlstore.Blocks, *sqlstore.RiskFactors) {
    39  	t.Helper()
    40  	bs := sqlstore.NewBlocks(connectionSource)
    41  	rfStore := sqlstore.NewRiskFactors(connectionSource)
    42  	return bs, rfStore
    43  }
    44  
    45  func testUpdateMarketRiskFactors(t *testing.T) {
    46  	ctx := tempTransaction(t)
    47  
    48  	bs, rfStore := setupRiskFactorTests(t)
    49  
    50  	// Add a risk factor for market 'aa' in one block
    51  
    52  	source := &testBlockSource{bs, time.Now()}
    53  	block := source.getNextBlock(t, ctx)
    54  	marketID := entities.MarketID("aa")
    55  	rf := entities.RiskFactor{
    56  		MarketID: marketID,
    57  		Short:    decimal.NewFromInt(100),
    58  		Long:     decimal.NewFromInt(200),
    59  		TxHash:   generateTxHash(),
    60  		VegaTime: block.VegaTime,
    61  	}
    62  	rfStore.Upsert(ctx, &rf)
    63  
    64  	// Check we get the same data back as we put in
    65  	fetched, err := rfStore.GetMarketRiskFactors(ctx, string(marketID))
    66  	require.NoError(t, err)
    67  	assert.Equal(t, fetched, rf)
    68  
    69  	// Upsert a new risk factor for the same in a different block
    70  	block2 := source.getNextBlock(t, ctx)
    71  	rf2 := entities.RiskFactor{
    72  		MarketID: marketID,
    73  		Short:    decimal.NewFromInt(101),
    74  		Long:     decimal.NewFromInt(202),
    75  		TxHash:   generateTxHash(),
    76  		VegaTime: block2.VegaTime,
    77  	}
    78  	rfStore.Upsert(ctx, &rf2)
    79  
    80  	// Check we get back the updated version
    81  	fetched, err = rfStore.GetMarketRiskFactors(ctx, string(marketID))
    82  	require.NoError(t, err)
    83  	assert.Equal(t, fetched, rf2)
    84  }
    85  
    86  func testAddRiskFactor(t *testing.T) {
    87  	ctx := tempTransaction(t)
    88  
    89  	bs, rfStore := setupRiskFactorTests(t)
    90  
    91  	var rowCount int
    92  	err := connectionSource.QueryRow(ctx, `select count(*) from risk_factors`).Scan(&rowCount)
    93  	assert.NoError(t, err)
    94  
    95  	block := addTestBlock(t, ctx, bs)
    96  	riskFactorProto := getRiskFactorProto()
    97  	riskFactor, err := entities.RiskFactorFromProto(riskFactorProto, generateTxHash(), block.VegaTime)
    98  	require.NoError(t, err, "Converting risk factor proto to database entity")
    99  
   100  	err = rfStore.Upsert(ctx, riskFactor)
   101  	require.NoError(t, err)
   102  
   103  	err = connectionSource.QueryRow(ctx, `select count(*) from risk_factors`).Scan(&rowCount)
   104  	assert.NoError(t, err)
   105  	assert.Equal(t, 1, rowCount)
   106  }
   107  
   108  func testUpsertDuplicateMarketInSameBlock(t *testing.T) {
   109  	ctx := tempTransaction(t)
   110  
   111  	bs, rfStore := setupRiskFactorTests(t)
   112  
   113  	var rowCount int
   114  	err := connectionSource.QueryRow(ctx, `select count(*) from risk_factors`).Scan(&rowCount)
   115  	assert.NoError(t, err)
   116  
   117  	block := addTestBlock(t, ctx, bs)
   118  	riskFactorProto := getRiskFactorProto()
   119  	riskFactor, err := entities.RiskFactorFromProto(riskFactorProto, generateTxHash(), block.VegaTime)
   120  	require.NoError(t, err, "Converting risk factor proto to database entity")
   121  
   122  	err = rfStore.Upsert(ctx, riskFactor)
   123  	require.NoError(t, err)
   124  
   125  	err = connectionSource.QueryRow(ctx, `select count(*) from risk_factors`).Scan(&rowCount)
   126  	assert.NoError(t, err)
   127  	assert.Equal(t, 1, rowCount)
   128  
   129  	err = rfStore.Upsert(ctx, riskFactor)
   130  	require.NoError(t, err)
   131  
   132  	err = connectionSource.QueryRow(ctx, `select count(*) from risk_factors`).Scan(&rowCount)
   133  	assert.NoError(t, err)
   134  	assert.Equal(t, 1, rowCount)
   135  }
   136  
   137  func getRiskFactorProto() *vega.RiskFactor {
   138  	return &vega.RiskFactor{
   139  		Market: "deadbeef",
   140  		Short:  "1000",
   141  		Long:   "1000",
   142  	}
   143  }
   144  
   145  func testGetMarketRiskFactors(t *testing.T) {
   146  	ctx := tempTransaction(t)
   147  
   148  	bs, rfStore := setupRiskFactorTests(t)
   149  
   150  	var rowCount int
   151  	err := connectionSource.QueryRow(ctx, `select count(*) from risk_factors`).Scan(&rowCount)
   152  	assert.NoError(t, err)
   153  
   154  	block := addTestBlock(t, ctx, bs)
   155  	riskFactorProto := getRiskFactorProto()
   156  	riskFactor, err := entities.RiskFactorFromProto(riskFactorProto, generateTxHash(), block.VegaTime)
   157  	require.NoError(t, err, "Converting risk factor proto to database entity")
   158  
   159  	err = rfStore.Upsert(ctx, riskFactor)
   160  	require.NoError(t, err)
   161  
   162  	err = connectionSource.QueryRow(ctx, `select count(*) from risk_factors`).Scan(&rowCount)
   163  	assert.NoError(t, err)
   164  	assert.Equal(t, 1, rowCount)
   165  
   166  	got, err := rfStore.GetMarketRiskFactors(ctx, "DEADBEEF")
   167  	assert.NoError(t, err)
   168  
   169  	want := *riskFactor
   170  
   171  	assert.Equal(t, want, got)
   172  }