github.com/apache/beam/sdks/v2@v2.48.2/go/test/integration/xlang/xlang_test.go (about)

     1  // Licensed to the Apache Software Foundation (ASF) under one or more
     2  // contributor license agreements.  See the NOTICE file distributed with
     3  // this work for additional information regarding copyright ownership.
     4  // The ASF licenses this file to You under the Apache License, Version 2.0
     5  // (the "License"); you may not use this file except in compliance with
     6  // the License.  You may obtain a copy of the License at
     7  //
     8  //    http://www.apache.org/licenses/LICENSE-2.0
     9  //
    10  // Unless required by applicable law or agreed to in writing, software
    11  // distributed under the License is distributed on an "AS IS" BASIS,
    12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  // See the License for the specific language governing permissions and
    14  // limitations under the License.
    15  
    16  package xlang
    17  
    18  import (
    19  	"flag"
    20  	"fmt"
    21  	"log"
    22  	"reflect"
    23  	"sort"
    24  	"testing"
    25  
    26  	"github.com/apache/beam/sdks/v2/go/examples/xlang"
    27  	"github.com/apache/beam/sdks/v2/go/pkg/beam"
    28  	_ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/dataflow"
    29  	_ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/flink"
    30  	_ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/samza"
    31  	_ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/spark"
    32  	"github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert"
    33  	"github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest"
    34  	"github.com/apache/beam/sdks/v2/go/test/integration"
    35  )
    36  
    37  var expansionAddr string // Populate with expansion address labelled "test".
    38  
    39  func init() {
    40  	beam.RegisterType(reflect.TypeOf((*IntString)(nil)).Elem())
    41  	beam.RegisterType(reflect.TypeOf((*StringInt)(nil)).Elem())
    42  	beam.RegisterFunction(formatIntStringsFn)
    43  	beam.RegisterFunction(formatStringIntFn)
    44  	beam.RegisterFunction(formatStringIntsFn)
    45  	beam.RegisterFunction(formatIntFn)
    46  	beam.RegisterFunction(getIntString)
    47  	beam.RegisterFunction(getStringInt)
    48  	beam.RegisterFunction(sumCounts)
    49  	beam.RegisterFunction(collectValues)
    50  }
    51  
    52  func checkFlags(t *testing.T) {
    53  	if expansionAddr == "" {
    54  		t.Skip("No Test expansion address provided.")
    55  	}
    56  }
    57  
    58  // formatIntStringsFn is a DoFn that formats an int64 and a list of strings.
    59  func formatIntStringsFn(i int64, s []string) string {
    60  	sort.Strings(s)
    61  	return fmt.Sprintf("%v:%v", i, s)
    62  }
    63  
    64  // formatStringIntFn is a DoFn that formats a string and an int64.
    65  func formatStringIntFn(s string, i int64) string {
    66  	return fmt.Sprintf("%s:%v", s, i)
    67  }
    68  
    69  // formatStringIntsFn is a DoFn that formats a string and a list of ints.
    70  func formatStringIntsFn(s string, i []int) string {
    71  	sort.Ints(i)
    72  	return fmt.Sprintf("%v:%v", s, i)
    73  }
    74  
    75  // formatIntFn is a DoFn that formats an int64 as a string.
    76  func formatIntFn(i int64) string {
    77  	return fmt.Sprintf("%v", i)
    78  }
    79  
    80  // IntString used to represent KV PCollection values of int64, string.
    81  type IntString struct {
    82  	X int64
    83  	Y string
    84  }
    85  
    86  func getIntString(kv IntString, emit func(int64, string)) {
    87  	emit(kv.X, kv.Y)
    88  }
    89  
    90  // StringInt used to represent KV PCollection values of string, int64.
    91  type StringInt struct {
    92  	X string
    93  	Y int64
    94  }
    95  
    96  func getStringInt(kv StringInt, emit func(string, int64)) {
    97  	emit(kv.X, kv.Y)
    98  }
    99  
   100  func sumCounts(key int64, iter1 func(*string) bool) (int64, []string) {
   101  	var val string
   102  	var values []string
   103  
   104  	for iter1(&val) {
   105  		values = append(values, val)
   106  	}
   107  	return key, values
   108  }
   109  
   110  func collectValues(key string, iter func(*int64) bool) (string, []int) {
   111  	var count int64
   112  	var values []int
   113  	for iter(&count) {
   114  		values = append(values, int(count))
   115  	}
   116  	return key, values
   117  }
   118  
   119  func TestXLang_Prefix(t *testing.T) {
   120  	integration.CheckFilters(t)
   121  	checkFlags(t)
   122  
   123  	p := beam.NewPipeline()
   124  	s := p.Root()
   125  
   126  	// Using the cross-language transform
   127  	strings := beam.Create(s, "a", "b", "c")
   128  	prefixed := xlang.Prefix(s, "prefix_", expansionAddr, strings)
   129  	passert.Equals(s, prefixed, "prefix_a", "prefix_b", "prefix_c")
   130  
   131  	ptest.RunAndValidate(t, p)
   132  }
   133  
   134  func TestXLang_CoGroupBy(t *testing.T) {
   135  	integration.CheckFilters(t)
   136  	checkFlags(t)
   137  
   138  	p := beam.NewPipeline()
   139  	s := p.Root()
   140  
   141  	// Using the cross-language transform
   142  	col1 := beam.ParDo(s, getIntString, beam.Create(s, IntString{X: 0, Y: "1"}, IntString{X: 0, Y: "2"}, IntString{X: 1, Y: "3"}))
   143  	col2 := beam.ParDo(s, getIntString, beam.Create(s, IntString{X: 0, Y: "4"}, IntString{X: 1, Y: "5"}, IntString{X: 1, Y: "6"}))
   144  	c := xlang.CoGroupByKey(s, expansionAddr, col1, col2)
   145  	sums := beam.ParDo(s, sumCounts, c)
   146  	formatted := beam.ParDo(s, formatIntStringsFn, sums)
   147  	passert.Equals(s, formatted, "0:[1 2 4]", "1:[3 5 6]")
   148  
   149  	ptest.RunAndValidate(t, p)
   150  }
   151  
   152  func TestXLang_Combine(t *testing.T) {
   153  	integration.CheckFilters(t)
   154  	checkFlags(t)
   155  
   156  	p := beam.NewPipeline()
   157  	s := p.Root()
   158  
   159  	// Using the cross-language transform
   160  	kvs := beam.Create(s, StringInt{X: "a", Y: 1}, StringInt{X: "a", Y: 2}, StringInt{X: "b", Y: 3})
   161  	ins := beam.ParDo(s, getStringInt, kvs)
   162  	c := xlang.CombinePerKey(s, expansionAddr, ins)
   163  
   164  	formatted := beam.ParDo(s, formatStringIntFn, c)
   165  	passert.Equals(s, formatted, "a:3", "b:3")
   166  
   167  	ptest.RunAndValidate(t, p)
   168  }
   169  
   170  func TestXLang_CombineGlobally(t *testing.T) {
   171  	integration.CheckFilters(t)
   172  	checkFlags(t)
   173  
   174  	p := beam.NewPipeline()
   175  	s := p.Root()
   176  
   177  	in := beam.CreateList(s, []int64{1, 2, 3})
   178  
   179  	// Using the cross-language transform
   180  	c := xlang.CombineGlobally(s, expansionAddr, in)
   181  
   182  	formatted := beam.ParDo(s, formatIntFn, c)
   183  	passert.Equals(s, formatted, "6")
   184  
   185  	ptest.RunAndValidate(t, p)
   186  }
   187  
   188  func TestXLang_Flatten(t *testing.T) {
   189  	integration.CheckFilters(t)
   190  	checkFlags(t)
   191  
   192  	p := beam.NewPipeline()
   193  	s := p.Root()
   194  
   195  	col1 := beam.CreateList(s, []int64{1, 2, 3})
   196  	col2 := beam.CreateList(s, []int64{4, 5, 6})
   197  
   198  	// Using the cross-language transform
   199  	c := xlang.Flatten(s, expansionAddr, col1, col2)
   200  
   201  	formatted := beam.ParDo(s, formatIntFn, c)
   202  	passert.Equals(s, formatted, "1", "2", "3", "4", "5", "6")
   203  
   204  	ptest.RunAndValidate(t, p)
   205  }
   206  
   207  func TestXLang_GroupBy(t *testing.T) {
   208  	integration.CheckFilters(t)
   209  	checkFlags(t)
   210  
   211  	p := beam.NewPipeline()
   212  	s := p.Root()
   213  
   214  	// Using the cross-language transform
   215  	kvs := beam.Create(s, StringInt{X: "0", Y: 1}, StringInt{X: "0", Y: 2}, StringInt{X: "1", Y: 3})
   216  	in := beam.ParDo(s, getStringInt, kvs)
   217  	out := xlang.GroupByKey(s, expansionAddr, in)
   218  
   219  	vals := beam.ParDo(s, collectValues, out)
   220  	formatted := beam.ParDo(s, formatStringIntsFn, vals)
   221  	passert.Equals(s, formatted, "0:[1 2]", "1:[3]")
   222  
   223  	ptest.RunAndValidate(t, p)
   224  }
   225  
   226  func TestXLang_Multi(t *testing.T) {
   227  	integration.CheckFilters(t)
   228  	checkFlags(t)
   229  
   230  	p := beam.NewPipeline()
   231  	s := p.Root()
   232  
   233  	main1 := beam.CreateList(s, []string{"a", "bb"})
   234  	main2 := beam.CreateList(s, []string{"x", "yy", "zzz"})
   235  	side := beam.CreateList(s, []string{"s"})
   236  
   237  	// Using the cross-language transform
   238  	mainOut, sideOut := xlang.Multi(s, expansionAddr, main1, main2, side)
   239  
   240  	passert.Equals(s, mainOut, "as", "bbs", "xs", "yys", "zzzs")
   241  	passert.Equals(s, sideOut, "ss")
   242  
   243  	ptest.RunAndValidate(t, p)
   244  }
   245  
   246  func TestXLang_Partition(t *testing.T) {
   247  	integration.CheckFilters(t)
   248  	checkFlags(t)
   249  
   250  	p := beam.NewPipeline()
   251  	s := p.Root()
   252  
   253  	col := beam.CreateList(s, []int64{1, 2, 3, 4, 5, 6})
   254  
   255  	// Using the cross-language transform
   256  	out0, out1 := xlang.Partition(s, expansionAddr, col)
   257  	formatted0 := beam.ParDo(s, formatIntFn, out0)
   258  	formatted1 := beam.ParDo(s, formatIntFn, out1)
   259  
   260  	passert.Equals(s, formatted0, "2", "4", "6")
   261  	passert.Equals(s, formatted1, "1", "3", "5")
   262  
   263  	ptest.RunAndValidate(t, p)
   264  }
   265  
   266  func TestMain(m *testing.M) {
   267  	flag.Parse()
   268  	beam.Init()
   269  
   270  	services := integration.NewExpansionServices()
   271  	defer func() { services.Shutdown() }()
   272  	addr, err := services.GetAddr("test")
   273  	if err != nil {
   274  		log.Printf("skipping missing expansion service: %v", err)
   275  	} else {
   276  		expansionAddr = addr
   277  	}
   278  
   279  	ptest.MainRet(m)
   280  }