github.com/erda-project/erda-infra@v1.0.9/base/servicehub/provider_context.go (about)

     1  // Copyright (c) 2021 Terminus, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package servicehub
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"os"
    21  	"reflect"
    22  	"strconv"
    23  	"strings"
    24  
    25  	"github.com/recallsong/go-utils/encoding/jsonx"
    26  	"github.com/recallsong/unmarshal"
    27  	unmarshalflag "github.com/recallsong/unmarshal/unmarshal-flag"
    28  	"github.com/spf13/pflag"
    29  
    30  	"github.com/erda-project/erda-infra/base/logs"
    31  	"github.com/erda-project/erda-infra/pkg/config"
    32  )
    33  
    34  type inheritLabelStrategy string
    35  
    36  const (
    37  	inheritLabelTrue      inheritLabelStrategy = "true"
    38  	inheritLabelFalse     inheritLabelStrategy = "false"
    39  	inheritLabelPreferred inheritLabelStrategy = "preferred"
    40  )
    41  
    42  type providerContext struct {
    43  	context.Context
    44  	hub         *Hub
    45  	key         string
    46  	label       string
    47  	name        string
    48  	cfg         interface{}
    49  	provider    Provider
    50  	structValue reflect.Value
    51  	structType  reflect.Type
    52  	define      ProviderDefine
    53  	tasks       []task
    54  }
    55  
    56  var loggerType = reflect.TypeOf((*logs.Logger)(nil)).Elem()
    57  
    58  func (c *providerContext) BindConfig(flags *pflag.FlagSet) (err error) {
    59  	if creator, ok := c.define.(ConfigCreator); ok {
    60  		cfg := creator.Config()
    61  		if cfg != nil {
    62  			err = unmarshal.BindDefault(cfg)
    63  			if err != nil {
    64  				return err
    65  			}
    66  			if c.cfg != nil {
    67  				err = config.ConvertData(c.cfg, cfg, "file")
    68  				if err != nil {
    69  					return err
    70  				}
    71  			}
    72  			err = unmarshal.BindEnv(cfg)
    73  			if err != nil {
    74  				return err
    75  			}
    76  			if flags != nil {
    77  				err = unmarshalflag.BindFlag(flags, cfg)
    78  				if err != nil {
    79  					return err
    80  				}
    81  			}
    82  			c.cfg = cfg
    83  			return nil
    84  		}
    85  	}
    86  	c.cfg = nil
    87  	return nil
    88  }
    89  
    90  func (c *providerContext) Init() (err error) {
    91  	if reflect.ValueOf(c.provider).Kind() == reflect.Ptr && c.structType != nil {
    92  		value, typ := c.structValue, c.structType
    93  		var (
    94  			cfgValue *reflect.Value
    95  			cfgType  reflect.Type
    96  		)
    97  		if c.cfg != nil {
    98  			value := reflect.ValueOf(c.cfg)
    99  			cfgValue = &value
   100  			cfgType = cfgValue.Type()
   101  		}
   102  		fields := typ.NumField()
   103  		for i := 0; i < fields; i++ {
   104  			if !value.Field(i).CanSet() {
   105  				continue
   106  			}
   107  			field := typ.Field(i)
   108  			if field.Type == loggerType {
   109  				logger := c.Logger()
   110  				value.Field(i).Set(reflect.ValueOf(logger))
   111  			}
   112  			if cfgValue != nil && field.Type == cfgType {
   113  				value.Field(i).Set(*cfgValue)
   114  			}
   115  			service := field.Tag.Get("service")
   116  			if len(service) <= 0 {
   117  				service = field.Tag.Get("autowired")
   118  			}
   119  			if service == "-" {
   120  				continue
   121  			}
   122  			service = c.adjustDependServiceLabel(service, &field)
   123  			dc := newDependencyContext(
   124  				service,
   125  				c.name,
   126  				field.Type,
   127  				field.Tag,
   128  			)
   129  			instance := c.hub.getService(dc)
   130  			if len(service) > 0 && instance == nil {
   131  				opt, err := boolTagValue(field.Tag, "optional", false)
   132  				if err != nil {
   133  					return fmt.Errorf("invalid optional tag value in %s.%s: %s", typ.String(), field.Name, err)
   134  				}
   135  				if opt {
   136  					continue
   137  				}
   138  				return fmt.Errorf("not found service %q", service)
   139  			}
   140  			if instance == nil {
   141  				continue
   142  			}
   143  			if !reflect.TypeOf(instance).AssignableTo(field.Type) {
   144  				return fmt.Errorf("service %q not implement %s", service, field.Type)
   145  			}
   146  			value.Field(i).Set(reflect.ValueOf(instance))
   147  		}
   148  	}
   149  	if c.cfg != nil {
   150  		key := c.key
   151  		if key != c.name {
   152  			key = fmt.Sprintf("%s (%s)", key, c.name)
   153  		}
   154  		if os.Getenv("LOG_LEVEL") == "debug" {
   155  			fmt.Printf("provider %s config: \n%s\n", key, jsonx.MarshalAndIndent(c.cfg))
   156  		}
   157  		// c.hub.logger.Debugf("provider %s config: \n%s", key, jsonx.MarshalAndIndent(c.cfg))
   158  	}
   159  
   160  	if initializer, ok := c.provider.(ProviderInitializer); ok {
   161  		err = initializer.Init(c)
   162  		if err != nil {
   163  			return fmt.Errorf("fail to Init provider %s: %s", c.name, err)
   164  		}
   165  	}
   166  	return nil
   167  }
   168  
   169  // Define .
   170  func (c *providerContext) Define() ProviderDefine {
   171  	return c.define
   172  }
   173  
   174  func (c *providerContext) dependencies() string {
   175  	services, providers := c.Dependencies()
   176  	if len(services) > 0 && len(providers) > 0 {
   177  		return fmt.Sprintf("services: %v, providers: %v", services, providers)
   178  	} else if len(services) > 0 {
   179  		return fmt.Sprintf("services: %v", services)
   180  	} else if len(providers) > 0 {
   181  		return fmt.Sprintf("providers: %v", providers)
   182  	}
   183  	return ""
   184  }
   185  
   186  func boolTagValue(tag reflect.StructTag, key string, defval bool) (bool, error) {
   187  	opt, ok := tag.Lookup(key)
   188  	if ok {
   189  		if len(opt) > 0 {
   190  			b, err := strconv.ParseBool(opt)
   191  			if err != nil {
   192  				return defval, err
   193  			}
   194  			return b, nil
   195  		}
   196  	}
   197  	return defval, nil
   198  }
   199  
   200  func (c *providerContext) adjustDependServiceLabel(service string, field *reflect.StructField) string {
   201  	if len(c.label) == 0 || strings.Contains(service, "@") {
   202  		return service
   203  	}
   204  	inheritLabel := field.Tag.Get("inherit-label")
   205  	switch inheritLabelStrategy(inheritLabel) {
   206  	case inheritLabelTrue:
   207  		return fmt.Sprintf("%s@%s", service, c.label)
   208  	case inheritLabelPreferred:
   209  		pcs := c.hub.servicesMap[service]
   210  		for _, pc := range pcs {
   211  			if pc.label == c.label {
   212  				return fmt.Sprintf("%s@%s", service, c.label)
   213  			}
   214  		}
   215  	case inheritLabelFalse:
   216  	default:
   217  	}
   218  	return service
   219  }
   220  
   221  func (c *providerContext) fullName() string {
   222  	if len(c.label) == 0 {
   223  		return c.name
   224  	}
   225  	return fmt.Sprintf("%s@%s", c.name, c.label)
   226  }
   227  
   228  // Dependencies .
   229  func (c *providerContext) Dependencies() (services []string, providers []string) {
   230  	srvset, provset := make(map[string]bool), make(map[reflect.Type]bool)
   231  	if deps, ok := c.define.(FixedServiceDependencies); ok {
   232  		for _, service := range deps.Dependencies() {
   233  			if !srvset[service] {
   234  				services = append(services, service)
   235  				srvset[service] = true
   236  			}
   237  		}
   238  	}
   239  	if deps, ok := c.define.(ServiceDependencies); ok {
   240  		for _, service := range deps.Dependencies(c.hub) {
   241  			if !srvset[service] {
   242  				services = append(services, service)
   243  				srvset[service] = true
   244  			}
   245  		}
   246  	}
   247  	if deps, ok := c.define.(OptionalServiceDependencies); ok {
   248  		for _, service := range deps.OptionalDependencies(c.hub) {
   249  			if len(c.hub.servicesMap[service]) > 0 && !srvset[service] {
   250  				services = append(services, service)
   251  				srvset[service] = true
   252  			}
   253  		}
   254  	}
   255  	if c.structType != nil {
   256  		fields := c.structType.NumField()
   257  		for i := 0; i < fields; i++ {
   258  			field := c.structType.Field(i)
   259  			service := field.Tag.Get("service")
   260  			if len(service) <= 0 {
   261  				service = field.Tag.Get("autowired")
   262  			}
   263  			if service == "-" {
   264  				continue
   265  			}
   266  			if len(service) > 0 {
   267  				service = c.adjustDependServiceLabel(service, &field)
   268  				opt, _ := boolTagValue(field.Tag, "optional", false)
   269  				if opt {
   270  					if len(c.hub.servicesMap[service]) > 0 && !srvset[service] {
   271  						services = append(services, service)
   272  						srvset[service] = true
   273  					}
   274  				} else if !srvset[service] {
   275  					services = append(services, service)
   276  					srvset[service] = true
   277  				}
   278  				continue
   279  			}
   280  			if !c.structValue.Field(i).CanSet() {
   281  				continue
   282  			}
   283  			plist := c.hub.servicesTypes[field.Type]
   284  			if len(plist) > 0 && !provset[field.Type] {
   285  				provset[field.Type] = true
   286  				providers = append(providers, plist[0].name)
   287  			}
   288  		}
   289  	}
   290  	return
   291  }
   292  
   293  // Hub .
   294  func (c *providerContext) Hub() *Hub {
   295  	return c.hub
   296  }
   297  
   298  // Logger .
   299  func (c *providerContext) Logger() logs.Logger {
   300  	if c.hub.logger == nil {
   301  		return nil
   302  	}
   303  	return c.hub.logger.Sub(c.key)
   304  }
   305  
   306  // Config .
   307  func (c *providerContext) Config() interface{} {
   308  	return c.cfg
   309  }
   310  
   311  // Service .
   312  func (c *providerContext) Service(name string, options ...interface{}) interface{} {
   313  	return c.hub.getService(newDependencyContext(
   314  		name,
   315  		c.name,
   316  		nil,
   317  		reflect.StructTag(""),
   318  	), options...)
   319  }
   320  
   321  // AddTask .
   322  func (c *providerContext) AddTask(fn func(context.Context) error, options ...TaskOption) {
   323  	t := task{
   324  		name: "",
   325  		fn:   fn,
   326  	}
   327  	for _, opt := range options {
   328  		opt(&t)
   329  	}
   330  	c.tasks = append(c.tasks, t)
   331  }
   332  
   333  // Label .
   334  func (c *providerContext) Label() string {
   335  	return c.label
   336  }
   337  
   338  // Key .
   339  func (c *providerContext) Key() string {
   340  	return c.key
   341  }
   342  
   343  // Provider .
   344  func (c *providerContext) Provider() Provider {
   345  	return c.provider
   346  }
   347  
   348  // WithTaskName .
   349  func WithTaskName(name string) TaskOption {
   350  	return func(t *task) {
   351  		t.name = name
   352  	}
   353  }
   354  
   355  type task struct {
   356  	name string
   357  	fn   func(context.Context) error
   358  }
   359  
   360  // dependencyContext .
   361  type dependencyContext struct {
   362  	typ     reflect.Type
   363  	tags    reflect.StructTag
   364  	service string
   365  	key     string
   366  	label   string
   367  	caller  string
   368  }
   369  
   370  func (dc *dependencyContext) Type() reflect.Type      { return dc.typ }
   371  func (dc *dependencyContext) Tags() reflect.StructTag { return dc.tags }
   372  func (dc *dependencyContext) Service() string         { return dc.service }
   373  func (dc *dependencyContext) Key() string             { return dc.key }
   374  func (dc *dependencyContext) Label() string           { return dc.label }
   375  func (dc *dependencyContext) Caller() string          { return dc.caller }
   376  
   377  func newDependencyContext(service, caller string, typ reflect.Type, tags reflect.StructTag) *dependencyContext {
   378  	dc := &dependencyContext{
   379  		typ:     typ,
   380  		tags:    tags,
   381  		key:     service,
   382  		service: service,
   383  		caller:  caller,
   384  	}
   385  	idx := strings.Index(service, "@")
   386  	if idx > 0 {
   387  		dc.service = service[0:idx]
   388  		dc.label = service[idx+1:]
   389  	}
   390  	return dc
   391  }