github.com/apache/beam/sdks/v2@v2.48.2/go/test/integration/primitives/state.go (about)

     1  // Licensed to the Apache Software Foundation (ASF) under one or more
     2  // contributor license agreements.  See the NOTICE file distributed with
     3  // this work for additional information regarding copyright ownership.
     4  // The ASF licenses this file to You under the Apache License, Version 2.0
     5  // (the "License"); you may not use this file except in compliance with
     6  // the License.  You may obtain a copy of the License at
     7  //
     8  //    http://www.apache.org/licenses/LICENSE-2.0
     9  //
    10  // Unless required by applicable law or agreed to in writing, software
    11  // distributed under the License is distributed on an "AS IS" BASIS,
    12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  // See the License for the specific language governing permissions and
    14  // limitations under the License.
    15  
    16  package primitives
    17  
    18  import (
    19  	"fmt"
    20  	"sort"
    21  	"strconv"
    22  	"strings"
    23  	"time"
    24  
    25  	"github.com/apache/beam/sdks/v2/go/pkg/beam"
    26  	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window"
    27  	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/state"
    28  	"github.com/apache/beam/sdks/v2/go/pkg/beam/register"
    29  	"github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert"
    30  )
    31  
    32  func init() {
    33  	register.DoFn3x1[state.Provider, string, int, string](&valueStateFn{})
    34  	register.DoFn3x1[state.Provider, string, int, string](&valueStateClearFn{})
    35  	register.DoFn3x1[state.Provider, string, int, string](&bagStateFn{})
    36  	register.DoFn3x1[state.Provider, string, int, string](&bagStateClearFn{})
    37  	register.DoFn3x1[state.Provider, string, int, string](&combiningStateFn{})
    38  	register.DoFn3x1[state.Provider, string, int, string](&mapStateFn{})
    39  	register.DoFn3x1[state.Provider, string, int, string](&mapStateClearFn{})
    40  	register.DoFn3x1[state.Provider, string, int, string](&setStateFn{})
    41  	register.DoFn3x1[state.Provider, string, int, string](&setStateClearFn{})
    42  	register.Emitter2[string, int]()
    43  	register.Combiner1[int](&combine1{})
    44  	register.Combiner2[string, int](&combine2{})
    45  	register.Combiner2[string, int](&combine3{})
    46  	register.Combiner1[int](&combine4{})
    47  }
    48  
    49  type valueStateFn struct {
    50  	State1 state.Value[int]
    51  	State2 state.Value[string]
    52  }
    53  
    54  func (f *valueStateFn) ProcessElement(s state.Provider, w string, c int) string {
    55  	i, ok, err := f.State1.Read(s)
    56  	if err != nil {
    57  		panic(err)
    58  	}
    59  	if !ok {
    60  		i = 1
    61  	}
    62  	err = f.State1.Write(s, i+1)
    63  	if err != nil {
    64  		panic(err)
    65  	}
    66  
    67  	j, ok, err := f.State2.Read(s)
    68  	if err != nil {
    69  		panic(err)
    70  	}
    71  	if !ok {
    72  		j = "I"
    73  	}
    74  	err = f.State2.Write(s, j+"I")
    75  	if err != nil {
    76  		panic(err)
    77  	}
    78  	return fmt.Sprintf("%s: %v, %s", w, i, j)
    79  }
    80  
    81  // ValueStateParDo tests a DoFn that uses value state.
    82  func ValueStateParDo() *beam.Pipeline {
    83  	p, s := beam.NewPipelineWithRoot()
    84  
    85  	in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear")
    86  	keyed := beam.ParDo(s, func(w string, emit func(string, int)) {
    87  		emit(w, 1)
    88  	}, in)
    89  	counts := beam.ParDo(s, &valueStateFn{}, keyed)
    90  	passert.Equals(s, counts, "apple: 1, I", "pear: 1, I", "peach: 1, I", "apple: 2, II", "apple: 3, III", "pear: 2, II")
    91  
    92  	return p
    93  }
    94  
    95  // ValueStateParDoWindowed tests a DoFn that uses windowed value state.
    96  func ValueStateParDoWindowed() *beam.Pipeline {
    97  	p, s := beam.NewPipelineWithRoot()
    98  
    99  	timestampedData := beam.ParDo(s, &createTimestampedData{Data: []int{1, 1, 1, 2, 2, 3, 4, 4, 4, 4}}, beam.Impulse(s))
   100  	wData := beam.WindowInto(s, window.NewFixedWindows(3*time.Second), timestampedData)
   101  	counts := beam.ParDo(s, &valueStateFn{State1: state.MakeValueState[int]("key1"), State2: state.MakeValueState[string]("key2")}, wData)
   102  	globalCounts := beam.WindowInto(s, window.NewGlobalWindows(), counts)
   103  	passert.Equals(s, globalCounts, "magic: 1, I", "magic: 2, II", "magic: 3, III", "magic: 1, I", "magic: 2, II", "magic: 3, III", "magic: 1, I", "magic: 2, II", "magic: 3, III", "magic: 1, I")
   104  
   105  	return p
   106  }
   107  
   108  type valueStateClearFn struct {
   109  	State1 state.Value[int]
   110  }
   111  
   112  func (f *valueStateClearFn) ProcessElement(s state.Provider, w string, c int) string {
   113  	i, ok, err := f.State1.Read(s)
   114  	if err != nil {
   115  		panic(err)
   116  	}
   117  	if ok {
   118  		err = f.State1.Clear(s)
   119  		if err != nil {
   120  			panic(err)
   121  		}
   122  	} else {
   123  		err = f.State1.Write(s, 1)
   124  		if err != nil {
   125  			panic(err)
   126  		}
   127  	}
   128  
   129  	return fmt.Sprintf("%s: %v,%v", w, i, ok)
   130  }
   131  
   132  // ValueStateParDoClear tests that a DoFn that uses value state can be cleared.
   133  func ValueStateParDoClear() *beam.Pipeline {
   134  	p, s := beam.NewPipelineWithRoot()
   135  
   136  	in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear", "pear", "apple")
   137  	keyed := beam.ParDo(s, func(w string, emit func(string, int)) {
   138  		emit(w, 1)
   139  	}, in)
   140  	counts := beam.ParDo(s, &valueStateClearFn{State1: state.MakeValueState[int]("key1")}, keyed)
   141  	passert.Equals(s, counts, "apple: 0,false", "pear: 0,false", "peach: 0,false", "apple: 1,true", "apple: 0,false", "pear: 1,true", "pear: 0,false", "apple: 1,true")
   142  
   143  	return p
   144  }
   145  
   146  type bagStateFn struct {
   147  	State1 state.Bag[int]
   148  	State2 state.Bag[string]
   149  }
   150  
   151  func (f *bagStateFn) ProcessElement(s state.Provider, w string, c int) string {
   152  	i, ok, err := f.State1.Read(s)
   153  	if err != nil {
   154  		panic(err)
   155  	}
   156  	if !ok {
   157  		i = []int{}
   158  	}
   159  	err = f.State1.Add(s, 1)
   160  	if err != nil {
   161  		panic(err)
   162  	}
   163  
   164  	j, ok, err := f.State2.Read(s)
   165  	if err != nil {
   166  		panic(err)
   167  	}
   168  	if !ok {
   169  		j = []string{}
   170  	}
   171  	err = f.State2.Add(s, "I")
   172  	if err != nil {
   173  		panic(err)
   174  	}
   175  	sum := 0
   176  	for _, val := range i {
   177  		sum += val
   178  	}
   179  	return fmt.Sprintf("%s: %v, %s", w, sum, strings.Join(j, ","))
   180  }
   181  
   182  // BagStateParDo tests a DoFn that uses bag state.
   183  func BagStateParDo() *beam.Pipeline {
   184  	p, s := beam.NewPipelineWithRoot()
   185  
   186  	in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear")
   187  	keyed := beam.ParDo(s, func(w string, emit func(string, int)) {
   188  		emit(w, 1)
   189  	}, in)
   190  	counts := beam.ParDo(s, &bagStateFn{}, keyed)
   191  	passert.Equals(s, counts, "apple: 0, ", "pear: 0, ", "peach: 0, ", "apple: 1, I", "apple: 2, I,I", "pear: 1, I")
   192  
   193  	return p
   194  }
   195  
   196  type bagStateClearFn struct {
   197  	State1 state.Bag[int]
   198  }
   199  
   200  func (f *bagStateClearFn) ProcessElement(s state.Provider, w string, c int) string {
   201  	i, ok, err := f.State1.Read(s)
   202  	if err != nil {
   203  		panic(err)
   204  	}
   205  	if !ok {
   206  		i = []int{}
   207  	}
   208  	err = f.State1.Add(s, 1)
   209  	if err != nil {
   210  		panic(err)
   211  	}
   212  
   213  	sum := 0
   214  	for _, val := range i {
   215  		sum += val
   216  	}
   217  	if sum == 3 {
   218  		f.State1.Clear(s)
   219  	}
   220  	return fmt.Sprintf("%s: %v", w, sum)
   221  }
   222  
   223  // BagStateParDoClear tests a DoFn that uses bag state.
   224  func BagStateParDoClear() *beam.Pipeline {
   225  	p, s := beam.NewPipelineWithRoot()
   226  
   227  	in := beam.Create(s, "apple", "pear", "apple", "apple", "pear", "apple", "apple", "pear", "pear", "pear", "apple", "pear")
   228  	keyed := beam.ParDo(s, func(w string, emit func(string, int)) {
   229  		emit(w, 1)
   230  	}, in)
   231  	counts := beam.ParDo(s, &bagStateClearFn{State1: state.MakeBagState[int]("key1")}, keyed)
   232  	passert.Equals(s, counts, "apple: 0", "pear: 0", "apple: 1", "apple: 2", "pear: 1", "apple: 3", "apple: 0", "pear: 2", "pear: 3", "pear: 0", "apple: 1", "pear: 1")
   233  
   234  	return p
   235  }
   236  
   237  type combiningStateFn struct {
   238  	State0 state.Combining[int, int, int]
   239  	State1 state.Combining[int, int, int]
   240  	State2 state.Combining[string, string, int]
   241  	State3 state.Combining[string, string, int]
   242  	State4 state.Combining[int, int, int]
   243  }
   244  
   245  type combine1 struct{}
   246  
   247  func (ac *combine1) MergeAccumulators(a, b int) int {
   248  	return a + b
   249  }
   250  
   251  type combine2 struct{}
   252  
   253  func (ac *combine2) MergeAccumulators(a, b string) string {
   254  	ai, _ := strconv.Atoi(a)
   255  	bi, _ := strconv.Atoi(b)
   256  	return strconv.Itoa(ai + bi)
   257  }
   258  
   259  func (ac *combine2) ExtractOutput(a string) int {
   260  	ai, _ := strconv.Atoi(a)
   261  	return ai
   262  }
   263  
   264  type combine3 struct{}
   265  
   266  func (ac *combine3) CreateAccumulator() string {
   267  	return "0"
   268  }
   269  
   270  func (ac *combine3) MergeAccumulators(a string, b string) string {
   271  	ai, _ := strconv.Atoi(a)
   272  	bi, _ := strconv.Atoi(b)
   273  	return strconv.Itoa(ai + bi)
   274  }
   275  
   276  func (ac *combine3) ExtractOutput(a string) int {
   277  	ai, _ := strconv.Atoi(a)
   278  	return ai
   279  }
   280  
   281  type combine4 struct{}
   282  
   283  func (ac *combine4) AddInput(a, b int) int {
   284  	return a + b
   285  }
   286  
   287  func (ac *combine4) MergeAccumulators(a, b int) int {
   288  	return a + b
   289  }
   290  
   291  func (f *combiningStateFn) ProcessElement(s state.Provider, w string, c int) string {
   292  	i, _, err := f.State0.Read(s)
   293  	if err != nil {
   294  		panic(err)
   295  	}
   296  	err = f.State0.Add(s, 1)
   297  	if err != nil {
   298  		panic(err)
   299  	}
   300  	i1, _, err := f.State1.Read(s)
   301  	if err != nil {
   302  		panic(err)
   303  	}
   304  	err = f.State1.Add(s, 1)
   305  	if err != nil {
   306  		panic(err)
   307  	}
   308  	i2, _, err := f.State2.Read(s)
   309  	if err != nil {
   310  		panic(err)
   311  	}
   312  	err = f.State2.Add(s, "1")
   313  	if err != nil {
   314  		panic(err)
   315  	}
   316  	i3, _, err := f.State3.Read(s)
   317  	if err != nil {
   318  		panic(err)
   319  	}
   320  	err = f.State3.Add(s, "1")
   321  	if err != nil {
   322  		panic(err)
   323  	}
   324  	i4, _, err := f.State4.Read(s)
   325  	if err != nil {
   326  		panic(err)
   327  	}
   328  	err = f.State4.Add(s, 1)
   329  	if err != nil {
   330  		panic(err)
   331  	}
   332  	return fmt.Sprintf("%s: %v %v %v %v %v", w, i, i1, i2, i3, i4)
   333  }
   334  
   335  // CombiningStateParDo tests a DoFn that uses value state.
   336  func CombiningStateParDo() *beam.Pipeline {
   337  	p, s := beam.NewPipelineWithRoot()
   338  
   339  	in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear")
   340  	keyed := beam.ParDo(s, func(w string, emit func(string, int)) {
   341  		emit(w, 1)
   342  	}, in)
   343  	counts := beam.ParDo(s, &combiningStateFn{
   344  		State0: state.MakeCombiningState[int, int, int]("key0", func(a, b int) int {
   345  			return a + b
   346  		}),
   347  		State1: state.Combining[int, int, int](state.MakeCombiningState[int, int, int]("key1", &combine1{})),
   348  		State2: state.Combining[string, string, int](state.MakeCombiningState[string, string, int]("key2", &combine2{})),
   349  		State3: state.Combining[string, string, int](state.MakeCombiningState[string, string, int]("key3", &combine3{})),
   350  		State4: state.Combining[int, int, int](state.MakeCombiningState[int, int, int]("key4", &combine4{}))},
   351  		keyed)
   352  	passert.Equals(s, counts, "apple: 0 0 0 0 0", "pear: 0 0 0 0 0", "peach: 0 0 0 0 0", "apple: 1 1 1 1 1", "apple: 2 2 2 2 2", "pear: 1 1 1 1 1")
   353  
   354  	return p
   355  }
   356  
   357  type mapStateFn struct {
   358  	State1 state.Map[string, int]
   359  }
   360  
   361  func (f *mapStateFn) ProcessElement(s state.Provider, w string, c int) string {
   362  	i, _, err := f.State1.Get(s, w)
   363  	if err != nil {
   364  		panic(err)
   365  	}
   366  	i++
   367  	err = f.State1.Put(s, w, i)
   368  	if err != nil {
   369  		panic(err)
   370  	}
   371  	err = f.State1.Put(s, fmt.Sprintf("%v%v", w, i), i)
   372  	if err != nil {
   373  		panic(err)
   374  	}
   375  	j, _, err := f.State1.Get(s, w)
   376  	if err != nil {
   377  		panic(err)
   378  	}
   379  	if i != j {
   380  		panic(fmt.Sprintf("Reading state multiple times for %v produced different results: %v != %v", w, i, j))
   381  	}
   382  
   383  	keys, _, err := f.State1.Keys(s)
   384  	if err != nil {
   385  		panic(err)
   386  	}
   387  
   388  	sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] })
   389  
   390  	return fmt.Sprintf("%v: %v, keys: %v", w, i, keys)
   391  }
   392  
   393  // MapStateParDo tests a DoFn that uses value state.
   394  func MapStateParDo() *beam.Pipeline {
   395  	p, s := beam.NewPipelineWithRoot()
   396  
   397  	in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear")
   398  	keyed := beam.ParDo(s, func(w string, emit func(string, int)) {
   399  		emit(w, 1)
   400  	}, in)
   401  	counts := beam.ParDo(s, &mapStateFn{State1: state.MakeMapState[string, int]("key1")}, keyed)
   402  	passert.Equals(s, counts, "apple: 1, keys: [apple apple1]", "pear: 1, keys: [pear pear1]", "peach: 1, keys: [peach peach1]", "apple: 2, keys: [apple apple1 apple2]", "apple: 3, keys: [apple apple1 apple2 apple3]", "pear: 2, keys: [pear pear1 pear2]")
   403  
   404  	return p
   405  }
   406  
   407  type mapStateClearFn struct {
   408  	State1 state.Map[string, int]
   409  }
   410  
   411  func (f *mapStateClearFn) ProcessElement(s state.Provider, w string, c int) string {
   412  	_, ok, err := f.State1.Get(s, w)
   413  	if err != nil {
   414  		panic(err)
   415  	}
   416  	if ok {
   417  		f.State1.Remove(s, w)
   418  		f.State1.Put(s, fmt.Sprintf("%v%v", w, 1), 1)
   419  		f.State1.Put(s, fmt.Sprintf("%v%v", w, 2), 1)
   420  		f.State1.Put(s, fmt.Sprintf("%v%v", w, 3), 1)
   421  	} else {
   422  		_, ok, err := f.State1.Get(s, fmt.Sprintf("%v%v", w, 1))
   423  		if err != nil {
   424  			panic(err)
   425  		}
   426  		if ok {
   427  			f.State1.Clear(s)
   428  		} else {
   429  			f.State1.Put(s, w, 1)
   430  		}
   431  	}
   432  
   433  	keys, _, err := f.State1.Keys(s)
   434  	if err != nil {
   435  		panic(err)
   436  	}
   437  
   438  	sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] })
   439  
   440  	for _, k := range keys {
   441  		_, ok, err = f.State1.Get(s, k)
   442  		if err != nil {
   443  			panic(err)
   444  		}
   445  		if !ok {
   446  			panic(fmt.Sprintf("%v is present in keys, but not in the map", k))
   447  		}
   448  	}
   449  
   450  	return fmt.Sprintf("%v: %v", w, keys)
   451  }
   452  
   453  // MapStateParDoClear tests clearing and removing from a DoFn that uses map state.
   454  func MapStateParDoClear() *beam.Pipeline {
   455  	p, s := beam.NewPipelineWithRoot()
   456  
   457  	in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear")
   458  	keyed := beam.ParDo(s, func(w string, emit func(string, int)) {
   459  		emit(w, 1)
   460  	}, in)
   461  	counts := beam.ParDo(s, &mapStateClearFn{State1: state.MakeMapState[string, int]("key1")}, keyed)
   462  	passert.Equals(s, counts, "apple: [apple]", "pear: [pear]", "peach: [peach]", "apple: [apple1 apple2 apple3]", "apple: []", "pear: [pear1 pear2 pear3]")
   463  
   464  	return p
   465  }
   466  
   467  type setStateFn struct {
   468  	State1 state.Set[string]
   469  }
   470  
   471  func (f *setStateFn) ProcessElement(s state.Provider, w string, c int) string {
   472  	ok, err := f.State1.Contains(s, w)
   473  	if err != nil {
   474  		panic(err)
   475  	}
   476  	err = f.State1.Add(s, w)
   477  	if err != nil {
   478  		panic(err)
   479  	}
   480  	if ok {
   481  		err = f.State1.Add(s, fmt.Sprintf("%v%v", w, 1))
   482  		if err != nil {
   483  			panic(err)
   484  		}
   485  	}
   486  
   487  	keys, _, err := f.State1.Keys(s)
   488  	if err != nil {
   489  		panic(err)
   490  	}
   491  
   492  	sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] })
   493  
   494  	return fmt.Sprintf("%v: %v, keys: %v", w, ok, keys)
   495  }
   496  
   497  // SetStateParDo tests a DoFn that uses set state.
   498  func SetStateParDo() *beam.Pipeline {
   499  	p, s := beam.NewPipelineWithRoot()
   500  
   501  	in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear")
   502  	keyed := beam.ParDo(s, func(w string, emit func(string, int)) {
   503  		emit(w, 1)
   504  	}, in)
   505  	counts := beam.ParDo(s, &setStateFn{State1: state.MakeSetState[string]("key1")}, keyed)
   506  	passert.Equals(s, counts, "apple: false, keys: [apple]", "pear: false, keys: [pear]", "peach: false, keys: [peach]", "apple: true, keys: [apple apple1]", "apple: true, keys: [apple apple1]", "pear: true, keys: [pear pear1]")
   507  
   508  	return p
   509  }
   510  
   511  type setStateClearFn struct {
   512  	State1 state.Set[string]
   513  }
   514  
   515  func (f *setStateClearFn) ProcessElement(s state.Provider, w string, c int) string {
   516  	ok, err := f.State1.Contains(s, w)
   517  	if err != nil {
   518  		panic(err)
   519  	}
   520  	if ok {
   521  		f.State1.Remove(s, w)
   522  		f.State1.Add(s, fmt.Sprintf("%v%v", w, 1))
   523  		f.State1.Add(s, fmt.Sprintf("%v%v", w, 2))
   524  		f.State1.Add(s, fmt.Sprintf("%v%v", w, 3))
   525  	} else {
   526  		ok, err := f.State1.Contains(s, fmt.Sprintf("%v%v", w, 1))
   527  		if err != nil {
   528  			panic(err)
   529  		}
   530  		if ok {
   531  			f.State1.Clear(s)
   532  		} else {
   533  			f.State1.Add(s, w)
   534  		}
   535  	}
   536  
   537  	keys, _, err := f.State1.Keys(s)
   538  	if err != nil {
   539  		panic(err)
   540  	}
   541  
   542  	sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] })
   543  
   544  	for _, k := range keys {
   545  		ok, err = f.State1.Contains(s, k)
   546  		if err != nil {
   547  			panic(err)
   548  		}
   549  		if !ok {
   550  			panic(fmt.Sprintf("%v is present in keys, but not in the map", k))
   551  		}
   552  	}
   553  
   554  	return fmt.Sprintf("%v: %v", w, keys)
   555  }
   556  
   557  // SetStateParDoClear tests clearing and removing from a DoFn that uses set state.
   558  func SetStateParDoClear() *beam.Pipeline {
   559  	p, s := beam.NewPipelineWithRoot()
   560  
   561  	in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear")
   562  	keyed := beam.ParDo(s, func(w string, emit func(string, int)) {
   563  		emit(w, 1)
   564  	}, in)
   565  	counts := beam.ParDo(s, &setStateClearFn{State1: state.MakeSetState[string]("key1")}, keyed)
   566  	passert.Equals(s, counts, "apple: [apple]", "pear: [pear]", "peach: [peach]", "apple: [apple1 apple2 apple3]", "apple: []", "pear: [pear1 pear2 pear3]")
   567  
   568  	return p
   569  }