trpc.group/trpc-go/trpc-go@v1.0.3/trpc_util.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package trpc
    15  
    16  import (
    17  	"context"
    18  	"net"
    19  	"runtime"
    20  	"sync"
    21  	"time"
    22  
    23  	"github.com/panjf2000/ants/v2"
    24  	trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc"
    25  
    26  	"trpc.group/trpc-go/trpc-go/codec"
    27  	"trpc.group/trpc-go/trpc-go/errs"
    28  	"trpc.group/trpc-go/trpc-go/internal/report"
    29  	"trpc.group/trpc-go/trpc-go/log"
    30  )
    31  
    32  // PanicBufLen is len of buffer used for stack trace logging
    33  // when the goroutine panics, 1024 by default.
    34  var PanicBufLen = 1024
    35  
    36  // ----------------------- trpc util functions ------------------------------------ //
    37  
    38  // Message returns msg from ctx.
    39  func Message(ctx context.Context) codec.Msg {
    40  	return codec.Message(ctx)
    41  }
    42  
    43  // BackgroundContext puts an initialized msg into background context and returns it.
    44  func BackgroundContext() context.Context {
    45  	cfg := GlobalConfig()
    46  	ctx, msg := codec.WithNewMessage(context.Background())
    47  	msg.WithCalleeContainerName(cfg.Global.ContainerName)
    48  	msg.WithNamespace(cfg.Global.Namespace)
    49  	msg.WithEnvName(cfg.Global.EnvName)
    50  	if cfg.Global.EnableSet == "Y" {
    51  		msg.WithSetName(cfg.Global.FullSetName)
    52  	}
    53  	if len(cfg.Server.Service) > 0 {
    54  		msg.WithCalleeServiceName(cfg.Server.Service[0].Name)
    55  	} else {
    56  		msg.WithCalleeApp(cfg.Server.App)
    57  		msg.WithCalleeServer(cfg.Server.Server)
    58  	}
    59  	return ctx
    60  }
    61  
    62  // GetMetaData returns metadata from ctx by key.
    63  func GetMetaData(ctx context.Context, key string) []byte {
    64  	msg := codec.Message(ctx)
    65  	if len(msg.ServerMetaData()) > 0 {
    66  		return msg.ServerMetaData()[key]
    67  	}
    68  	return nil
    69  }
    70  
    71  // SetMetaData sets metadata which will be returned to upstream.
    72  // This method is not thread-safe.
    73  // Notice: SetMetaData can only be called in the server side rpc entry goroutine,
    74  // not in goroutine that calls the client.
    75  func SetMetaData(ctx context.Context, key string, val []byte) {
    76  	msg := codec.Message(ctx)
    77  	if len(msg.ServerMetaData()) > 0 {
    78  		msg.ServerMetaData()[key] = val
    79  		return
    80  	}
    81  	md := make(map[string][]byte)
    82  	md[key] = val
    83  	msg.WithServerMetaData(md)
    84  }
    85  
    86  // Request returns RequestProtocol from ctx.
    87  // If the RequestProtocol not found, a new RequestProtocol will be created and returned.
    88  func Request(ctx context.Context) *trpcpb.RequestProtocol {
    89  	msg := codec.Message(ctx)
    90  	request, ok := msg.ServerReqHead().(*trpcpb.RequestProtocol)
    91  	if !ok {
    92  		return &trpcpb.RequestProtocol{}
    93  	}
    94  	return request
    95  }
    96  
    97  // Response returns ResponseProtocol from ctx.
    98  // If the ResponseProtocol not found, a new ResponseProtocol will be created and returned.
    99  func Response(ctx context.Context) *trpcpb.ResponseProtocol {
   100  	msg := codec.Message(ctx)
   101  	response, ok := msg.ServerRspHead().(*trpcpb.ResponseProtocol)
   102  	if !ok {
   103  		return &trpcpb.ResponseProtocol{}
   104  	}
   105  	return response
   106  }
   107  
   108  // CloneContext copies the context to get a context that retains the value and doesn't cancel.
   109  // This is used when the handler is processed asynchronously to detach the original timeout control
   110  // and retains the original context information.
   111  //
   112  // After the trpc handler function returns, ctx will be canceled, and put the ctx's Msg back into pool,
   113  // and the associated Metrics and logger will be released.
   114  //
   115  // Before starting a goroutine to run the handler function asynchronously,
   116  // this method must be called to copy context, detach the original timeout control,
   117  // and retain the information in Msg for Metrics.
   118  //
   119  // Retain the logger context for printing the associated log,
   120  // keep other value in context, such as tracing context, etc.
   121  func CloneContext(ctx context.Context) context.Context {
   122  	oldMsg := codec.Message(ctx)
   123  	newCtx, newMsg := codec.WithNewMessage(detach(ctx))
   124  	codec.CopyMsg(newMsg, oldMsg)
   125  	return newCtx
   126  }
   127  
   128  type detachedContext struct{ parent context.Context }
   129  
   130  func detach(ctx context.Context) context.Context { return detachedContext{ctx} }
   131  
   132  // Deadline implements context.Deadline
   133  func (v detachedContext) Deadline() (time.Time, bool) { return time.Time{}, false }
   134  
   135  // Done implements context.Done
   136  func (v detachedContext) Done() <-chan struct{} { return nil }
   137  
   138  // Err implements context.Err
   139  func (v detachedContext) Err() error { return nil }
   140  
   141  // Value implements context.Value
   142  func (v detachedContext) Value(key interface{}) interface{} { return v.parent.Value(key) }
   143  
   144  // GoAndWait provides safe concurrent handling. Per input handler, it starts a goroutine.
   145  // Then it waits until all handlers are done and will recover if any handler panics.
   146  // The returned error is the first non-nil error returned by one of the handlers.
   147  // It can be set that non-nil error will be returned if the "key" handler fails while other handlers always
   148  // return nil error.
   149  func GoAndWait(handlers ...func() error) error {
   150  	var (
   151  		wg   sync.WaitGroup
   152  		once sync.Once
   153  		err  error
   154  	)
   155  	for _, f := range handlers {
   156  		wg.Add(1)
   157  		go func(handler func() error) {
   158  			defer func() {
   159  				if e := recover(); e != nil {
   160  					buf := make([]byte, PanicBufLen)
   161  					buf = buf[:runtime.Stack(buf, false)]
   162  					log.Errorf("[PANIC]%v\n%s\n", e, buf)
   163  					report.PanicNum.Incr()
   164  					once.Do(func() {
   165  						err = errs.New(errs.RetServerSystemErr, "panic found in call handlers")
   166  					})
   167  				}
   168  				wg.Done()
   169  			}()
   170  			if e := handler(); e != nil {
   171  				once.Do(func() {
   172  					err = e
   173  				})
   174  			}
   175  		}(f)
   176  	}
   177  	wg.Wait()
   178  	return err
   179  }
   180  
   181  // Goer is the interface that launches a testable and safe goroutine.
   182  type Goer interface {
   183  	Go(ctx context.Context, timeout time.Duration, handler func(context.Context)) error
   184  }
   185  
   186  type asyncGoer struct {
   187  	panicBufLen   int
   188  	shouldRecover bool
   189  	pool          *ants.PoolWithFunc
   190  }
   191  
   192  type goerParam struct {
   193  	ctx     context.Context
   194  	cancel  context.CancelFunc
   195  	handler func(context.Context)
   196  }
   197  
   198  // NewAsyncGoer creates a goer that executes handler asynchronously with a goroutine when Go() is called.
   199  func NewAsyncGoer(workerPoolSize int, panicBufLen int, shouldRecover bool) Goer {
   200  	g := &asyncGoer{
   201  		panicBufLen:   panicBufLen,
   202  		shouldRecover: shouldRecover,
   203  	}
   204  	if workerPoolSize == 0 {
   205  		return g
   206  	}
   207  
   208  	pool, err := ants.NewPoolWithFunc(workerPoolSize, func(args interface{}) {
   209  		p := args.(*goerParam)
   210  		g.handle(p.ctx, p.handler, p.cancel)
   211  	})
   212  	if err != nil {
   213  		panic(err)
   214  	}
   215  	g.pool = pool
   216  	return g
   217  }
   218  
   219  func (g *asyncGoer) handle(ctx context.Context, handler func(context.Context), cancel context.CancelFunc) {
   220  	defer func() {
   221  		if g.shouldRecover {
   222  			if err := recover(); err != nil {
   223  				buf := make([]byte, g.panicBufLen)
   224  				buf = buf[:runtime.Stack(buf, false)]
   225  				log.ErrorContextf(ctx, "[PANIC]%v\n%s\n", err, buf)
   226  				report.PanicNum.Incr()
   227  			}
   228  		}
   229  		cancel()
   230  	}()
   231  	handler(ctx)
   232  }
   233  
   234  func (g *asyncGoer) Go(ctx context.Context, timeout time.Duration, handler func(context.Context)) error {
   235  	oldMsg := codec.Message(ctx)
   236  	newCtx, newMsg := codec.WithNewMessage(detach(ctx))
   237  	codec.CopyMsg(newMsg, oldMsg)
   238  	newCtx, cancel := context.WithTimeout(newCtx, timeout)
   239  	if g.pool != nil {
   240  		p := &goerParam{
   241  			ctx:     newCtx,
   242  			cancel:  cancel,
   243  			handler: handler,
   244  		}
   245  		return g.pool.Invoke(p)
   246  	}
   247  	go g.handle(newCtx, handler, cancel)
   248  	return nil
   249  }
   250  
   251  type syncGoer struct {
   252  }
   253  
   254  // NewSyncGoer creates a goer that executes handler synchronously without cloning ctx when Go() is called.
   255  // it's usually used for testing.
   256  func NewSyncGoer() Goer {
   257  	return &syncGoer{}
   258  }
   259  
   260  func (g *syncGoer) Go(ctx context.Context, timeout time.Duration, handler func(context.Context)) error {
   261  	newCtx, cancel := context.WithTimeout(ctx, timeout)
   262  	defer cancel()
   263  	handler(newCtx)
   264  	return nil
   265  }
   266  
   267  // DefaultGoer is an async goer without worker pool.
   268  var DefaultGoer = NewAsyncGoer(0, PanicBufLen, true)
   269  
   270  // Go launches a safer goroutine for async task inside rpc handler.
   271  // it clones ctx and msg before the goroutine, and will recover and report metrics when the goroutine panics.
   272  // you should set a suitable timeout to control the lifetime of the new goroutine to prevent goroutine leaks.
   273  func Go(ctx context.Context, timeout time.Duration, handler func(context.Context)) error {
   274  	return DefaultGoer.Go(ctx, timeout, handler)
   275  }
   276  
   277  // --------------- the following code is IP Config related -----------------//
   278  
   279  // nicIP defines the parameters used to record the ip address (ipv4 & ipv6) of the nic.
   280  type nicIP struct {
   281  	nic  string
   282  	ipv4 []string
   283  	ipv6 []string
   284  }
   285  
   286  // netInterfaceIP maintains the nic name to nicIP mapping.
   287  type netInterfaceIP struct {
   288  	once sync.Once
   289  	ips  map[string]*nicIP
   290  }
   291  
   292  // enumAllIP returns the nic name to nicIP mapping.
   293  func (p *netInterfaceIP) enumAllIP() map[string]*nicIP {
   294  	p.once.Do(func() {
   295  		p.ips = make(map[string]*nicIP)
   296  		interfaces, err := net.Interfaces()
   297  		if err != nil {
   298  			return
   299  		}
   300  		for _, i := range interfaces {
   301  			p.addInterface(i)
   302  		}
   303  	})
   304  	return p.ips
   305  }
   306  
   307  func (p *netInterfaceIP) addInterface(i net.Interface) {
   308  	addrs, err := i.Addrs()
   309  	if err != nil {
   310  		return
   311  	}
   312  	for _, addr := range addrs {
   313  		ipNet, ok := addr.(*net.IPNet)
   314  		if !ok {
   315  			continue
   316  		}
   317  		if ipNet.IP.To4() != nil {
   318  			p.addIPv4(i.Name, ipNet.IP.String())
   319  		} else if ipNet.IP.To16() != nil {
   320  			p.addIPv6(i.Name, ipNet.IP.String())
   321  		}
   322  	}
   323  }
   324  
   325  // addIPv4 append ipv4 address
   326  func (p *netInterfaceIP) addIPv4(nic string, ip4 string) {
   327  	ips := p.getNicIP(nic)
   328  	ips.ipv4 = append(ips.ipv4, ip4)
   329  }
   330  
   331  // addIPv6 append ipv6 address
   332  func (p *netInterfaceIP) addIPv6(nic string, ip6 string) {
   333  	ips := p.getNicIP(nic)
   334  	ips.ipv6 = append(ips.ipv6, ip6)
   335  }
   336  
   337  // getNicIP returns nicIP by nic name.
   338  func (p *netInterfaceIP) getNicIP(nic string) *nicIP {
   339  	if _, ok := p.ips[nic]; !ok {
   340  		p.ips[nic] = &nicIP{nic: nic}
   341  	}
   342  	return p.ips[nic]
   343  }
   344  
   345  // getIPByNic returns ip address by nic name.
   346  // If the ipv4 addr is not empty, it will be returned.
   347  // Otherwise, the ipv6 addr will be returned.
   348  func (p *netInterfaceIP) getIPByNic(nic string) string {
   349  	p.enumAllIP()
   350  	if len(p.ips) <= 0 {
   351  		return ""
   352  	}
   353  	if _, ok := p.ips[nic]; !ok {
   354  		return ""
   355  	}
   356  	ip := p.ips[nic]
   357  	if len(ip.ipv4) > 0 {
   358  		return ip.ipv4[0]
   359  	}
   360  	if len(ip.ipv6) > 0 {
   361  		return ip.ipv6[0]
   362  	}
   363  	return ""
   364  }
   365  
   366  // localIP records the local nic name->nicIP mapping.
   367  var localIP = &netInterfaceIP{}
   368  
   369  // getIP returns ip addr by nic name.
   370  func getIP(nic string) string {
   371  	ip := localIP.getIPByNic(nic)
   372  	return ip
   373  }
   374  
   375  // deduplicate merges two slices.
   376  // Order will be kept and duplication will be removed.
   377  func deduplicate(a, b []string) []string {
   378  	r := make([]string, 0, len(a)+len(b))
   379  	m := make(map[string]bool)
   380  	for _, s := range append(a, b...) {
   381  		if _, ok := m[s]; !ok {
   382  			m[s] = true
   383  			r = append(r, s)
   384  		}
   385  	}
   386  	return r
   387  }