github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/coders/standard_coders_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Unit tests for coders that must be consistent across all Beam SDKs. 19 """ 20 # pytype: skip-file 21 22 import json 23 import logging 24 import math 25 import os.path 26 import sys 27 import unittest 28 from copy import deepcopy 29 from typing import Dict 30 from typing import Tuple 31 32 import numpy as np 33 import yaml 34 from numpy.testing import assert_array_equal 35 36 from apache_beam.coders import coder_impl 37 from apache_beam.portability.api import beam_runner_api_pb2 38 from apache_beam.portability.api import schema_pb2 39 from apache_beam.runners import pipeline_context 40 from apache_beam.transforms import userstate 41 from apache_beam.transforms import window 42 from apache_beam.transforms.window import IntervalWindow 43 from apache_beam.typehints import schemas 44 from apache_beam.utils import windowed_value 45 from apache_beam.utils.sharded_key import ShardedKey 46 from apache_beam.utils.timestamp import Timestamp 47 from apache_beam.utils.windowed_value import PaneInfo 48 from apache_beam.utils.windowed_value import PaneInfoTiming 49 50 STANDARD_CODERS_YAML = os.path.normpath( 51 os.path.join( 52 os.path.dirname(__file__), '../portability/api/standard_coders.yaml')) 53 54 55 def _load_test_cases(test_yaml): 56 """Load test data from yaml file and return an iterable of test cases. 57 58 See ``standard_coders.yaml`` for more details. 59 """ 60 if not os.path.exists(test_yaml): 61 raise ValueError('Could not find the test spec: %s' % test_yaml) 62 with open(test_yaml, 'rb') as coder_spec: 63 for ix, spec in enumerate( 64 yaml.load_all(coder_spec, Loader=yaml.SafeLoader)): 65 spec['index'] = ix 66 name = spec.get('name', spec['coder']['urn'].split(':')[-2]) 67 yield [name, spec] 68 69 70 def parse_float(s): 71 x = float(s) 72 if math.isnan(x): 73 # In Windows, float('NaN') has opposite sign from other platforms. 74 # For the purpose of this test, we just need consistency. 75 x = abs(x) 76 return x 77 78 79 def value_parser_from_schema(schema): 80 def attribute_parser_from_type(type_): 81 parser = nonnull_attribute_parser_from_type(type_) 82 if type_.nullable: 83 return lambda x: None if x is None else parser(x) 84 else: 85 return parser 86 87 def nonnull_attribute_parser_from_type(type_): 88 # TODO: This should be exhaustive 89 type_info = type_.WhichOneof("type_info") 90 if type_info == "atomic_type": 91 if type_.atomic_type == schema_pb2.BYTES: 92 return lambda x: x.encode("utf-8") 93 else: 94 return schemas.ATOMIC_TYPE_TO_PRIMITIVE[type_.atomic_type] 95 elif type_info == "array_type": 96 element_parser = attribute_parser_from_type(type_.array_type.element_type) 97 return lambda x: list(map(element_parser, x)) 98 elif type_info == "map_type": 99 key_parser = attribute_parser_from_type(type_.map_type.key_type) 100 value_parser = attribute_parser_from_type(type_.map_type.value_type) 101 return lambda x: dict( 102 (key_parser(k), value_parser(v)) for k, v in x.items()) 103 elif type_info == "row_type": 104 return value_parser_from_schema(type_.row_type.schema) 105 elif type_info == "logical_type": 106 # In YAML logical types are represented with their representation types. 107 to_language_type = schemas.LogicalType.from_runner_api( 108 type_.logical_type).to_language_type 109 parse_representation = attribute_parser_from_type( 110 type_.logical_type.representation) 111 return lambda x: to_language_type(parse_representation(x)) 112 113 parsers = [(field.name, attribute_parser_from_type(field.type)) 114 for field in schema.fields] 115 116 constructor = schemas.named_tuple_from_schema(schema) 117 118 def value_parser(x): 119 result = [] 120 x = deepcopy(x) 121 for name, parser in parsers: 122 value = x.pop(name) 123 result.append(None if value is None else parser(value)) 124 125 if len(x): 126 raise ValueError( 127 "Test data contains attributes that don't exist in the schema: {}". 128 format(', '.join(x.keys()))) 129 130 return constructor(*result) 131 132 return value_parser 133 134 135 class StandardCodersTest(unittest.TestCase): 136 137 _urn_to_json_value_parser = { 138 'beam:coder:bytes:v1': lambda x: x.encode('utf-8'), 139 'beam:coder:bool:v1': lambda x: x, 140 'beam:coder:string_utf8:v1': lambda x: x, 141 'beam:coder:varint:v1': lambda x: x, 142 'beam:coder:kv:v1': lambda x, 143 key_parser, 144 value_parser: (key_parser(x['key']), value_parser(x['value'])), 145 'beam:coder:interval_window:v1': lambda x: IntervalWindow( 146 start=Timestamp(micros=(x['end'] - x['span']) * 1000), 147 end=Timestamp(micros=x['end'] * 1000)), 148 'beam:coder:iterable:v1': lambda x, 149 parser: list(map(parser, x)), 150 'beam:coder:state_backed_iterable:v1': lambda x, 151 parser: list(map(parser, x)), 152 'beam:coder:global_window:v1': lambda x: window.GlobalWindow(), 153 'beam:coder:windowed_value:v1': lambda x, 154 value_parser, 155 window_parser: windowed_value.create( 156 value_parser(x['value']), 157 x['timestamp'] * 1000, 158 tuple(window_parser(w) for w in x['windows'])), 159 'beam:coder:param_windowed_value:v1': lambda x, 160 value_parser, 161 window_parser: windowed_value.create( 162 value_parser(x['value']), 163 x['timestamp'] * 1000, 164 tuple(window_parser(w) for w in x['windows']), 165 PaneInfo( 166 x['pane']['is_first'], 167 x['pane']['is_last'], 168 PaneInfoTiming.from_string(x['pane']['timing']), 169 x['pane']['index'], 170 x['pane']['on_time_index'])), 171 'beam:coder:timer:v1': lambda x, 172 value_parser, 173 window_parser: userstate.Timer( 174 user_key=value_parser(x['userKey']), 175 dynamic_timer_tag=x['dynamicTimerTag'], 176 clear_bit=x['clearBit'], 177 windows=tuple(window_parser(w) for w in x['windows']), 178 fire_timestamp=None, 179 hold_timestamp=None, 180 paneinfo=None) if x['clearBit'] else userstate.Timer( 181 user_key=value_parser(x['userKey']), 182 dynamic_timer_tag=x['dynamicTimerTag'], 183 clear_bit=x['clearBit'], 184 fire_timestamp=Timestamp(micros=x['fireTimestamp'] * 1000), 185 hold_timestamp=Timestamp(micros=x['holdTimestamp'] * 1000), 186 windows=tuple(window_parser(w) for w in x['windows']), 187 paneinfo=PaneInfo( 188 x['pane']['is_first'], 189 x['pane']['is_last'], 190 PaneInfoTiming.from_string(x['pane']['timing']), 191 x['pane']['index'], 192 x['pane']['on_time_index'])), 193 'beam:coder:double:v1': parse_float, 194 'beam:coder:sharded_key:v1': lambda x, 195 value_parser: ShardedKey( 196 key=value_parser(x['key']), shard_id=x['shardId'].encode('utf-8')), 197 'beam:coder:custom_window:v1': lambda x, 198 window_parser: window_parser(x['window']), 199 'beam:coder:nullable:v1': lambda x, 200 value_parser: x.encode('utf-8') if x else None 201 } 202 203 def test_standard_coders(self): 204 for name, spec in _load_test_cases(STANDARD_CODERS_YAML): 205 logging.info('Executing %s test.', name) 206 self._run_standard_coder(name, spec) 207 208 def _run_standard_coder(self, name, spec): 209 def assert_equal(actual, expected): 210 """Handle nan values which self.assertEqual fails on.""" 211 if (isinstance(actual, float) and isinstance(expected, float) and 212 math.isnan(actual) and math.isnan(expected)): 213 return 214 self.assertEqual(actual, expected) 215 216 coder = self.parse_coder(spec['coder']) 217 parse_value = self.json_value_parser(spec['coder']) 218 nested_list = [spec['nested']] if 'nested' in spec else [True, False] 219 for nested in nested_list: 220 for expected_encoded, json_value in spec['examples'].items(): 221 value = parse_value(json_value) 222 expected_encoded = expected_encoded.encode('latin1') 223 if not spec['coder'].get('non_deterministic', False): 224 actual_encoded = encode_nested(coder, value, nested) 225 if self.fix and actual_encoded != expected_encoded: 226 self.to_fix[spec['index'], expected_encoded] = actual_encoded 227 else: 228 self.assertEqual(expected_encoded, actual_encoded) 229 decoded = decode_nested(coder, expected_encoded, nested) 230 assert_equal(decoded, value) 231 else: 232 # Only verify decoding for a non-deterministic coder 233 self.assertEqual( 234 decode_nested(coder, expected_encoded, nested), value) 235 236 if spec['coder']['urn'] == 'beam:coder:row:v1': 237 # Test batch encoding/decoding as well. 238 values = [ 239 parse_value(json_value) for json_value in spec['examples'].values() 240 ] 241 columnar = { 242 field.name: np.array([getattr(value, field.name) for value in values]) 243 for field in coder.schema.fields 244 } 245 dest = { 246 field: np.empty_like(values) 247 for field, values in columnar.items() 248 } 249 for column in dest.values(): 250 column[:] = 0 if 'int' in column.dtype.name else None 251 expected_encoded = ''.join(spec['examples'].keys()).encode('latin1') 252 actual_encoded = encode_batch(coder, columnar) 253 assert_equal(expected_encoded, actual_encoded) 254 decoded_count = decode_batch(coder, expected_encoded, dest) 255 assert_equal(len(spec['examples']), decoded_count) 256 for field, values in dest.items(): 257 assert_array_equal(columnar[field], dest[field]) 258 259 def parse_coder(self, spec): 260 context = pipeline_context.PipelineContext() 261 coder_id = str(hash(str(spec))) 262 component_ids = [ 263 context.coders.get_id(self.parse_coder(c)) 264 for c in spec.get('components', ()) 265 ] 266 if spec.get('state'): 267 268 def iterable_state_read(state_token, elem_coder): 269 state = spec.get('state').get(state_token.decode('latin1')) 270 if state is None: 271 state = '' 272 input_stream = coder_impl.create_InputStream(state.encode('latin1')) 273 while input_stream.size() > 0: 274 yield elem_coder.decode_from_stream(input_stream, True) 275 276 context.iterable_state_read = iterable_state_read 277 278 context.coders.put_proto( 279 coder_id, 280 beam_runner_api_pb2.Coder( 281 spec=beam_runner_api_pb2.FunctionSpec( 282 urn=spec['urn'], 283 payload=spec.get('payload', '').encode('latin1')), 284 component_coder_ids=component_ids)) 285 return context.coders.get_by_id(coder_id) 286 287 def json_value_parser(self, coder_spec): 288 # TODO: integrate this with the logic for the other parsers 289 if coder_spec['urn'] == 'beam:coder:row:v1': 290 schema = schema_pb2.Schema.FromString( 291 coder_spec['payload'].encode('latin1')) 292 return value_parser_from_schema(schema) 293 294 component_parsers = [ 295 self.json_value_parser(c) for c in coder_spec.get('components', ()) 296 ] 297 return lambda x: self._urn_to_json_value_parser[coder_spec['urn']]( 298 x, *component_parsers) 299 300 # Used when --fix is passed. 301 302 fix = False 303 to_fix = {} # type: Dict[Tuple[int, bytes], bytes] 304 305 @classmethod 306 def tearDownClass(cls): 307 if cls.fix and cls.to_fix: 308 print("FIXING", len(cls.to_fix), "TESTS") 309 doc_sep = '\n---\n' 310 docs = open(STANDARD_CODERS_YAML).read().split(doc_sep) 311 312 def quote(s): 313 return json.dumps(s.decode('latin1')).replace(r'\u0000', r'\0') 314 315 for (doc_ix, expected_encoded), actual_encoded in cls.to_fix.items(): 316 print(quote(expected_encoded), "->", quote(actual_encoded)) 317 docs[doc_ix] = docs[doc_ix].replace( 318 quote(expected_encoded) + ':', quote(actual_encoded) + ':') 319 open(STANDARD_CODERS_YAML, 'w').write(doc_sep.join(docs)) 320 321 322 def encode_nested(coder, value, nested=True): 323 out = coder_impl.create_OutputStream() 324 coder.get_impl().encode_to_stream(value, out, nested) 325 return out.get() 326 327 328 def decode_nested(coder, encoded, nested=True): 329 return coder.get_impl().decode_from_stream( 330 coder_impl.create_InputStream(encoded), nested) 331 332 333 def encode_batch(row_coder, values): 334 out = coder_impl.create_OutputStream() 335 row_coder.get_impl().encode_batch_to_stream(values, out) 336 return out.get() 337 338 339 def decode_batch(row_coder, encoded, dest): 340 return row_coder.get_impl().decode_batch_from_stream( 341 dest, coder_impl.create_InputStream(encoded)) 342 343 344 if __name__ == '__main__': 345 if '--fix' in sys.argv: 346 StandardCodersTest.fix = True 347 sys.argv.remove('--fix') 348 unittest.main()