github.com/apache/beam/sdks/v2@v2.48.2/go/test/integration/xlang/xlang_test.go (about) 1 // Licensed to the Apache Software Foundation (ASF) under one or more 2 // contributor license agreements. See the NOTICE file distributed with 3 // this work for additional information regarding copyright ownership. 4 // The ASF licenses this file to You under the Apache License, Version 2.0 5 // (the "License"); you may not use this file except in compliance with 6 // the License. You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 package xlang 17 18 import ( 19 "flag" 20 "fmt" 21 "log" 22 "reflect" 23 "sort" 24 "testing" 25 26 "github.com/apache/beam/sdks/v2/go/examples/xlang" 27 "github.com/apache/beam/sdks/v2/go/pkg/beam" 28 _ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/dataflow" 29 _ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/flink" 30 _ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/samza" 31 _ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/spark" 32 "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert" 33 "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest" 34 "github.com/apache/beam/sdks/v2/go/test/integration" 35 ) 36 37 var expansionAddr string // Populate with expansion address labelled "test". 38 39 func init() { 40 beam.RegisterType(reflect.TypeOf((*IntString)(nil)).Elem()) 41 beam.RegisterType(reflect.TypeOf((*StringInt)(nil)).Elem()) 42 beam.RegisterFunction(formatIntStringsFn) 43 beam.RegisterFunction(formatStringIntFn) 44 beam.RegisterFunction(formatStringIntsFn) 45 beam.RegisterFunction(formatIntFn) 46 beam.RegisterFunction(getIntString) 47 beam.RegisterFunction(getStringInt) 48 beam.RegisterFunction(sumCounts) 49 beam.RegisterFunction(collectValues) 50 } 51 52 func checkFlags(t *testing.T) { 53 if expansionAddr == "" { 54 t.Skip("No Test expansion address provided.") 55 } 56 } 57 58 // formatIntStringsFn is a DoFn that formats an int64 and a list of strings. 59 func formatIntStringsFn(i int64, s []string) string { 60 sort.Strings(s) 61 return fmt.Sprintf("%v:%v", i, s) 62 } 63 64 // formatStringIntFn is a DoFn that formats a string and an int64. 65 func formatStringIntFn(s string, i int64) string { 66 return fmt.Sprintf("%s:%v", s, i) 67 } 68 69 // formatStringIntsFn is a DoFn that formats a string and a list of ints. 70 func formatStringIntsFn(s string, i []int) string { 71 sort.Ints(i) 72 return fmt.Sprintf("%v:%v", s, i) 73 } 74 75 // formatIntFn is a DoFn that formats an int64 as a string. 76 func formatIntFn(i int64) string { 77 return fmt.Sprintf("%v", i) 78 } 79 80 // IntString used to represent KV PCollection values of int64, string. 81 type IntString struct { 82 X int64 83 Y string 84 } 85 86 func getIntString(kv IntString, emit func(int64, string)) { 87 emit(kv.X, kv.Y) 88 } 89 90 // StringInt used to represent KV PCollection values of string, int64. 91 type StringInt struct { 92 X string 93 Y int64 94 } 95 96 func getStringInt(kv StringInt, emit func(string, int64)) { 97 emit(kv.X, kv.Y) 98 } 99 100 func sumCounts(key int64, iter1 func(*string) bool) (int64, []string) { 101 var val string 102 var values []string 103 104 for iter1(&val) { 105 values = append(values, val) 106 } 107 return key, values 108 } 109 110 func collectValues(key string, iter func(*int64) bool) (string, []int) { 111 var count int64 112 var values []int 113 for iter(&count) { 114 values = append(values, int(count)) 115 } 116 return key, values 117 } 118 119 func TestXLang_Prefix(t *testing.T) { 120 integration.CheckFilters(t) 121 checkFlags(t) 122 123 p := beam.NewPipeline() 124 s := p.Root() 125 126 // Using the cross-language transform 127 strings := beam.Create(s, "a", "b", "c") 128 prefixed := xlang.Prefix(s, "prefix_", expansionAddr, strings) 129 passert.Equals(s, prefixed, "prefix_a", "prefix_b", "prefix_c") 130 131 ptest.RunAndValidate(t, p) 132 } 133 134 func TestXLang_CoGroupBy(t *testing.T) { 135 integration.CheckFilters(t) 136 checkFlags(t) 137 138 p := beam.NewPipeline() 139 s := p.Root() 140 141 // Using the cross-language transform 142 col1 := beam.ParDo(s, getIntString, beam.Create(s, IntString{X: 0, Y: "1"}, IntString{X: 0, Y: "2"}, IntString{X: 1, Y: "3"})) 143 col2 := beam.ParDo(s, getIntString, beam.Create(s, IntString{X: 0, Y: "4"}, IntString{X: 1, Y: "5"}, IntString{X: 1, Y: "6"})) 144 c := xlang.CoGroupByKey(s, expansionAddr, col1, col2) 145 sums := beam.ParDo(s, sumCounts, c) 146 formatted := beam.ParDo(s, formatIntStringsFn, sums) 147 passert.Equals(s, formatted, "0:[1 2 4]", "1:[3 5 6]") 148 149 ptest.RunAndValidate(t, p) 150 } 151 152 func TestXLang_Combine(t *testing.T) { 153 integration.CheckFilters(t) 154 checkFlags(t) 155 156 p := beam.NewPipeline() 157 s := p.Root() 158 159 // Using the cross-language transform 160 kvs := beam.Create(s, StringInt{X: "a", Y: 1}, StringInt{X: "a", Y: 2}, StringInt{X: "b", Y: 3}) 161 ins := beam.ParDo(s, getStringInt, kvs) 162 c := xlang.CombinePerKey(s, expansionAddr, ins) 163 164 formatted := beam.ParDo(s, formatStringIntFn, c) 165 passert.Equals(s, formatted, "a:3", "b:3") 166 167 ptest.RunAndValidate(t, p) 168 } 169 170 func TestXLang_CombineGlobally(t *testing.T) { 171 integration.CheckFilters(t) 172 checkFlags(t) 173 174 p := beam.NewPipeline() 175 s := p.Root() 176 177 in := beam.CreateList(s, []int64{1, 2, 3}) 178 179 // Using the cross-language transform 180 c := xlang.CombineGlobally(s, expansionAddr, in) 181 182 formatted := beam.ParDo(s, formatIntFn, c) 183 passert.Equals(s, formatted, "6") 184 185 ptest.RunAndValidate(t, p) 186 } 187 188 func TestXLang_Flatten(t *testing.T) { 189 integration.CheckFilters(t) 190 checkFlags(t) 191 192 p := beam.NewPipeline() 193 s := p.Root() 194 195 col1 := beam.CreateList(s, []int64{1, 2, 3}) 196 col2 := beam.CreateList(s, []int64{4, 5, 6}) 197 198 // Using the cross-language transform 199 c := xlang.Flatten(s, expansionAddr, col1, col2) 200 201 formatted := beam.ParDo(s, formatIntFn, c) 202 passert.Equals(s, formatted, "1", "2", "3", "4", "5", "6") 203 204 ptest.RunAndValidate(t, p) 205 } 206 207 func TestXLang_GroupBy(t *testing.T) { 208 integration.CheckFilters(t) 209 checkFlags(t) 210 211 p := beam.NewPipeline() 212 s := p.Root() 213 214 // Using the cross-language transform 215 kvs := beam.Create(s, StringInt{X: "0", Y: 1}, StringInt{X: "0", Y: 2}, StringInt{X: "1", Y: 3}) 216 in := beam.ParDo(s, getStringInt, kvs) 217 out := xlang.GroupByKey(s, expansionAddr, in) 218 219 vals := beam.ParDo(s, collectValues, out) 220 formatted := beam.ParDo(s, formatStringIntsFn, vals) 221 passert.Equals(s, formatted, "0:[1 2]", "1:[3]") 222 223 ptest.RunAndValidate(t, p) 224 } 225 226 func TestXLang_Multi(t *testing.T) { 227 integration.CheckFilters(t) 228 checkFlags(t) 229 230 p := beam.NewPipeline() 231 s := p.Root() 232 233 main1 := beam.CreateList(s, []string{"a", "bb"}) 234 main2 := beam.CreateList(s, []string{"x", "yy", "zzz"}) 235 side := beam.CreateList(s, []string{"s"}) 236 237 // Using the cross-language transform 238 mainOut, sideOut := xlang.Multi(s, expansionAddr, main1, main2, side) 239 240 passert.Equals(s, mainOut, "as", "bbs", "xs", "yys", "zzzs") 241 passert.Equals(s, sideOut, "ss") 242 243 ptest.RunAndValidate(t, p) 244 } 245 246 func TestXLang_Partition(t *testing.T) { 247 integration.CheckFilters(t) 248 checkFlags(t) 249 250 p := beam.NewPipeline() 251 s := p.Root() 252 253 col := beam.CreateList(s, []int64{1, 2, 3, 4, 5, 6}) 254 255 // Using the cross-language transform 256 out0, out1 := xlang.Partition(s, expansionAddr, col) 257 formatted0 := beam.ParDo(s, formatIntFn, out0) 258 formatted1 := beam.ParDo(s, formatIntFn, out1) 259 260 passert.Equals(s, formatted0, "2", "4", "6") 261 passert.Equals(s, formatted1, "1", "3", "5") 262 263 ptest.RunAndValidate(t, p) 264 } 265 266 func TestMain(m *testing.M) { 267 flag.Parse() 268 beam.Init() 269 270 services := integration.NewExpansionServices() 271 defer func() { services.Shutdown() }() 272 addr, err := services.GetAddr("test") 273 if err != nil { 274 log.Printf("skipping missing expansion service: %v", err) 275 } else { 276 expansionAddr = addr 277 } 278 279 ptest.MainRet(m) 280 }