kythe.io@v0.0.68-0.20240422202219-7225dbc01741/kythe/go/serving/pipeline/beamio/shards_test.go (about) 1 /* 2 * Copyright 2021 The Kythe Authors. All rights reserved. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package beamio 18 19 import ( 20 "testing" 21 22 "github.com/apache/beam/sdks/go/pkg/beam" 23 "github.com/apache/beam/sdks/go/pkg/beam/testing/passert" 24 "github.com/apache/beam/sdks/go/pkg/beam/testing/ptest" 25 "github.com/apache/beam/sdks/go/pkg/beam/transforms/stats" 26 ) 27 28 type shardValue struct { 29 Num int 30 Values []int 31 } 32 33 type collectShards struct{} 34 35 func (c collectShards) CreateAccumulator() []int { return []int{} } 36 func (c collectShards) AddInput(accum []int, input KeyValue) []int { 37 return append(accum, int(input.Key[0])) 38 } 39 func (c collectShards) MergeAccumulators(accum, other []int) []int { 40 return append(accum, other...) 41 } 42 func makeShardValue(shard int, accum []int) shardValue { 43 return shardValue{Num: shard, Values: accum} 44 } 45 46 func TestComputeShards(t *testing.T) { 47 const testElements = 8 48 kvs := make([]KeyValue, 0, testElements) 49 for i := 0; i < testElements; i++ { 50 kvs = append(kvs, KeyValue{ 51 Key: []byte{byte(i)}, 52 Value: []byte{}, 53 }) 54 } 55 expected := []shardValue{ 56 {Num: 0, Values: []int{0, 1}}, 57 {Num: 1, Values: []int{2, 3, 4}}, 58 {Num: 2, Values: []int{5, 6, 7}}, 59 } 60 p, s, col, pExpected := ptest.CreateList2(kvs, expected) 61 shardedElementsKv := ComputeShards(s, col, stats.Opts{K: 10, NumQuantiles: 3}) 62 collectedShards := beam.ParDo(s, makeShardValue, beam.CombinePerKey(s, collectShards{}, shardedElementsKv)) 63 64 passert.Equals(s, pExpected, collectedShards) 65 if err := ptest.Run(p); err != nil { 66 t.Errorf("ComputeShards failed: %v", err) 67 } 68 }