github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/pipeline_context_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Unit tests for the windowing classes.""" 19 20 # pytype: skip-file 21 22 import unittest 23 24 from apache_beam import coders 25 from apache_beam.runners import pipeline_context 26 from apache_beam.transforms import environments 27 28 29 class PipelineContextTest(unittest.TestCase): 30 def test_deduplication(self): 31 context = pipeline_context.PipelineContext() 32 bytes_coder_ref = context.coders.get_id(coders.BytesCoder()) 33 bytes_coder_ref2 = context.coders.get_id(coders.BytesCoder()) 34 self.assertEqual(bytes_coder_ref, bytes_coder_ref2) 35 36 def test_deduplication_by_proto(self): 37 context = pipeline_context.PipelineContext() 38 env_proto = environments.SubprocessSDKEnvironment( 39 command_string="foo").to_runner_api(None) 40 env_ref_1 = context.environments.get_by_proto(env_proto) 41 env_ref_2 = context.environments.get_by_proto(env_proto, deduplicate=True) 42 self.assertEqual(env_ref_1, env_ref_2) 43 44 def test_equal_environments_are_deduplicated_when_fetched_by_obj_or_proto( 45 self): 46 context = pipeline_context.PipelineContext() 47 48 env = environments.SubprocessSDKEnvironment(command_string="foo") 49 env_proto = env.to_runner_api(None) 50 id_from_proto = context.environments.get_by_proto(env_proto) 51 id_from_obj = context.environments.get_id(env) 52 self.assertEqual(id_from_obj, id_from_proto) 53 self.assertEqual( 54 context.environments.get_by_id(id_from_obj).command_string, "foo") 55 56 env = environments.SubprocessSDKEnvironment(command_string="bar") 57 env_proto = env.to_runner_api(None) 58 id_from_obj = context.environments.get_id(env) 59 id_from_proto = context.environments.get_by_proto( 60 env_proto, deduplicate=True) 61 self.assertEqual(id_from_obj, id_from_proto) 62 self.assertEqual( 63 context.environments.get_by_id(id_from_obj).command_string, "bar") 64 65 def test_serialization(self): 66 context = pipeline_context.PipelineContext() 67 float_coder_ref = context.coders.get_id(coders.FloatCoder()) 68 bytes_coder_ref = context.coders.get_id(coders.BytesCoder()) 69 proto = context.to_runner_api() 70 context2 = pipeline_context.PipelineContext.from_runner_api(proto) 71 self.assertEqual( 72 coders.FloatCoder(), context2.coders.get_by_id(float_coder_ref)) 73 self.assertEqual( 74 coders.BytesCoder(), context2.coders.get_by_id(bytes_coder_ref)) 75 76 def test_common_id_assignment(self): 77 context = pipeline_context.PipelineContext() 78 float_coder_ref = context.coders.get_id(coders.FloatCoder()) 79 bytes_coder_ref = context.coders.get_id(coders.BytesCoder()) 80 context2 = pipeline_context.PipelineContext( 81 component_id_map=context.component_id_map) 82 83 bytes_coder_ref2 = context2.coders.get_id(coders.BytesCoder()) 84 float_coder_ref2 = context2.coders.get_id(coders.FloatCoder()) 85 86 self.assertEqual(bytes_coder_ref, bytes_coder_ref2) 87 self.assertEqual(float_coder_ref, float_coder_ref2) 88 89 90 if __name__ == '__main__': 91 unittest.main()