github.com/apache/beam/sdks/v2@v2.48.2/go/test/regression/coders/fromyaml/fromyaml.go (about)

     1  // Licensed to the Apache Software Foundation (ASF) under one or more
     2  // contributor license agreements.  See the NOTICE file distributed with
     3  // this work for additional information regarding copyright ownership.
     4  // The ASF licenses this file to You under the Apache License, Version 2.0
     5  // (the "License"); you may not use this file except in compliance with
     6  // the License.  You may obtain a copy of the License at
     7  //
     8  //    http://www.apache.org/licenses/LICENSE-2.0
     9  //
    10  // Unless required by applicable law or agreed to in writing, software
    11  // distributed under the License is distributed on an "AS IS" BASIS,
    12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  // See the License for the specific language governing permissions and
    14  // limitations under the License.
    15  
    16  // fromyaml generates a resource file from the standard_coders.yaml
    17  // file for use in these coder regression tests.
    18  //
    19  // It expects to be run in it's test directory, or via it's go test.
    20  package main
    21  
    22  import (
    23  	"bytes"
    24  	"fmt"
    25  	"log"
    26  	"math"
    27  	"os"
    28  	"reflect"
    29  	"runtime/debug"
    30  	"strconv"
    31  	"strings"
    32  
    33  	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime"
    34  	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window"
    35  	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec"
    36  	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/graphx"
    37  	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
    38  	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
    39  	pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1"
    40  	"github.com/google/go-cmp/cmp"
    41  	"golang.org/x/text/encoding/charmap"
    42  	yaml "gopkg.in/yaml.v2"
    43  )
    44  
    45  var unimplementedCoders = map[string]bool{
    46  	"beam:coder:param_windowed_value:v1": true,
    47  	"beam:coder:sharded_key:v1":          true,
    48  	"beam:coder:custom_window:v1":        true,
    49  }
    50  
    51  var filteredCases = []struct{ filter, reason string }{
    52  	{"logical", "BEAM-9615: Support logical types"},
    53  	{"30ea5a25-dcd8-4cdb-abeb-5332d15ab4b9", "https://github.com/apache/beam/issues/21206: Support encoding position."},
    54  	{"80be749a-5700-4ede-89d8-dd9a4433a3f8", "https://github.com/apache/beam/issues/19817: Support millis_instant."},
    55  	{"800c44ae-a1b7-4def-bbf6-6217cca89ec4", "https://github.com/apache/beam/issues/19817: Support decimal."},
    56  	{"f0ffb3a4-f46f-41ca-a942-85e3e939452a", "https://github.com/apache/beam/issues/23526: Support char/varchar, binary/varbinary."},
    57  }
    58  
    59  // Coder is a representation a serialized beam coder.
    60  type Coder struct {
    61  	Urn              string  `yaml:"urn,omitempty"`
    62  	Payload          string  `yaml:"payload,omitempty"`
    63  	Components       []Coder `yaml:"components,omitempty"`
    64  	NonDeterministic bool    `yaml:"non_deterministic,omitempty"`
    65  }
    66  
    67  type logger interface {
    68  	Errorf(string, ...any)
    69  	Logf(string, ...any)
    70  }
    71  
    72  // Spec is a set of conditions that a coder must pass.
    73  type Spec struct {
    74  	Coder    Coder         `yaml:"coder,omitempty"`
    75  	Nested   *bool         `yaml:"nested,omitempty"`
    76  	Examples yaml.MapSlice `yaml:"examples,omitempty"`
    77  	Log      logger
    78  
    79  	id       int // for generating coder ids.
    80  	coderPBs map[string]*pipepb.Coder
    81  }
    82  
    83  func (s *Spec) nextID() string {
    84  	ret := fmt.Sprintf("%d", s.id)
    85  	s.id++
    86  	return ret
    87  }
    88  
    89  func (s *Spec) testStandardCoder() (err error) {
    90  	if unimplementedCoders[s.Coder.Urn] {
    91  		log.Printf("skipping unimplemented coder urn: %v", s.Coder.Urn)
    92  		return nil
    93  	}
    94  	if s.Coder.Urn == "beam:coder:state_backed_iterable:v1" {
    95  		log.Printf("skipping unimplemented test coverage for beam:coder:state_backed_iterable:v1. https://github.com/apache/beam/issues/21324")
    96  		return nil
    97  	}
    98  	for _, c := range filteredCases {
    99  		if strings.Contains(s.Coder.Payload, c.filter) {
   100  			log.Printf("skipping coder case. Unsupported in the Go SDK for now: %v Payload: %v", c.reason, s.Coder.Payload)
   101  			return nil
   102  		}
   103  	}
   104  
   105  	// Construct the coder proto equivalents.
   106  
   107  	// Only nested tests need to be run, since nestedness is a pre-portability
   108  	// concept.
   109  	// For legacy Java reasons, the row coder examples are all marked nested: false
   110  	// so we need to check that before skipping unnested tests.
   111  	if s.Coder.Urn != "beam:coder:row:v1" && s.Nested != nil && !*s.Nested {
   112  		log.Printf("skipping unnested coder spec: %v\n", s.Coder)
   113  		return nil
   114  	}
   115  
   116  	s.coderPBs = make(map[string]*pipepb.Coder)
   117  	id := s.parseCoder(s.Coder)
   118  	b := graphx.NewCoderUnmarshaller(s.coderPBs)
   119  	underTest, err := b.Coder(id)
   120  	if err != nil {
   121  		return fmt.Errorf("unable to create coder: %v", err)
   122  	}
   123  
   124  	defer func() {
   125  		if e := recover(); e != nil {
   126  			err = fmt.Errorf("panicked on coder %v || %v:\n\t%v :\n%s", underTest, s.Coder, e, debug.Stack())
   127  		}
   128  	}()
   129  
   130  	var decFails, encFails int
   131  	for _, eg := range s.Examples {
   132  
   133  		// Test Decoding
   134  		// Ideally we'd use the beam package coders, but KVs make that complicated.
   135  		// This can be cleaned up once a type parametered beam.KV type exists.
   136  		dec := exec.MakeElementDecoder(underTest)
   137  		encoded := eg.Key.(string)
   138  		var elem exec.FullValue
   139  
   140  		// What I would have expected.
   141  		//		r := charmap.ISO8859_1.NewDecoder().Reader(strings.NewReader(encoded))
   142  		recoded, err := charmap.ISO8859_1.NewEncoder().String(encoded)
   143  		if err != nil {
   144  			return err
   145  		}
   146  		r := strings.NewReader(recoded)
   147  		if err := dec.DecodeTo(r, &elem); err != nil {
   148  			return fmt.Errorf("err decoding %q: %v", encoded, err)
   149  		}
   150  		if !diff(s.Coder, &elem, eg) {
   151  			decFails++
   152  			continue
   153  		}
   154  
   155  		// Test Encoding
   156  		if s.Coder.NonDeterministic {
   157  			// Skip verifying nondeterministic encodings.
   158  			continue
   159  		}
   160  		enc := exec.MakeElementEncoder(underTest)
   161  		var out bytes.Buffer
   162  		if err := enc.Encode(&elem, &out); err != nil {
   163  			return err
   164  		}
   165  		if d := cmp.Diff(recoded, string(out.Bytes())); d != "" {
   166  			log.Printf("Encoding error: diff(-want,+got): %v\n", d)
   167  		}
   168  	}
   169  	if decFails+encFails > 0 {
   170  		return fmt.Errorf("failed to decode %v times, and encode %v times", decFails, encFails)
   171  	}
   172  
   173  	return nil
   174  }
   175  
   176  var cmpOpts = []cmp.Option{
   177  	cmp.Transformer("bytes2string", func(in []byte) (out string) {
   178  		return string(in)
   179  	}),
   180  }
   181  
   182  func diff(c Coder, elem *exec.FullValue, eg yaml.MapItem) bool {
   183  	var got, want any
   184  	switch c.Urn {
   185  	case "beam:coder:bytes:v1":
   186  		got = string(elem.Elm.([]byte))
   187  		switch egv := eg.Value.(type) {
   188  		case string:
   189  			want = egv
   190  		case []byte:
   191  			want = string(egv)
   192  		}
   193  	case "beam:coder:varint:v1":
   194  		got, want = elem.Elm.(int64), int64(eg.Value.(int))
   195  	case "beam:coder:double:v1":
   196  		got = elem.Elm.(float64)
   197  		switch v := eg.Value.(string); v {
   198  		case "NaN":
   199  			// Do the NaN comparison here since NaN by definition != NaN.
   200  			if math.IsNaN(got.(float64)) {
   201  				want, got = 1, 1
   202  			} else {
   203  				want = math.NaN()
   204  			}
   205  		case "-Infinity":
   206  			want = math.Inf(-1)
   207  		case "Infinity":
   208  			want = math.Inf(1)
   209  		default:
   210  			want, _ = strconv.ParseFloat(v, 64)
   211  		}
   212  
   213  	case "beam:coder:kv:v1":
   214  		v := eg.Value.(yaml.MapSlice)
   215  		pass := true
   216  		if !diff(c.Components[0], &exec.FullValue{Elm: elem.Elm}, v[0]) {
   217  			pass = false
   218  		}
   219  		if !diff(c.Components[1], &exec.FullValue{Elm: elem.Elm2}, v[1]) {
   220  			pass = false
   221  		}
   222  		return pass
   223  
   224  	case "beam:coder:nullable:v1":
   225  		if elem.Elm == nil || eg.Value == nil {
   226  			got, want = elem.Elm, eg.Value
   227  		} else {
   228  			got = string(elem.Elm.([]byte))
   229  			switch egv := eg.Value.(type) {
   230  			case string:
   231  				want = egv
   232  			case []byte:
   233  				want = string(egv)
   234  			}
   235  		}
   236  
   237  	case "beam:coder:iterable:v1":
   238  		pass := true
   239  		gotrv := reflect.ValueOf(elem.Elm)
   240  		wantrv := reflect.ValueOf(eg.Value)
   241  		if gotrv.Len() != wantrv.Len() {
   242  			log.Printf("Lengths don't match. got %v, want %v;  %v, %v", gotrv.Len(), wantrv.Len(), gotrv, wantrv)
   243  			return false
   244  		}
   245  		for i := 0; i < wantrv.Len(); i++ {
   246  			if !diff(c.Components[0],
   247  				&exec.FullValue{Elm: gotrv.Index(i).Interface()},
   248  				yaml.MapItem{Value: wantrv.Index(i).Interface()}) {
   249  				pass = false
   250  			}
   251  
   252  		}
   253  		return pass
   254  	case "beam:coder:interval_window:v1":
   255  		var a, b int
   256  		val := eg.Value
   257  		if is, ok := eg.Value.([]any); ok {
   258  			val = is[0]
   259  		}
   260  		v := val.(yaml.MapSlice)
   261  
   262  		a = v[0].Value.(int)
   263  		b = v[1].Value.(int)
   264  		end := mtime.FromMilliseconds(int64(a))
   265  		start := end - mtime.Time(int64(b))
   266  		want = window.IntervalWindow{Start: start, End: end}
   267  		// If this is nested in an iterable, windows won't be populated.
   268  		if len(elem.Windows) == 0 {
   269  			got = elem.Elm
   270  		} else {
   271  			got = elem.Windows[0]
   272  		}
   273  
   274  	case "beam:coder:global_window:v1":
   275  		want = window.GlobalWindow{}
   276  		// If this is nested in an iterable, windows won't be populated.
   277  		if len(elem.Windows) == 0 {
   278  			got = window.GlobalWindow(elem.Elm.(struct{}))
   279  		} else {
   280  			got = elem.Windows[0]
   281  		}
   282  	case "beam:coder:windowed_value:v1", "beam:coder:param_windowed_value:v1":
   283  		// elem contains all the information, but we need to compare the element+timestamp
   284  		// separately from the windows, to avoid repeated expected value parsing logic.
   285  		pass := true
   286  		vs := eg.Value.(yaml.MapSlice)
   287  		if !diff(c.Components[0], elem, vs[0]) {
   288  			pass = false
   289  		}
   290  		if d := cmp.Diff(
   291  			mtime.FromMilliseconds(int64(vs[1].Value.(int))),
   292  			elem.Timestamp, cmpOpts...); d != "" {
   293  
   294  			pass = false
   295  		}
   296  		if !diff(c.Components[1], elem, vs[3]) {
   297  			pass = false
   298  		}
   299  		if !diffPane(vs[2].Value, elem.Pane) {
   300  			pass = false
   301  		}
   302  		return pass
   303  	case "beam:coder:row:v1":
   304  		fs := eg.Value.(yaml.MapSlice)
   305  		var rfs []reflect.StructField
   306  		// There are only 2 pointer examples, but they reuse field names,
   307  		// so we key off the proto hash to know which example we're handling.
   308  		ptrEg := strings.Contains(c.Payload, "51ace21c7393")
   309  		for _, rf := range fs {
   310  			name := rf.Key.(string)
   311  			t := nameToType[name]
   312  			if ptrEg {
   313  				t = reflect.PtrTo(t)
   314  			}
   315  			rfs = append(rfs, reflect.StructField{
   316  				Name: strings.ToUpper(name[:1]) + name[1:],
   317  				Type: t,
   318  				Tag:  reflect.StructTag(fmt.Sprintf("beam:\"%v\"", name)),
   319  			})
   320  		}
   321  		rv := reflect.New(reflect.StructOf(rfs)).Elem()
   322  		for i, rf := range fs {
   323  			setField(rv, i, rf.Value)
   324  		}
   325  
   326  		got, want = elem.Elm, rv.Interface()
   327  	case "beam:coder:timer:v1":
   328  		pass := true
   329  		tm := elem.Elm.(exec.TimerRecv)
   330  		fs := eg.Value.(yaml.MapSlice)
   331  		for _, item := range fs {
   332  
   333  			switch item.Key.(string) {
   334  			case "userKey":
   335  				if want := item.Value.(string); want != tm.Key.Elm.(string) {
   336  					pass = false
   337  				}
   338  			case "dynamicTimerTag":
   339  				if want := item.Value.(string); want != tm.Tag {
   340  					pass = false
   341  				}
   342  			case "windows":
   343  				if v, ok := item.Value.([]any); ok {
   344  					for i, val := range v {
   345  						if val.(string) == "global" && fmt.Sprintf("%s", tm.Windows[i]) == "[*]" {
   346  							continue
   347  						} else if val.(string) != fmt.Sprintf("%s", tm.Windows[i]) {
   348  							pass = false
   349  						}
   350  					}
   351  				}
   352  			case "clearBit":
   353  				if want := item.Value.(bool); want != tm.Clear {
   354  					pass = false
   355  				}
   356  			case "fireTimestamp":
   357  				if want := item.Value.(int); want != int(tm.FireTimestamp) {
   358  					pass = false
   359  				}
   360  			case "holdTimestamp":
   361  				if want := item.Value.(int); want != int(tm.HoldTimestamp) {
   362  					pass = false
   363  				}
   364  			case "pane":
   365  				pass = diffPane(item.Value, tm.Pane)
   366  			}
   367  		}
   368  		return pass
   369  	default:
   370  		got, want = elem.Elm, eg.Value
   371  	}
   372  	if d := cmp.Diff(want, got, cmpOpts...); d != "" {
   373  		log.Printf("Decoding error: diff(-want,+got): %v\n", d)
   374  		return false
   375  	}
   376  	return true
   377  }
   378  
   379  func diffPane(eg any, got typex.PaneInfo) bool {
   380  	pass := true
   381  	paneTiming := map[typex.PaneTiming]string{
   382  		typex.PaneUnknown: "UNKNOWN",
   383  		typex.PaneEarly:   "EARLY",
   384  		typex.PaneLate:    "LATE",
   385  		typex.PaneOnTime:  "ONTIME",
   386  	}
   387  	for _, item := range eg.(yaml.MapSlice) {
   388  		switch item.Key.(string) {
   389  		case "is_first":
   390  			if want := item.Value.(bool); want != got.IsFirst {
   391  				pass = false
   392  			}
   393  		case "is_last":
   394  			if want := item.Value.(bool); want != got.IsLast {
   395  				pass = false
   396  			}
   397  		case "timing":
   398  			if want := item.Value.(string); want != paneTiming[got.Timing] {
   399  				pass = false
   400  			}
   401  		case "index":
   402  			if want := item.Value.(int); want != int(got.Index) {
   403  				pass = false
   404  			}
   405  		case "on_time_index":
   406  			if want := item.Value.(int); want != int(got.NonSpeculativeIndex) {
   407  				pass = false
   408  			}
   409  		}
   410  	}
   411  	return pass
   412  }
   413  
   414  // standard_coders.yaml uses the name for type indication, except for nullability.
   415  var nameToType = map[string]reflect.Type{
   416  	"str":     reflectx.String,
   417  	"i32":     reflectx.Int32,
   418  	"f64":     reflectx.Float64,
   419  	"arr":     reflect.SliceOf(reflectx.String),
   420  	"f_bool":  reflectx.Bool,
   421  	"f_bytes": reflect.PtrTo(reflectx.ByteSlice),
   422  	"f_map":   reflect.MapOf(reflectx.String, reflect.PtrTo(reflectx.Int64)),
   423  	"f_float": reflectx.Float32,
   424  }
   425  
   426  func setField(rv reflect.Value, i int, v any) {
   427  	if v == nil {
   428  		return
   429  	}
   430  	rf := rv.Field(i)
   431  	if rf.Kind() == reflect.Ptr {
   432  		// Ensure it's initialized.
   433  		rf.Set(reflect.New(rf.Type().Elem()))
   434  		rf = rf.Elem()
   435  	}
   436  	switch rf.Kind() {
   437  	case reflect.String:
   438  		rf.SetString(v.(string))
   439  	case reflect.Int32:
   440  		rf.SetInt(int64(v.(int)))
   441  	case reflect.Float32:
   442  		c, err := strconv.ParseFloat(v.(string), 32)
   443  		if err != nil {
   444  			panic(err)
   445  		}
   446  		rf.SetFloat(c)
   447  	case reflect.Float64:
   448  		c, err := strconv.ParseFloat(v.(string), 64)
   449  		if err != nil {
   450  			panic(err)
   451  		}
   452  		rf.SetFloat(c)
   453  	case reflect.Slice:
   454  		if rf.Type() == reflectx.ByteSlice {
   455  			rf.Set(reflect.ValueOf([]byte(v.(string))))
   456  			break
   457  		}
   458  		// Value is a []any with string values.
   459  		var arr []string
   460  		for _, a := range v.([]any) {
   461  			arr = append(arr, a.(string))
   462  		}
   463  		rf.Set(reflect.ValueOf(arr))
   464  	case reflect.Bool:
   465  		rf.SetBool(v.(bool))
   466  	case reflect.Map:
   467  		// only f_map presently, which is always map[string]*int64
   468  		rm := reflect.MakeMap(rf.Type())
   469  		for _, a := range v.(yaml.MapSlice) {
   470  			rk := reflect.ValueOf(a.Key.(string))
   471  			rv := reflect.Zero(rf.Type().Elem())
   472  			if a.Value != nil {
   473  				rv = reflect.New(reflectx.Int64)
   474  				rv.Elem().SetInt(int64(a.Value.(int)))
   475  			}
   476  			rm.SetMapIndex(rk, rv)
   477  		}
   478  		rf.Set(rm)
   479  
   480  	}
   481  }
   482  
   483  func (s *Spec) parseCoder(c Coder) string {
   484  	id := s.nextID()
   485  	var compIDs []string
   486  	for _, comp := range c.Components {
   487  		compIDs = append(compIDs, s.parseCoder(comp))
   488  	}
   489  	s.coderPBs[id] = &pipepb.Coder{
   490  		Spec: &pipepb.FunctionSpec{
   491  			Urn:     c.Urn,
   492  			Payload: []byte(c.Payload),
   493  		},
   494  		ComponentCoderIds: compIDs,
   495  	}
   496  	return id
   497  }
   498  
   499  // Simple logger to run as main program.
   500  type logLogger struct{}
   501  
   502  func (*logLogger) Errorf(format string, v ...any) {
   503  	log.Printf(format, v...)
   504  }
   505  
   506  func (*logLogger) Logf(format string, v ...any) {
   507  	log.Printf(format, v...)
   508  }
   509  
   510  const yamlPath = "../../../../../../model/fn-execution/src/main/resources/org/apache/beam/model/fnexecution/v1/standard_coders.yaml"
   511  
   512  func main() {
   513  	data, err := os.ReadFile(yamlPath)
   514  	if err != nil {
   515  		log.Fatalf("Couldn't read %v: %v", yamlPath, err)
   516  	}
   517  	specs := bytes.Split(data, []byte("\n---\n"))
   518  	var failures bool
   519  	var l logLogger
   520  	for _, data := range specs {
   521  		cs := Spec{Log: &l}
   522  		if err := yaml.Unmarshal(data, &cs); err != nil {
   523  			failures = true
   524  			l.Logf("unable to parse yaml: %v %q", err, data)
   525  			continue
   526  		}
   527  		if err := cs.testStandardCoder(); err != nil {
   528  			failures = true
   529  			l.Logf("Failed \"%v\": %v", cs.Coder, err)
   530  		}
   531  	}
   532  	if !failures {
   533  		log.Println("PASS")
   534  	}
   535  }