github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/data_sampler_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  import unittest
    21  
    22  from apache_beam.coders import FastPrimitivesCoder
    23  from apache_beam.coders import WindowedValueCoder
    24  from apache_beam.coders.coders import Coder
    25  from apache_beam.runners.worker.data_sampler import DataSampler
    26  from apache_beam.runners.worker.data_sampler import OutputSampler
    27  from apache_beam.transforms.window import GlobalWindow
    28  from apache_beam.utils.windowed_value import WindowedValue
    29  
    30  
    31  class DataSamplerTest(unittest.TestCase):
    32    def test_single_output(self):
    33      """Simple test for a single sample."""
    34      data_sampler = DataSampler()
    35      coder = FastPrimitivesCoder()
    36  
    37      output_sampler = data_sampler.sample_output('1', coder)
    38      output_sampler.sample('a')
    39  
    40      self.assertEqual(data_sampler.samples(), {'1': [coder.encode_nested('a')]})
    41  
    42    def test_multiple_outputs(self):
    43      """Tests that multiple PCollections have their own sampler."""
    44      data_sampler = DataSampler()
    45      coder = FastPrimitivesCoder()
    46  
    47      data_sampler.sample_output('1', coder).sample('a')
    48      data_sampler.sample_output('2', coder).sample('a')
    49  
    50      self.assertEqual(
    51          data_sampler.samples(), {
    52              '1': [coder.encode_nested('a')], '2': [coder.encode_nested('a')]
    53          })
    54  
    55    def gen_samples(self, data_sampler: DataSampler, coder: Coder):
    56      data_sampler.sample_output('a', coder).sample('1')
    57      data_sampler.sample_output('a', coder).sample('2')
    58      data_sampler.sample_output('b', coder).sample('3')
    59      data_sampler.sample_output('b', coder).sample('4')
    60      data_sampler.sample_output('c', coder).sample('5')
    61      data_sampler.sample_output('c', coder).sample('6')
    62  
    63    def test_sample_filters_single_pcollection_ids(self):
    64      """Tests the samples can be filtered based on a single pcollection id."""
    65      data_sampler = DataSampler()
    66      coder = FastPrimitivesCoder()
    67  
    68      self.gen_samples(data_sampler, coder)
    69      self.assertEqual(
    70          data_sampler.samples(pcollection_ids=['a']),
    71          {'a': [coder.encode_nested('1'), coder.encode_nested('2')]})
    72  
    73      self.assertEqual(
    74          data_sampler.samples(pcollection_ids=['b']),
    75          {'b': [coder.encode_nested('3'), coder.encode_nested('4')]})
    76  
    77    def test_sample_filters_multiple_pcollection_ids(self):
    78      """Tests the samples can be filtered based on a multiple pcollection ids."""
    79      data_sampler = DataSampler()
    80      coder = FastPrimitivesCoder()
    81  
    82      self.gen_samples(data_sampler, coder)
    83      self.assertEqual(
    84          data_sampler.samples(pcollection_ids=['a', 'c']),
    85          {
    86              'a': [coder.encode_nested('1'), coder.encode_nested('2')],
    87              'c': [coder.encode_nested('5'), coder.encode_nested('6')]
    88          })
    89  
    90  
    91  class FakeClock:
    92    def __init__(self):
    93      self.clock = 0
    94  
    95    def time(self):
    96      return self.clock
    97  
    98  
    99  class OutputSamplerTest(unittest.TestCase):
   100    def setUp(self):
   101      self.fake_clock = FakeClock()
   102  
   103    def control_time(self, new_time):
   104      self.fake_clock.clock = new_time
   105  
   106    def test_samples_first_n(self):
   107      """Tests that the first elements are always sampled."""
   108      coder = FastPrimitivesCoder()
   109      sampler = OutputSampler(coder)
   110  
   111      for i in range(15):
   112        sampler.sample(i)
   113  
   114      self.assertEqual(
   115          sampler.flush(), [coder.encode_nested(i) for i in range(10)])
   116  
   117    def test_acts_like_circular_buffer(self):
   118      """Tests that the buffer overwrites old samples."""
   119      coder = FastPrimitivesCoder()
   120      sampler = OutputSampler(coder, max_samples=2)
   121  
   122      for i in range(10):
   123        sampler.sample(i)
   124  
   125      self.assertEqual(sampler.flush(), [coder.encode_nested(i) for i in (8, 9)])
   126  
   127    def test_samples_every_n_secs(self):
   128      """Tests that the buffer overwrites old samples."""
   129      coder = FastPrimitivesCoder()
   130      sampler = OutputSampler(
   131          coder, max_samples=1, sample_every_sec=10, clock=self.fake_clock)
   132  
   133      # Always samples the first ten.
   134      for i in range(10):
   135        sampler.sample(i)
   136      self.assertEqual(sampler.flush(), [coder.encode_nested(9)])
   137  
   138      # Start at t=0
   139      sampler.sample(10)
   140      self.assertEqual(len(sampler.flush()), 0)
   141  
   142      # Still not over threshold yet.
   143      self.control_time(9)
   144      for i in range(100):
   145        sampler.sample(i)
   146      self.assertEqual(len(sampler.flush()), 0)
   147  
   148      # First sample after 10s.
   149      self.control_time(10)
   150      sampler.sample(10)
   151      self.assertEqual(sampler.flush(), [coder.encode_nested(10)])
   152  
   153      # No samples between tresholds.
   154      self.control_time(15)
   155      for i in range(100):
   156        sampler.sample(i)
   157      self.assertEqual(len(sampler.flush()), 0)
   158  
   159      # Second sample after 20s.
   160      self.control_time(20)
   161      sampler.sample(11)
   162      self.assertEqual(sampler.flush(), [coder.encode_nested(11)])
   163  
   164    def test_can_sample_windowed_value(self):
   165      """Tests that values with WindowedValueCoders are sampled wholesale."""
   166      data_sampler = DataSampler()
   167      coder = WindowedValueCoder(FastPrimitivesCoder())
   168      value = WindowedValue('Hello, World!', 0, [GlobalWindow()])
   169      data_sampler.sample_output('1', coder).sample(value)
   170  
   171      self.assertEqual(
   172          data_sampler.samples(), {'1': [coder.encode_nested(value)]})
   173  
   174    def test_can_sample_non_windowed_value(self):
   175      """Tests that windowed values with WindowedValueCoders sample only the
   176      value.
   177  
   178      This is important because the Python SDK wraps all values in a WindowedValue
   179      even if the coder is not a WindowedValueCoder. In this case, the value must
   180      be retrieved from the WindowedValue to match the correct coder.
   181      """
   182      data_sampler = DataSampler()
   183      coder = FastPrimitivesCoder()
   184      data_sampler.sample_output('1', coder).sample(
   185          WindowedValue('Hello, World!', 0, [GlobalWindow()]))
   186  
   187      self.assertEqual(
   188          data_sampler.samples(), {'1': [coder.encode_nested('Hello, World!')]})
   189  
   190  
   191  if __name__ == '__main__':
   192    unittest.main()