github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/data_sampler_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 # pytype: skip-file 19 20 import unittest 21 22 from apache_beam.coders import FastPrimitivesCoder 23 from apache_beam.coders import WindowedValueCoder 24 from apache_beam.coders.coders import Coder 25 from apache_beam.runners.worker.data_sampler import DataSampler 26 from apache_beam.runners.worker.data_sampler import OutputSampler 27 from apache_beam.transforms.window import GlobalWindow 28 from apache_beam.utils.windowed_value import WindowedValue 29 30 31 class DataSamplerTest(unittest.TestCase): 32 def test_single_output(self): 33 """Simple test for a single sample.""" 34 data_sampler = DataSampler() 35 coder = FastPrimitivesCoder() 36 37 output_sampler = data_sampler.sample_output('1', coder) 38 output_sampler.sample('a') 39 40 self.assertEqual(data_sampler.samples(), {'1': [coder.encode_nested('a')]}) 41 42 def test_multiple_outputs(self): 43 """Tests that multiple PCollections have their own sampler.""" 44 data_sampler = DataSampler() 45 coder = FastPrimitivesCoder() 46 47 data_sampler.sample_output('1', coder).sample('a') 48 data_sampler.sample_output('2', coder).sample('a') 49 50 self.assertEqual( 51 data_sampler.samples(), { 52 '1': [coder.encode_nested('a')], '2': [coder.encode_nested('a')] 53 }) 54 55 def gen_samples(self, data_sampler: DataSampler, coder: Coder): 56 data_sampler.sample_output('a', coder).sample('1') 57 data_sampler.sample_output('a', coder).sample('2') 58 data_sampler.sample_output('b', coder).sample('3') 59 data_sampler.sample_output('b', coder).sample('4') 60 data_sampler.sample_output('c', coder).sample('5') 61 data_sampler.sample_output('c', coder).sample('6') 62 63 def test_sample_filters_single_pcollection_ids(self): 64 """Tests the samples can be filtered based on a single pcollection id.""" 65 data_sampler = DataSampler() 66 coder = FastPrimitivesCoder() 67 68 self.gen_samples(data_sampler, coder) 69 self.assertEqual( 70 data_sampler.samples(pcollection_ids=['a']), 71 {'a': [coder.encode_nested('1'), coder.encode_nested('2')]}) 72 73 self.assertEqual( 74 data_sampler.samples(pcollection_ids=['b']), 75 {'b': [coder.encode_nested('3'), coder.encode_nested('4')]}) 76 77 def test_sample_filters_multiple_pcollection_ids(self): 78 """Tests the samples can be filtered based on a multiple pcollection ids.""" 79 data_sampler = DataSampler() 80 coder = FastPrimitivesCoder() 81 82 self.gen_samples(data_sampler, coder) 83 self.assertEqual( 84 data_sampler.samples(pcollection_ids=['a', 'c']), 85 { 86 'a': [coder.encode_nested('1'), coder.encode_nested('2')], 87 'c': [coder.encode_nested('5'), coder.encode_nested('6')] 88 }) 89 90 91 class FakeClock: 92 def __init__(self): 93 self.clock = 0 94 95 def time(self): 96 return self.clock 97 98 99 class OutputSamplerTest(unittest.TestCase): 100 def setUp(self): 101 self.fake_clock = FakeClock() 102 103 def control_time(self, new_time): 104 self.fake_clock.clock = new_time 105 106 def test_samples_first_n(self): 107 """Tests that the first elements are always sampled.""" 108 coder = FastPrimitivesCoder() 109 sampler = OutputSampler(coder) 110 111 for i in range(15): 112 sampler.sample(i) 113 114 self.assertEqual( 115 sampler.flush(), [coder.encode_nested(i) for i in range(10)]) 116 117 def test_acts_like_circular_buffer(self): 118 """Tests that the buffer overwrites old samples.""" 119 coder = FastPrimitivesCoder() 120 sampler = OutputSampler(coder, max_samples=2) 121 122 for i in range(10): 123 sampler.sample(i) 124 125 self.assertEqual(sampler.flush(), [coder.encode_nested(i) for i in (8, 9)]) 126 127 def test_samples_every_n_secs(self): 128 """Tests that the buffer overwrites old samples.""" 129 coder = FastPrimitivesCoder() 130 sampler = OutputSampler( 131 coder, max_samples=1, sample_every_sec=10, clock=self.fake_clock) 132 133 # Always samples the first ten. 134 for i in range(10): 135 sampler.sample(i) 136 self.assertEqual(sampler.flush(), [coder.encode_nested(9)]) 137 138 # Start at t=0 139 sampler.sample(10) 140 self.assertEqual(len(sampler.flush()), 0) 141 142 # Still not over threshold yet. 143 self.control_time(9) 144 for i in range(100): 145 sampler.sample(i) 146 self.assertEqual(len(sampler.flush()), 0) 147 148 # First sample after 10s. 149 self.control_time(10) 150 sampler.sample(10) 151 self.assertEqual(sampler.flush(), [coder.encode_nested(10)]) 152 153 # No samples between tresholds. 154 self.control_time(15) 155 for i in range(100): 156 sampler.sample(i) 157 self.assertEqual(len(sampler.flush()), 0) 158 159 # Second sample after 20s. 160 self.control_time(20) 161 sampler.sample(11) 162 self.assertEqual(sampler.flush(), [coder.encode_nested(11)]) 163 164 def test_can_sample_windowed_value(self): 165 """Tests that values with WindowedValueCoders are sampled wholesale.""" 166 data_sampler = DataSampler() 167 coder = WindowedValueCoder(FastPrimitivesCoder()) 168 value = WindowedValue('Hello, World!', 0, [GlobalWindow()]) 169 data_sampler.sample_output('1', coder).sample(value) 170 171 self.assertEqual( 172 data_sampler.samples(), {'1': [coder.encode_nested(value)]}) 173 174 def test_can_sample_non_windowed_value(self): 175 """Tests that windowed values with WindowedValueCoders sample only the 176 value. 177 178 This is important because the Python SDK wraps all values in a WindowedValue 179 even if the coder is not a WindowedValueCoder. In this case, the value must 180 be retrieved from the WindowedValue to match the correct coder. 181 """ 182 data_sampler = DataSampler() 183 coder = FastPrimitivesCoder() 184 data_sampler.sample_output('1', coder).sample( 185 WindowedValue('Hello, World!', 0, [GlobalWindow()])) 186 187 self.assertEqual( 188 data_sampler.samples(), {'1': [coder.encode_nested('Hello, World!')]}) 189 190 191 if __name__ == '__main__': 192 unittest.main()