kythe.io@v0.0.68-0.20240422202219-7225dbc01741/kythe/go/serving/pipeline/beamtest/beamtest.go (about)

     1  /*
     2   * Copyright 2018 The Kythe Authors. All rights reserved.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *   http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  // Package beamtest contains utilities to test Apache Beam pipelines.
    18  package beamtest // import "kythe.io/kythe/go/serving/pipeline/beamtest"
    19  
    20  import (
    21  	"fmt"
    22  	"reflect"
    23  	"testing"
    24  
    25  	"github.com/apache/beam/sdks/go/pkg/beam"
    26  	"github.com/apache/beam/sdks/go/pkg/beam/core/runtime"
    27  	"github.com/apache/beam/sdks/go/pkg/beam/core/util/reflectx"
    28  	"google.golang.org/protobuf/proto"
    29  )
    30  
    31  // CheckRegistrations returns an error if p uses any non-registered types.
    32  func CheckRegistrations(t *testing.T, p *beam.Pipeline) {
    33  	edges, nodes, err := p.Build()
    34  	if err != nil {
    35  		t.Fatalf("invalid pipeline: %v", err)
    36  	}
    37  	for _, n := range nodes {
    38  		if err := checkFullType(n.Type()); err != nil {
    39  			t.Error(n)
    40  		}
    41  	}
    42  	for _, e := range edges {
    43  		if fn := e.DoFn; fn != nil {
    44  			if fn.Recv != nil {
    45  				if err := checkFnType(reflect.TypeOf(fn.Recv)); err != nil {
    46  					t.Error(err)
    47  				}
    48  			} else if _, err := runtime.ResolveFunction(fn.Name(), fn.Fn.Fn.Type()); err != nil {
    49  				t.Errorf("error resolving function: %v", err)
    50  			}
    51  		}
    52  	}
    53  }
    54  
    55  var protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
    56  
    57  func checkType(t reflect.Type) error {
    58  	key, keyValid := runtime.TypeKey(reflectx.SkipPtr(t))
    59  	if !keyValid {
    60  		return nil
    61  	} else if _, ok := runtime.LookupType(key); !ok && t.Implements(protoMessageType) {
    62  		return fmt.Errorf("unregistered proto.Message type: %v", t)
    63  	}
    64  	return nil
    65  }
    66  
    67  func checkFnType(t reflect.Type) error {
    68  	key, keyValid := runtime.TypeKey(reflectx.SkipPtr(t))
    69  	if !keyValid {
    70  		return nil
    71  	} else if _, ok := runtime.LookupType(key); !ok {
    72  		return fmt.Errorf("unregistered function type: %v", t)
    73  	}
    74  	return nil
    75  }
    76  
    77  func checkFullType(t beam.FullType) error {
    78  	if err := checkType(t.Type()); err != nil {
    79  		return err
    80  	}
    81  	for _, c := range t.Components() {
    82  		if err := checkFullType(c); err != nil {
    83  			return err
    84  		}
    85  	}
    86  	return nil
    87  }