github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/matrix_power.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """An example that computes the matrix power y = A^m * v.
    19  
    20  A is square matrix and v is a given vector with appropriate dimension.
    21  
    22  In this computation, each element of the matrix is represented by ((i,j), a)
    23  where a is the element in the i-th row and j-th column. Each element of the
    24  vector is computed as a PCollection (i, v) where v is the element of the i-th
    25  row. For multiplication, the vector is converted into a dict side input.
    26  """
    27  
    28  import argparse
    29  import logging
    30  
    31  import apache_beam as beam
    32  from apache_beam.options.pipeline_options import PipelineOptions
    33  from apache_beam.testing.test_pipeline import TestPipeline
    34  
    35  
    36  def extract_matrix(line):
    37    tokens = line.split(':')
    38    row = int(tokens[0])
    39    numbers = tokens[1].strip().split()
    40    for column, number in enumerate(numbers):
    41      yield ((row, column), float(number))
    42  
    43  
    44  def extract_vector(line):
    45    return enumerate(map(float, line.split()))
    46  
    47  
    48  def multiply_elements(element, vector):
    49    ((row, col), value) = element
    50    return (row, value * vector[col])
    51  
    52  
    53  def run(argv=None):
    54    parser = argparse.ArgumentParser()
    55    parser.add_argument(
    56        '--input_matrix', required=True, help='Input file containing the matrix.')
    57    parser.add_argument(
    58        '--input_vector',
    59        required=True,
    60        help='Input file containing initial vector.')
    61    parser.add_argument(
    62        '--output', required=True, help='Output file to write results to.')
    63    parser.add_argument(
    64        '--exponent',
    65        required=True,
    66        type=int,
    67        help='Exponent of input square matrix.')
    68    known_args, pipeline_args = parser.parse_known_args(argv)
    69  
    70    p = TestPipeline(options=PipelineOptions(pipeline_args))
    71  
    72    # Read the matrix from the input file and extract into the ((i,j), a) format.
    73    matrix = (
    74        p | 'read matrix' >> beam.io.ReadFromText(known_args.input_matrix)
    75        | 'extract matrix' >> beam.FlatMap(extract_matrix))
    76  
    77    # Read and extract the vector from its input file.
    78    vector = (
    79        p | 'read vector' >> beam.io.ReadFromText(known_args.input_vector)
    80        | 'extract vector' >> beam.FlatMap(extract_vector))
    81  
    82    for i in range(known_args.exponent):
    83      # Multiply the matrix by the current vector once,
    84      # and keep the resulting vector for the next iteration.
    85      vector = (
    86          matrix
    87          # Convert vector into side-input dictionary, compute the product.
    88          | 'multiply elements %d' % i >> beam.Map(
    89              multiply_elements, beam.pvalue.AsDict(vector))
    90          | 'sum element products %d' % i >> beam.CombinePerKey(sum))
    91  
    92    # Format and output final vector.
    93    _ = (
    94        vector  # pylint: disable=expression-not-assigned
    95        | 'format' >> beam.Map(repr)
    96        | 'write' >> beam.io.WriteToText(known_args.output))
    97  
    98    p.run()
    99  
   100  
   101  if __name__ == '__main__':
   102    logging.getLogger().setLevel(logging.INFO)
   103    run()