github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/sql/sql_chain.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Module for tracking a chain of beam_sql magics applied. 19 20 For internal use only; no backwards-compatibility guarantees. 21 """ 22 23 # pytype: skip-file 24 25 import importlib 26 import logging 27 from dataclasses import dataclass 28 from typing import Any 29 from typing import Dict 30 from typing import Optional 31 from typing import Set 32 from typing import Union 33 34 import apache_beam as beam 35 from apache_beam.internal import pickler 36 from apache_beam.runners.interactive.sql.utils import register_coder_for_schema 37 from apache_beam.runners.interactive.utils import create_var_in_main 38 from apache_beam.runners.interactive.utils import pcoll_by_name 39 from apache_beam.runners.interactive.utils import progress_indicated 40 from apache_beam.transforms.sql import SqlTransform 41 from apache_beam.utils.interactive_utils import is_in_ipython 42 43 _LOGGER = logging.getLogger(__name__) 44 45 46 @dataclass 47 class SqlNode: 48 """Each SqlNode represents a beam_sql magic applied. 49 50 Attributes: 51 output_name: the watched unique name of the beam_sql output. Can be used as 52 an identifier. 53 source: the inputs consumed by this node. Can be a pipeline or a set of 54 PCollections represented by their variable names watched. When it's a 55 pipeline, the node computes from raw values in the query, so the output 56 can be consumed by any SqlNode in any SqlChain. 57 query: the SQL query applied by this node. 58 schemas: the schemas (NamedTuple classes) used by this node. 59 evaluated: the pipelines this node has been evaluated for. 60 next: the next SqlNode applied chronologically. 61 execution_count: the execution count if in an IPython env. 62 """ 63 output_name: str 64 source: Union[beam.Pipeline, Set[str]] 65 query: str 66 schemas: Set[Any] = None 67 evaluated: Set[beam.Pipeline] = None 68 next: Optional['SqlNode'] = None 69 execution_count: int = 0 70 71 def __post_init__(self): 72 if not self.schemas: 73 self.schemas = set() 74 if not self.evaluated: 75 self.evaluated = set() 76 if is_in_ipython(): 77 from IPython import get_ipython 78 self.execution_count = get_ipython().execution_count 79 80 def __hash__(self): 81 return hash( 82 (self.output_name, self.source, self.query, self.execution_count)) 83 84 def to_pipeline(self, pipeline: Optional[beam.Pipeline]) -> beam.Pipeline: 85 """Converts the chain into an executable pipeline.""" 86 if pipeline not in self.evaluated: 87 # The whole chain should form a single pipeline. 88 source = self.source 89 if isinstance(self.source, beam.Pipeline): 90 if pipeline: # use the known pipeline 91 source = pipeline 92 else: # use the source pipeline 93 pipeline = self.source 94 else: 95 name_to_pcoll = pcoll_by_name() 96 if len(self.source) == 1: 97 source = name_to_pcoll.get(next(iter(self.source))) 98 else: 99 source = {s: name_to_pcoll.get(s) for s in self.source} 100 if isinstance(source, beam.Pipeline): 101 output = source | 'beam_sql_{}_{}'.format( 102 self.output_name, self.execution_count) >> SqlTransform(self.query) 103 else: 104 output = source | 'schema_loaded_beam_sql_{}_{}'.format( 105 self.output_name, self.execution_count 106 ) >> SchemaLoadedSqlTransform( 107 self.output_name, self.query, self.schemas, self.execution_count) 108 _ = create_var_in_main(self.output_name, output) 109 self.evaluated.add(pipeline) 110 if self.next: 111 return self.next.to_pipeline(pipeline) 112 else: 113 return pipeline 114 115 116 class SchemaLoadedSqlTransform(beam.PTransform): 117 """PTransform that loads schema before executing SQL. 118 119 When submitting a pipeline to remote runner for execution, schemas defined in 120 the main module are not available without save_main_session. However, 121 save_main_session might fail when there is anything unpicklable. This DoFn 122 makes sure only the schemas needed are pickled locally and restored later on 123 workers. 124 """ 125 def __init__(self, output_name, query, schemas, execution_count): 126 self.output_name = output_name 127 self.query = query 128 self.schemas = schemas 129 self.execution_count = execution_count 130 # TODO(BEAM-8123): clean up this attribute or the whole wrapper PTransform. 131 # Dill does not preserve everything. On the other hand, save_main_session 132 # is not stable. Until cloudpickle replaces dill in Beam, we work around 133 # it by explicitly pickling annotations and load schemas in remote main 134 # sessions. 135 self.schema_annotations = [s.__annotations__ for s in self.schemas] 136 137 class _SqlTransformDoFn(beam.DoFn): 138 """The DoFn yields all its input without any transform but a setup to 139 configure the main session.""" 140 def __init__(self, schemas, annotations): 141 self.pickled_schemas = [pickler.dumps(s) for s in schemas] 142 self.pickled_annotations = [pickler.dumps(a) for a in annotations] 143 144 def setup(self): 145 main_session = importlib.import_module('__main__') 146 for pickled_schema, pickled_annotation in zip( 147 self.pickled_schemas, self.pickled_annotations): 148 schema = pickler.loads(pickled_schema) 149 schema.__annotations__ = pickler.loads(pickled_annotation) 150 if not hasattr(main_session, schema.__name__) or not hasattr( 151 getattr(main_session, schema.__name__), '__annotations__'): 152 # Restore the schema in the main session on the [remote] worker. 153 setattr(main_session, schema.__name__, schema) 154 register_coder_for_schema(schema) 155 156 def process(self, e): 157 yield e 158 159 def expand(self, source): 160 """Applies the SQL transform. If a PCollection uses a schema defined in 161 the main session, use the additional DoFn to restore it on the worker.""" 162 if isinstance(source, dict): 163 schema_loaded = { 164 tag: pcoll | 'load_schemas_{}_tag_{}_{}'.format( 165 self.output_name, tag, self.execution_count) >> beam.ParDo( 166 self._SqlTransformDoFn(self.schemas, self.schema_annotations)) 167 if pcoll.element_type in self.schemas else pcoll 168 for tag, 169 pcoll in source.items() 170 } 171 elif isinstance(source, beam.pvalue.PCollection): 172 schema_loaded = source | 'load_schemas_{}_{}'.format( 173 self.output_name, self.execution_count) >> beam.ParDo( 174 self._SqlTransformDoFn(self.schemas, self.schema_annotations) 175 ) if source.element_type in self.schemas else source 176 else: 177 raise ValueError( 178 '{} should be either a single PCollection or a dict of named ' 179 'PCollections.'.format(source)) 180 return schema_loaded | 'beam_sql_{}_{}'.format( 181 self.output_name, self.execution_count) >> SqlTransform(self.query) 182 183 184 @dataclass 185 class SqlChain: 186 """A chain of SqlNodes. 187 188 Attributes: 189 nodes: all nodes by their output_names. 190 root: the first SqlNode applied chronologically. 191 current: the last node applied. 192 user_pipeline: the user defined pipeline this chain originates from. If 193 None, the whole chain just computes from raw values in queries. 194 Otherwise, at least some of the nodes in chain has queried against 195 PCollections. 196 """ 197 nodes: Dict[str, SqlNode] = None 198 root: Optional[SqlNode] = None 199 current: Optional[SqlNode] = None 200 user_pipeline: Optional[beam.Pipeline] = None 201 202 def __post_init__(self): 203 if not self.nodes: 204 self.nodes = {} 205 206 @progress_indicated 207 def to_pipeline(self) -> beam.Pipeline: 208 """Converts the chain into a beam pipeline.""" 209 pipeline_to_execute = self.root.to_pipeline(self.user_pipeline) 210 # The pipeline definitely contains external transform: SqlTransform. 211 pipeline_to_execute.contains_external_transforms = True 212 return pipeline_to_execute 213 214 def append(self, node: SqlNode) -> 'SqlChain': 215 """Appends a node to the chain.""" 216 if self.current: 217 self.current.next = node 218 else: 219 self.root = node 220 self.current = node 221 self.nodes[node.output_name] = node 222 return self 223 224 def get(self, output_name: str) -> Optional[SqlNode]: 225 """Gets a node from the chain based on the given output_name.""" 226 return self.nodes.get(output_name, None)