github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/cookbook/group_with_coder_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Test for the custom coders example.""" 19 20 # pytype: skip-file 21 22 import logging 23 import unittest 24 import uuid 25 26 import pytest 27 28 from apache_beam.examples.cookbook import group_with_coder 29 from apache_beam.testing.test_pipeline import TestPipeline 30 from apache_beam.testing.test_utils import read_files_from_pattern 31 32 # Protect against environments where gcsio library is not available. 33 try: 34 from apache_beam.io.gcp import gcsio 35 except ImportError: 36 gcsio = None 37 38 # Patch group_with_coder.PlayerCoder.decode(). To test that the PlayerCoder was 39 # used, we do not strip the prepended 'x:' string when decoding a Player object. 40 group_with_coder.PlayerCoder.decode = lambda self, s: group_with_coder.Player( # type: ignore[assignment] 41 s.decode('utf-8')) 42 43 44 def create_content_input_file(path, records): 45 logging.info('Creating file: %s', path) 46 gcs = gcsio.GcsIO() 47 with gcs.open(path, 'w') as f: 48 for record in records: 49 f.write(b'%s\n' % record.encode('utf-8')) 50 return path 51 52 53 @unittest.skipIf(gcsio is None, 'GCP dependencies are not installed') 54 @pytest.mark.examples_postcommit 55 class GroupWithCoderTest(unittest.TestCase): 56 57 SAMPLE_RECORDS = [ 58 'joe,10', 59 'fred,3', 60 'mary,7', 61 'joe,20', 62 'fred,6', 63 'ann,5', 64 'joe,30', 65 'ann,10', 66 'mary,1' 67 ] 68 69 def setUp(self): 70 self.test_pipeline = TestPipeline(is_integration_test=True) 71 # Setup the file with expected content. 72 self.temp_location = self.test_pipeline.get_option('temp_location') 73 self.input_file = create_content_input_file( 74 '/'.join([self.temp_location, str(uuid.uuid4()), 'input.txt']), 75 self.SAMPLE_RECORDS) 76 77 #TODO(https://github.com/apache/beam/issues/23608) Fix and enable 78 @pytest.mark.sickbay_dataflow 79 def test_basics_with_type_check(self): 80 # Run the workflow with pipeline_type_check option. This will make sure 81 # the typehints associated with all transforms will have non-default values 82 # and therefore any custom coders will be used. In our case we want to make 83 # sure the coder for the Player class will be used. 84 output = '/'.join([self.temp_location, str(uuid.uuid4()), 'result']) 85 extra_opts = {'input': self.input_file, 'output': output} 86 group_with_coder.run( 87 self.test_pipeline.get_full_options_as_args(**extra_opts), 88 save_main_session=False) 89 # Parse result file and compare. 90 results = [] 91 lines = read_files_from_pattern('%s*' % output).splitlines() 92 for line in lines: 93 name, points = line.split(',') 94 results.append((name, int(points))) 95 logging.info('result: %s', results) 96 self.assertEqual( 97 sorted(results), 98 sorted([('x:ann', 15), ('x:fred', 9), ('x:joe', 60), ('x:mary', 8)])) 99 100 def test_basics_without_type_check(self): 101 # Run the workflow without pipeline_type_check option. This will make sure 102 # the typehints associated with all transforms will have default values and 103 # therefore any custom coders will not be used. The default coder (pickler) 104 # will be used instead. 105 output = '/'.join([self.temp_location, str(uuid.uuid4()), 'result']) 106 extra_opts = {'input': self.input_file, 'output': output} 107 with self.assertRaises(Exception) as context: 108 # yapf: disable 109 group_with_coder.run( 110 self.test_pipeline.get_full_options_as_args(**extra_opts) + 111 ['--no_pipeline_type_check'], 112 save_main_session=False) 113 self.assertIn('Unable to deterministically encode', str(context.exception)) 114 self.assertIn('CombinePerKey(sum)/GroupByKey', str(context.exception)) 115 116 117 if __name__ == '__main__': 118 logging.getLogger().setLevel(logging.INFO) 119 unittest.main()