github.com/grafana/pyroscope@v1.18.0/pkg/segmentwriter/client/distributor/placement/adaptiveplacement/store_test.go (about)

     1  package adaptiveplacement
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"testing"
     9  
    10  	"github.com/stretchr/testify/mock"
    11  	"github.com/stretchr/testify/suite"
    12  
    13  	"github.com/grafana/pyroscope/pkg/segmentwriter/client/distributor/placement/adaptiveplacement/adaptive_placementpb"
    14  	"github.com/grafana/pyroscope/pkg/test/mocks/mockobjstore"
    15  )
    16  
    17  type storeSuite struct {
    18  	suite.Suite
    19  
    20  	bucket *mockobjstore.MockBucket
    21  	store  *BucketStore
    22  }
    23  
    24  func (s *storeSuite) SetupTest() {
    25  	s.bucket = new(mockobjstore.MockBucket)
    26  	s.store = NewStore(s.bucket)
    27  }
    28  
    29  func Test_StoreSuite(t *testing.T) { suite.Run(t, new(storeSuite)) }
    30  
    31  func (s *storeSuite) Test_LoadRules() {
    32  	rules := &adaptive_placementpb.PlacementRules{CreatedAt: 1}
    33  	s.bucket.On("Get", mock.Anything, rulesFilePath).
    34  		Return(s.marshal(rules), nil)
    35  	loaded, err := s.store.LoadRules(context.Background())
    36  	s.NoError(err)
    37  	s.Equal(rules, loaded)
    38  	s.bucket.AssertExpectations(s.T())
    39  }
    40  
    41  func (s *storeSuite) Test_LoadRules_not_found() {
    42  	notFound := fmt.Errorf("not found")
    43  	s.bucket.On("Get", mock.Anything, rulesFilePath).
    44  		Return(nil, notFound)
    45  	s.bucket.On("IsObjNotFoundErr", notFound).Return(true)
    46  	_, err := s.store.LoadRules(context.Background())
    47  	s.ErrorIs(err, ErrRulesNotFound)
    48  	s.bucket.AssertExpectations(s.T())
    49  }
    50  
    51  func (s *storeSuite) Test_StoreRules() {
    52  	rules := &adaptive_placementpb.PlacementRules{CreatedAt: 1}
    53  	s.bucket.On("Upload", mock.Anything, rulesFilePath, mock.Anything).
    54  		Run(func(args mock.Arguments) {
    55  			var stored adaptive_placementpb.PlacementRules
    56  			s.unmarshal(args[2].(io.Reader), &stored)
    57  			s.Equal(rules, &stored)
    58  		}).
    59  		Return(nil).
    60  		Once()
    61  	s.NoError(s.store.StoreRules(context.Background(), rules))
    62  	s.bucket.AssertExpectations(s.T())
    63  }
    64  
    65  func (s *storeSuite) Test_LoadStats() {
    66  	stats := &adaptive_placementpb.DistributionStats{CreatedAt: 1}
    67  	s.bucket.On("Get", mock.Anything, statsFilePath).
    68  		Return(s.marshal(stats), nil)
    69  	loaded, err := s.store.LoadStats(context.Background())
    70  	s.NoError(err)
    71  	s.Equal(stats, loaded)
    72  	s.bucket.AssertExpectations(s.T())
    73  }
    74  
    75  func (s *storeSuite) Test_LoadStats_not_found() {
    76  	notFound := fmt.Errorf("not found")
    77  	s.bucket.On("Get", mock.Anything, statsFilePath).
    78  		Return(nil, notFound)
    79  	s.bucket.On("IsObjNotFoundErr", notFound).Return(true)
    80  	_, err := s.store.LoadStats(context.Background())
    81  	s.ErrorIs(err, ErrStatsNotFound)
    82  	s.bucket.AssertExpectations(s.T())
    83  }
    84  
    85  func (s *storeSuite) Test_StoreStats() {
    86  	stats := &adaptive_placementpb.DistributionStats{CreatedAt: 1}
    87  	s.bucket.On("Upload", mock.Anything, statsFilePath, mock.Anything).
    88  		Run(func(args mock.Arguments) {
    89  			var stored adaptive_placementpb.DistributionStats
    90  			s.unmarshal(args[2].(io.Reader), &stored)
    91  			s.Equal(stats, &stored)
    92  		}).
    93  		Return(nil).
    94  		Once()
    95  	s.NoError(s.store.StoreStats(context.Background(), stats))
    96  	s.bucket.AssertExpectations(s.T())
    97  }
    98  
    99  func (s *storeSuite) marshal(m vtProtoMessage) io.ReadCloser {
   100  	b, err := m.MarshalVT()
   101  	s.Require().NoError(err)
   102  	return io.NopCloser(bytes.NewBuffer(b))
   103  }
   104  
   105  func (s *storeSuite) unmarshal(r io.Reader, m vtProtoMessage) {
   106  	b, err := io.ReadAll(r)
   107  	s.Require().NoError(err)
   108  	s.Require().NoError(m.UnmarshalVT(b))
   109  }