github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/sql/sql_chain_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Tests for sql_chain module."""
    19  
    20  # pytype: skip-file
    21  
    22  import unittest
    23  from unittest.mock import patch
    24  
    25  import pytest
    26  
    27  import apache_beam as beam
    28  from apache_beam.runners.interactive import interactive_environment as ie
    29  from apache_beam.runners.interactive.sql.sql_chain import SqlChain
    30  from apache_beam.runners.interactive.sql.sql_chain import SqlNode
    31  from apache_beam.runners.interactive.testing.mock_ipython import mock_get_ipython
    32  
    33  
    34  class SqlChainTest(unittest.TestCase):
    35    def test_init(self):
    36      chain = SqlChain()
    37      self.assertEqual({}, chain.nodes)
    38      self.assertIsNone(chain.root)
    39      self.assertIsNone(chain.current)
    40      self.assertIsNone(chain.user_pipeline)
    41  
    42    def test_append_first_node(self):
    43      node = SqlNode(output_name='first', source='a', query='q1')
    44      chain = SqlChain().append(node)
    45      self.assertIs(node, chain.get(node.output_name))
    46      self.assertIs(node, chain.root)
    47      self.assertIs(node, chain.current)
    48  
    49    def test_append_non_root_node(self):
    50      chain = SqlChain().append(
    51          SqlNode(output_name='root', source='root', query='q1'))
    52      self.assertIsNone(chain.root.next)
    53      node = SqlNode(output_name='next_node', source='root', query='q2')
    54      chain.append(node)
    55      self.assertIs(node, chain.root.next)
    56      self.assertIs(node, chain.get(node.output_name))
    57  
    58    @patch(
    59        'apache_beam.runners.interactive.sql.sql_chain.SchemaLoadedSqlTransform.'
    60        '__rrshift__')
    61    def test_to_pipeline_only_evaluate_once_per_pipeline_and_node(
    62        self, mocked_sql_transform):
    63      p = beam.Pipeline()
    64      ie.current_env().watch({'p': p})
    65      pcoll_1 = p | 'create pcoll_1' >> beam.Create([1, 2, 3])
    66      pcoll_2 = p | 'create pcoll_2' >> beam.Create([4, 5, 6])
    67      ie.current_env().watch({'pcoll_1': pcoll_1, 'pcoll_2': pcoll_2})
    68      node = SqlNode(
    69          output_name='root', source={'pcoll_1', 'pcoll_2'}, query='q1')
    70      chain = SqlChain(user_pipeline=p).append(node)
    71      _ = chain.to_pipeline()
    72      mocked_sql_transform.assert_called_once()
    73      _ = chain.to_pipeline()
    74      mocked_sql_transform.assert_called_once()
    75  
    76    @unittest.skipIf(
    77        not ie.current_env().is_interactive_ready,
    78        '[interactive] dependency is not installed.')
    79    @pytest.mark.skipif(
    80        not ie.current_env().is_interactive_ready,
    81        reason='[interactive] dependency is not installed.')
    82    @patch(
    83        'apache_beam.runners.interactive.sql.sql_chain.SchemaLoadedSqlTransform.'
    84        '__rrshift__')
    85    def test_nodes_with_same_outputs(self, mocked_sql_transform):
    86      p = beam.Pipeline()
    87      ie.current_env().watch({'p_nodes_with_same_output': p})
    88      pcoll = p | 'create pcoll' >> beam.Create([1, 2, 3])
    89      ie.current_env().watch({'pcoll': pcoll})
    90      chain = SqlChain(user_pipeline=p)
    91      output_name = 'output'
    92  
    93      with patch('IPython.get_ipython', new_callable=mock_get_ipython) as cell:
    94        with cell:
    95          node_cell_1 = SqlNode(output_name, source='pcoll', query='q1')
    96          chain.append(node_cell_1)
    97          _ = chain.to_pipeline()
    98          mocked_sql_transform.assert_called_with(
    99              'schema_loaded_beam_sql_output_1')
   100        with cell:
   101          node_cell_2 = SqlNode(output_name, source='pcoll', query='q2')
   102          chain.append(node_cell_2)
   103          _ = chain.to_pipeline()
   104          mocked_sql_transform.assert_called_with(
   105              'schema_loaded_beam_sql_output_2')
   106  
   107  
   108  if __name__ == '__main__':
   109    unittest.main()