github.com/erda-project/erda-infra@v1.0.9/providers/grpcclient/provider.go (about)

     1  // Copyright (c) 2021 Terminus, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package grpcclient
    16  
    17  import (
    18  	"context"
    19  	"crypto/tls"
    20  	"fmt"
    21  	"reflect"
    22  
    23  	"google.golang.org/grpc"
    24  	"google.golang.org/grpc/credentials"
    25  
    26  	"github.com/erda-project/erda-infra/base/logs"
    27  	"github.com/erda-project/erda-infra/base/servicehub"
    28  	grpccontext "github.com/erda-project/erda-infra/pkg/trace/inject/context/grpc"
    29  	transgrpc "github.com/erda-project/erda-infra/pkg/transport/grpc"
    30  )
    31  
    32  // Interface .
    33  type Interface interface {
    34  	Get() *grpc.ClientConn
    35  	NewConnect(opts ...grpc.DialOption) (*grpc.ClientConn, error)
    36  }
    37  
    38  var (
    39  	clientConnType          = reflect.TypeOf((*grpc.ClientConn)(nil))
    40  	clientConnInterfaceType = reflect.TypeOf((*transgrpc.ClientConnInterface)(nil)).Elem()
    41  	interfaceType           = reflect.TypeOf((*Interface)(nil)).Elem()
    42  )
    43  
    44  type config struct {
    45  	Addr string `file:"addr" default:":7070" desc:"the server address in the format of host:port"`
    46  	TLS  struct {
    47  		ServerNameOverride string `file:"cert_file" desc:"the server name used to verify the hostname returned by the TLS handshake"`
    48  		CAFile             string `file:"ca_file" desc:"the file containing the CA root cert file"`
    49  		InsecureSkipVerify bool   `file:"insecure_skip_verify" desc:"skip verify"`
    50  	} `file:"tls"`
    51  	Singleton   bool `file:"singleton" default:"true" desc:"one client instance"`
    52  	Block       bool `file:"block" default:"true" desc:"block until the connection is up"`
    53  	TraceEnable bool `file:"trace_enable" default:"true"`
    54  }
    55  
    56  type provider struct {
    57  	Cfg  *config
    58  	Log  logs.Logger
    59  	conn *grpc.ClientConn
    60  	opts []grpc.DialOption
    61  }
    62  
    63  func (p *provider) Init(ctx servicehub.Context) error {
    64  	var opts []grpc.DialOption
    65  	if len(p.Cfg.TLS.CAFile) > 0 {
    66  		creds, err := credentials.NewClientTLSFromFile(p.Cfg.TLS.CAFile, p.Cfg.TLS.ServerNameOverride)
    67  		if err != nil {
    68  			return fmt.Errorf("fail to create tls credentials %s", err)
    69  		}
    70  		opts = append(opts, grpc.WithTransportCredentials(creds))
    71  	} else {
    72  		// distinguish `no tls` or `tls: insecure skip verify`
    73  		notls := true // default no tls, compatible with old config
    74  		if p.Cfg.TLS.InsecureSkipVerify {
    75  			notls = false
    76  		}
    77  		if notls {
    78  			opts = append(opts, grpc.WithInsecure())
    79  		} else {
    80  			insecureSkipVerifyTLS := credentials.NewTLS(&tls.Config{InsecureSkipVerify: true})
    81  			opts = append(opts, grpc.WithTransportCredentials(insecureSkipVerifyTLS))
    82  		}
    83  	}
    84  	if p.Cfg.TraceEnable {
    85  		opts = append(opts,
    86  			grpc.WithUnaryInterceptor(grpccontext.UnaryClientInterceptor()),
    87  			grpc.WithStreamInterceptor(grpccontext.StreamClientInterceptor()),
    88  		)
    89  	}
    90  	p.opts = opts
    91  	if p.Cfg.Singleton {
    92  		opts = nil
    93  		if p.Cfg.Block {
    94  			opts = append(opts, grpc.WithBlock())
    95  		}
    96  		conn, err := p.NewConnect(opts...)
    97  		if err != nil {
    98  			return fmt.Errorf("fail to dial: %s", err)
    99  		}
   100  		p.conn = conn
   101  	}
   102  	return nil
   103  }
   104  
   105  func (p *provider) Get() *grpc.ClientConn { return p.conn }
   106  func (p *provider) NewConnect(opts ...grpc.DialOption) (*grpc.ClientConn, error) {
   107  	return grpc.Dial(p.Cfg.Addr, append(opts, p.opts...)...)
   108  }
   109  
   110  func (p *provider) Run(ctx context.Context) error {
   111  	if p.Cfg.Singleton {
   112  		select {
   113  		case <-ctx.Done():
   114  			p.conn.Close()
   115  			return nil
   116  		}
   117  	}
   118  	return nil
   119  }
   120  
   121  func (p *provider) Provide(ctx servicehub.DependencyContext, args ...interface{}) interface{} {
   122  	if ctx.Service() == "grpc-client-connector" || ctx.Type() == interfaceType {
   123  		return p
   124  	}
   125  	return p.conn
   126  }
   127  
   128  func init() {
   129  	servicehub.Register("grpc-client", &servicehub.Spec{
   130  		Services: []string{"grpc-client", "grpc-client-connector"},
   131  		Types: []reflect.Type{
   132  			clientConnType,
   133  			clientConnInterfaceType,
   134  			interfaceType,
   135  		},
   136  		Description: "grpc client",
   137  		ConfigFunc: func() interface{} {
   138  			return &config{}
   139  		},
   140  		Creator: func() servicehub.Provider {
   141  			return &provider{}
   142  		},
   143  	})
   144  }