github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/sql/sql_chain_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Tests for sql_chain module.""" 19 20 # pytype: skip-file 21 22 import unittest 23 from unittest.mock import patch 24 25 import pytest 26 27 import apache_beam as beam 28 from apache_beam.runners.interactive import interactive_environment as ie 29 from apache_beam.runners.interactive.sql.sql_chain import SqlChain 30 from apache_beam.runners.interactive.sql.sql_chain import SqlNode 31 from apache_beam.runners.interactive.testing.mock_ipython import mock_get_ipython 32 33 34 class SqlChainTest(unittest.TestCase): 35 def test_init(self): 36 chain = SqlChain() 37 self.assertEqual({}, chain.nodes) 38 self.assertIsNone(chain.root) 39 self.assertIsNone(chain.current) 40 self.assertIsNone(chain.user_pipeline) 41 42 def test_append_first_node(self): 43 node = SqlNode(output_name='first', source='a', query='q1') 44 chain = SqlChain().append(node) 45 self.assertIs(node, chain.get(node.output_name)) 46 self.assertIs(node, chain.root) 47 self.assertIs(node, chain.current) 48 49 def test_append_non_root_node(self): 50 chain = SqlChain().append( 51 SqlNode(output_name='root', source='root', query='q1')) 52 self.assertIsNone(chain.root.next) 53 node = SqlNode(output_name='next_node', source='root', query='q2') 54 chain.append(node) 55 self.assertIs(node, chain.root.next) 56 self.assertIs(node, chain.get(node.output_name)) 57 58 @patch( 59 'apache_beam.runners.interactive.sql.sql_chain.SchemaLoadedSqlTransform.' 60 '__rrshift__') 61 def test_to_pipeline_only_evaluate_once_per_pipeline_and_node( 62 self, mocked_sql_transform): 63 p = beam.Pipeline() 64 ie.current_env().watch({'p': p}) 65 pcoll_1 = p | 'create pcoll_1' >> beam.Create([1, 2, 3]) 66 pcoll_2 = p | 'create pcoll_2' >> beam.Create([4, 5, 6]) 67 ie.current_env().watch({'pcoll_1': pcoll_1, 'pcoll_2': pcoll_2}) 68 node = SqlNode( 69 output_name='root', source={'pcoll_1', 'pcoll_2'}, query='q1') 70 chain = SqlChain(user_pipeline=p).append(node) 71 _ = chain.to_pipeline() 72 mocked_sql_transform.assert_called_once() 73 _ = chain.to_pipeline() 74 mocked_sql_transform.assert_called_once() 75 76 @unittest.skipIf( 77 not ie.current_env().is_interactive_ready, 78 '[interactive] dependency is not installed.') 79 @pytest.mark.skipif( 80 not ie.current_env().is_interactive_ready, 81 reason='[interactive] dependency is not installed.') 82 @patch( 83 'apache_beam.runners.interactive.sql.sql_chain.SchemaLoadedSqlTransform.' 84 '__rrshift__') 85 def test_nodes_with_same_outputs(self, mocked_sql_transform): 86 p = beam.Pipeline() 87 ie.current_env().watch({'p_nodes_with_same_output': p}) 88 pcoll = p | 'create pcoll' >> beam.Create([1, 2, 3]) 89 ie.current_env().watch({'pcoll': pcoll}) 90 chain = SqlChain(user_pipeline=p) 91 output_name = 'output' 92 93 with patch('IPython.get_ipython', new_callable=mock_get_ipython) as cell: 94 with cell: 95 node_cell_1 = SqlNode(output_name, source='pcoll', query='q1') 96 chain.append(node_cell_1) 97 _ = chain.to_pipeline() 98 mocked_sql_transform.assert_called_with( 99 'schema_loaded_beam_sql_output_1') 100 with cell: 101 node_cell_2 = SqlNode(output_name, source='pcoll', query='q2') 102 chain.append(node_cell_2) 103 _ = chain.to_pipeline() 104 mocked_sql_transform.assert_called_with( 105 'schema_loaded_beam_sql_output_2') 106 107 108 if __name__ == '__main__': 109 unittest.main()