github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/testing/extra_assertions.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  # pytype: skip-file
    18  
    19  import numpy as np
    20  
    21  
    22  class ExtraAssertionsMixin(object):
    23    def assertUnhashableCountEqual(self, data1, data2):
    24      """Assert that two containers have the same items, with special treatment
    25      for numpy arrays.
    26      """
    27      try:
    28        self.assertCountEqual(data1, data2)
    29      except (TypeError, ValueError):
    30        data1 = [self._to_hashable(d) for d in data1]
    31        data2 = [self._to_hashable(d) for d in data2]
    32        self.assertCountEqual(data1, data2)
    33  
    34    def _to_hashable(self, element):
    35      try:
    36        hash(element)
    37        return element
    38      except TypeError:
    39        pass
    40  
    41      if isinstance(element, list):
    42        return tuple(self._to_hashable(e) for e in element)
    43  
    44      if isinstance(element, dict):
    45        hashable_elements = []
    46        for key, value in sorted(element.items(), key=lambda t: hash(t[0])):
    47          hashable_elements.append((key, self._to_hashable(value)))
    48        return tuple(hashable_elements)
    49  
    50      if isinstance(element, np.ndarray):
    51        return element.tobytes()
    52  
    53      raise AssertionError("Encountered unhashable element: {}.".format(element))