github.com/apache/beam/sdks/v2@v2.48.2/go/test/integration/transforms/xlang/inference/inference.go (about) 1 // Licensed to the Apache Software Foundation (ASF) under one or more 2 // contributor license agreements. See the NOTICE file distributed with 3 // this work for additional information regarding copyright ownership. 4 // The ASF licenses this file to You under the Apache License, Version 2.0 5 // (the "License"); you may not use this file except in compliance with 6 // the License. You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 package inference 17 18 import ( 19 "github.com/apache/beam/sdks/v2/go/pkg/beam" 20 _ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/dataflow" 21 _ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/flink" 22 _ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/universal" 23 "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert" 24 "github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/xlang/inference" 25 ) 26 27 func SklearnInference(expansionAddr string) *beam.Pipeline { 28 p, s := beam.NewPipelineWithRoot() 29 30 inputRow := [][]int64{{0, 0}, {1, 1}} 31 input := beam.CreateList(s, inputRow) 32 output := []inference.PredictionResult{ 33 { 34 Example: []int64{0, 0}, 35 Inference: 0, 36 }, 37 { 38 Example: []int64{1, 1}, 39 Inference: 1, 40 }, 41 } 42 outCol := inference.SklearnModel("/tmp/staged/sklearn_model").RunInference(s, input, inference.WithExpansionAddr(expansionAddr)) 43 passert.Equals(s, outCol, output[0], output[1]) 44 return p 45 }