trpc.group/trpc-go/trpc-go@v1.0.2/trpc_util.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package trpc
    15  
    16  import (
    17  	"context"
    18  	"net"
    19  	"os"
    20  	"runtime"
    21  	"sync"
    22  	"time"
    23  
    24  	"github.com/panjf2000/ants/v2"
    25  	trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc"
    26  
    27  	"trpc.group/trpc-go/trpc-go/codec"
    28  	"trpc.group/trpc-go/trpc-go/errs"
    29  	"trpc.group/trpc-go/trpc-go/internal/report"
    30  	"trpc.group/trpc-go/trpc-go/log"
    31  )
    32  
    33  // PanicBufLen is len of buffer used for stack trace logging
    34  // when the goroutine panics, 1024 by default.
    35  var PanicBufLen = 1024
    36  
    37  // ----------------------- trpc util functions ------------------------------------ //
    38  
    39  // Message returns msg from ctx.
    40  func Message(ctx context.Context) codec.Msg {
    41  	return codec.Message(ctx)
    42  }
    43  
    44  // BackgroundContext puts an initialized msg into background context and returns it.
    45  func BackgroundContext() context.Context {
    46  	cfg := GlobalConfig()
    47  	ctx, msg := codec.WithNewMessage(context.Background())
    48  	msg.WithCalleeContainerName(cfg.Global.ContainerName)
    49  	msg.WithNamespace(cfg.Global.Namespace)
    50  	msg.WithEnvName(cfg.Global.EnvName)
    51  	if cfg.Global.EnableSet == "Y" {
    52  		msg.WithSetName(cfg.Global.FullSetName)
    53  	}
    54  	if len(cfg.Server.Service) > 0 {
    55  		msg.WithCalleeServiceName(cfg.Server.Service[0].Name)
    56  	} else {
    57  		msg.WithCalleeApp(cfg.Server.App)
    58  		msg.WithCalleeServer(cfg.Server.Server)
    59  	}
    60  	return ctx
    61  }
    62  
    63  // GetMetaData returns metadata from ctx by key.
    64  func GetMetaData(ctx context.Context, key string) []byte {
    65  	msg := codec.Message(ctx)
    66  	if len(msg.ServerMetaData()) > 0 {
    67  		return msg.ServerMetaData()[key]
    68  	}
    69  	return nil
    70  }
    71  
    72  // SetMetaData sets metadata which will be returned to upstream.
    73  // This method is not thread-safe.
    74  // Notice: SetMetaData can only be called in the server side rpc entry goroutine,
    75  // not in goroutine that calls the client.
    76  func SetMetaData(ctx context.Context, key string, val []byte) {
    77  	msg := codec.Message(ctx)
    78  	if len(msg.ServerMetaData()) > 0 {
    79  		msg.ServerMetaData()[key] = val
    80  		return
    81  	}
    82  	md := make(map[string][]byte)
    83  	md[key] = val
    84  	msg.WithServerMetaData(md)
    85  }
    86  
    87  // Request returns RequestProtocol from ctx.
    88  // If the RequestProtocol not found, a new RequestProtocol will be created and returned.
    89  func Request(ctx context.Context) *trpcpb.RequestProtocol {
    90  	msg := codec.Message(ctx)
    91  	request, ok := msg.ServerReqHead().(*trpcpb.RequestProtocol)
    92  	if !ok {
    93  		return &trpcpb.RequestProtocol{}
    94  	}
    95  	return request
    96  }
    97  
    98  // Response returns ResponseProtocol from ctx.
    99  // If the ResponseProtocol not found, a new ResponseProtocol will be created and returned.
   100  func Response(ctx context.Context) *trpcpb.ResponseProtocol {
   101  	msg := codec.Message(ctx)
   102  	response, ok := msg.ServerRspHead().(*trpcpb.ResponseProtocol)
   103  	if !ok {
   104  		return &trpcpb.ResponseProtocol{}
   105  	}
   106  	return response
   107  }
   108  
   109  // CloneContext copies the context to get a context that retains the value and doesn't cancel.
   110  // This is used when the handler is processed asynchronously to detach the original timeout control
   111  // and retains the original context information.
   112  //
   113  // After the trpc handler function returns, ctx will be canceled, and put the ctx's Msg back into pool,
   114  // and the associated Metrics and logger will be released.
   115  //
   116  // Before starting a goroutine to run the handler function asynchronously,
   117  // this method must be called to copy context, detach the original timeout control,
   118  // and retain the information in Msg for Metrics.
   119  //
   120  // Retain the logger context for printing the associated log,
   121  // keep other value in context, such as tracing context, etc.
   122  func CloneContext(ctx context.Context) context.Context {
   123  	oldMsg := codec.Message(ctx)
   124  	newCtx, newMsg := codec.WithNewMessage(detach(ctx))
   125  	codec.CopyMsg(newMsg, oldMsg)
   126  	return newCtx
   127  }
   128  
   129  type detachedContext struct{ parent context.Context }
   130  
   131  func detach(ctx context.Context) context.Context { return detachedContext{ctx} }
   132  
   133  // Deadline implements context.Deadline
   134  func (v detachedContext) Deadline() (time.Time, bool) { return time.Time{}, false }
   135  
   136  // Done implements context.Done
   137  func (v detachedContext) Done() <-chan struct{} { return nil }
   138  
   139  // Err implements context.Err
   140  func (v detachedContext) Err() error { return nil }
   141  
   142  // Value implements context.Value
   143  func (v detachedContext) Value(key interface{}) interface{} { return v.parent.Value(key) }
   144  
   145  // GoAndWait provides safe concurrent handling. Per input handler, it starts a goroutine.
   146  // Then it waits until all handlers are done and will recover if any handler panics.
   147  // The returned error is the first non-nil error returned by one of the handlers.
   148  // It can be set that non-nil error will be returned if the "key" handler fails while other handlers always
   149  // return nil error.
   150  func GoAndWait(handlers ...func() error) error {
   151  	var (
   152  		wg   sync.WaitGroup
   153  		once sync.Once
   154  		err  error
   155  	)
   156  	for _, f := range handlers {
   157  		wg.Add(1)
   158  		go func(handler func() error) {
   159  			defer func() {
   160  				if e := recover(); e != nil {
   161  					buf := make([]byte, PanicBufLen)
   162  					buf = buf[:runtime.Stack(buf, false)]
   163  					log.Errorf("[PANIC]%v\n%s\n", e, buf)
   164  					report.PanicNum.Incr()
   165  					once.Do(func() {
   166  						err = errs.New(errs.RetServerSystemErr, "panic found in call handlers")
   167  					})
   168  				}
   169  				wg.Done()
   170  			}()
   171  			if e := handler(); e != nil {
   172  				once.Do(func() {
   173  					err = e
   174  				})
   175  			}
   176  		}(f)
   177  	}
   178  	wg.Wait()
   179  	return err
   180  }
   181  
   182  // Goer is the interface that launches a testable and safe goroutine.
   183  type Goer interface {
   184  	Go(ctx context.Context, timeout time.Duration, handler func(context.Context)) error
   185  }
   186  
   187  type asyncGoer struct {
   188  	panicBufLen   int
   189  	shouldRecover bool
   190  	pool          *ants.PoolWithFunc
   191  }
   192  
   193  type goerParam struct {
   194  	ctx     context.Context
   195  	cancel  context.CancelFunc
   196  	handler func(context.Context)
   197  }
   198  
   199  // NewAsyncGoer creates a goer that executes handler asynchronously with a goroutine when Go() is called.
   200  func NewAsyncGoer(workerPoolSize int, panicBufLen int, shouldRecover bool) Goer {
   201  	g := &asyncGoer{
   202  		panicBufLen:   panicBufLen,
   203  		shouldRecover: shouldRecover,
   204  	}
   205  	if workerPoolSize == 0 {
   206  		return g
   207  	}
   208  
   209  	pool, err := ants.NewPoolWithFunc(workerPoolSize, func(args interface{}) {
   210  		p := args.(*goerParam)
   211  		g.handle(p.ctx, p.handler, p.cancel)
   212  	})
   213  	if err != nil {
   214  		panic(err)
   215  	}
   216  	g.pool = pool
   217  	return g
   218  }
   219  
   220  func (g *asyncGoer) handle(ctx context.Context, handler func(context.Context), cancel context.CancelFunc) {
   221  	defer func() {
   222  		if g.shouldRecover {
   223  			if err := recover(); err != nil {
   224  				buf := make([]byte, g.panicBufLen)
   225  				buf = buf[:runtime.Stack(buf, false)]
   226  				log.ErrorContextf(ctx, "[PANIC]%v\n%s\n", err, buf)
   227  				report.PanicNum.Incr()
   228  			}
   229  		}
   230  		cancel()
   231  	}()
   232  	handler(ctx)
   233  }
   234  
   235  func (g *asyncGoer) Go(ctx context.Context, timeout time.Duration, handler func(context.Context)) error {
   236  	oldMsg := codec.Message(ctx)
   237  	newCtx, newMsg := codec.WithNewMessage(detach(ctx))
   238  	codec.CopyMsg(newMsg, oldMsg)
   239  	newCtx, cancel := context.WithTimeout(newCtx, timeout)
   240  	if g.pool != nil {
   241  		p := &goerParam{
   242  			ctx:     newCtx,
   243  			cancel:  cancel,
   244  			handler: handler,
   245  		}
   246  		return g.pool.Invoke(p)
   247  	}
   248  	go g.handle(newCtx, handler, cancel)
   249  	return nil
   250  }
   251  
   252  type syncGoer struct {
   253  }
   254  
   255  // NewSyncGoer creates a goer that executes handler synchronously without cloning ctx when Go() is called.
   256  // it's usually used for testing.
   257  func NewSyncGoer() Goer {
   258  	return &syncGoer{}
   259  }
   260  
   261  func (g *syncGoer) Go(ctx context.Context, timeout time.Duration, handler func(context.Context)) error {
   262  	newCtx, cancel := context.WithTimeout(ctx, timeout)
   263  	defer cancel()
   264  	handler(newCtx)
   265  	return nil
   266  }
   267  
   268  // DefaultGoer is an async goer without worker pool.
   269  var DefaultGoer = NewAsyncGoer(0, PanicBufLen, true)
   270  
   271  // Go launches a safer goroutine for async task inside rpc handler.
   272  // it clones ctx and msg before the goroutine, and will recover and report metrics when the goroutine panics.
   273  // you should set a suitable timeout to control the lifetime of the new goroutine to prevent goroutine leaks.
   274  func Go(ctx context.Context, timeout time.Duration, handler func(context.Context)) error {
   275  	return DefaultGoer.Go(ctx, timeout, handler)
   276  }
   277  
   278  // expandEnv looks for ${var} in s and replaces them with value of the
   279  // corresponding environment variable.
   280  // $var is considered invalid.
   281  // It's not like os.ExpandEnv which will handle both ${var} and $var.
   282  // Since configurations like password for redis/mysql may contain $, this
   283  // method is needed.
   284  func expandEnv(s string) string {
   285  	var buf []byte
   286  	i := 0
   287  	for j := 0; j < len(s); j++ {
   288  		if s[j] == '$' && j+2 < len(s) && s[j+1] == '{' { // only ${var} instead of $var is valid
   289  			if buf == nil {
   290  				buf = make([]byte, 0, 2*len(s))
   291  			}
   292  			buf = append(buf, s[i:j]...)
   293  			name, w := getEnvName(s[j+1:])
   294  			if name == "" && w > 0 {
   295  				// invalid matching, remove the $
   296  			} else if name == "" {
   297  				buf = append(buf, s[j]) // keep the $
   298  			} else {
   299  				buf = append(buf, os.Getenv(name)...)
   300  			}
   301  			j += w
   302  			i = j + 1
   303  		}
   304  	}
   305  	if buf == nil {
   306  		return s
   307  	}
   308  	return string(buf) + s[i:]
   309  }
   310  
   311  // getEnvName gets env name, that is, var from ${var}.
   312  // And content of var and its len will be returned.
   313  func getEnvName(s string) (string, int) {
   314  	// look for right curly bracket '}'
   315  	// it's guaranteed that the first char is '{' and the string has at least two char
   316  	for i := 1; i < len(s); i++ {
   317  		if s[i] == ' ' || s[i] == '\n' || s[i] == '"' { // "xx${xxx"
   318  			return "", 0 // encounter invalid char, keep the $
   319  		}
   320  		if s[i] == '}' {
   321  			if i == 1 { // ${}
   322  				return "", 2 // remove ${}
   323  			}
   324  			return s[1:i], i + 1
   325  		}
   326  	}
   327  	return "", 0 // no },keep the $
   328  }
   329  
   330  // --------------- the following code is IP Config related -----------------//
   331  
   332  // nicIP defines the parameters used to record the ip address (ipv4 & ipv6) of the nic.
   333  type nicIP struct {
   334  	nic  string
   335  	ipv4 []string
   336  	ipv6 []string
   337  }
   338  
   339  // netInterfaceIP maintains the nic name to nicIP mapping.
   340  type netInterfaceIP struct {
   341  	once sync.Once
   342  	ips  map[string]*nicIP
   343  }
   344  
   345  // enumAllIP returns the nic name to nicIP mapping.
   346  func (p *netInterfaceIP) enumAllIP() map[string]*nicIP {
   347  	p.once.Do(func() {
   348  		p.ips = make(map[string]*nicIP)
   349  		interfaces, err := net.Interfaces()
   350  		if err != nil {
   351  			return
   352  		}
   353  		for _, i := range interfaces {
   354  			p.addInterface(i)
   355  		}
   356  	})
   357  	return p.ips
   358  }
   359  
   360  func (p *netInterfaceIP) addInterface(i net.Interface) {
   361  	addrs, err := i.Addrs()
   362  	if err != nil {
   363  		return
   364  	}
   365  	for _, addr := range addrs {
   366  		ipNet, ok := addr.(*net.IPNet)
   367  		if !ok {
   368  			continue
   369  		}
   370  		if ipNet.IP.To4() != nil {
   371  			p.addIPv4(i.Name, ipNet.IP.String())
   372  		} else if ipNet.IP.To16() != nil {
   373  			p.addIPv6(i.Name, ipNet.IP.String())
   374  		}
   375  	}
   376  }
   377  
   378  // addIPv4 append ipv4 address
   379  func (p *netInterfaceIP) addIPv4(nic string, ip4 string) {
   380  	ips := p.getNicIP(nic)
   381  	ips.ipv4 = append(ips.ipv4, ip4)
   382  }
   383  
   384  // addIPv6 append ipv6 address
   385  func (p *netInterfaceIP) addIPv6(nic string, ip6 string) {
   386  	ips := p.getNicIP(nic)
   387  	ips.ipv6 = append(ips.ipv6, ip6)
   388  }
   389  
   390  // getNicIP returns nicIP by nic name.
   391  func (p *netInterfaceIP) getNicIP(nic string) *nicIP {
   392  	if _, ok := p.ips[nic]; !ok {
   393  		p.ips[nic] = &nicIP{nic: nic}
   394  	}
   395  	return p.ips[nic]
   396  }
   397  
   398  // getIPByNic returns ip address by nic name.
   399  // If the ipv4 addr is not empty, it will be returned.
   400  // Otherwise, the ipv6 addr will be returned.
   401  func (p *netInterfaceIP) getIPByNic(nic string) string {
   402  	p.enumAllIP()
   403  	if len(p.ips) <= 0 {
   404  		return ""
   405  	}
   406  	if _, ok := p.ips[nic]; !ok {
   407  		return ""
   408  	}
   409  	ip := p.ips[nic]
   410  	if len(ip.ipv4) > 0 {
   411  		return ip.ipv4[0]
   412  	}
   413  	if len(ip.ipv6) > 0 {
   414  		return ip.ipv6[0]
   415  	}
   416  	return ""
   417  }
   418  
   419  // localIP records the local nic name->nicIP mapping.
   420  var localIP = &netInterfaceIP{}
   421  
   422  // getIP returns ip addr by nic name.
   423  func getIP(nic string) string {
   424  	ip := localIP.getIPByNic(nic)
   425  	return ip
   426  }
   427  
   428  // deduplicate merges two slices.
   429  // Order will be kept and duplication will be removed.
   430  func deduplicate(a, b []string) []string {
   431  	r := make([]string, 0, len(a)+len(b))
   432  	m := make(map[string]bool)
   433  	for _, s := range append(a, b...) {
   434  		if _, ok := m[s]; !ok {
   435  			m[s] = true
   436  			r = append(r, s)
   437  		}
   438  	}
   439  	return r
   440  }