trpc.group/trpc-go/trpc-go@v1.0.3/trpc_util.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package trpc 15 16 import ( 17 "context" 18 "net" 19 "runtime" 20 "sync" 21 "time" 22 23 "github.com/panjf2000/ants/v2" 24 trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" 25 26 "trpc.group/trpc-go/trpc-go/codec" 27 "trpc.group/trpc-go/trpc-go/errs" 28 "trpc.group/trpc-go/trpc-go/internal/report" 29 "trpc.group/trpc-go/trpc-go/log" 30 ) 31 32 // PanicBufLen is len of buffer used for stack trace logging 33 // when the goroutine panics, 1024 by default. 34 var PanicBufLen = 1024 35 36 // ----------------------- trpc util functions ------------------------------------ // 37 38 // Message returns msg from ctx. 39 func Message(ctx context.Context) codec.Msg { 40 return codec.Message(ctx) 41 } 42 43 // BackgroundContext puts an initialized msg into background context and returns it. 44 func BackgroundContext() context.Context { 45 cfg := GlobalConfig() 46 ctx, msg := codec.WithNewMessage(context.Background()) 47 msg.WithCalleeContainerName(cfg.Global.ContainerName) 48 msg.WithNamespace(cfg.Global.Namespace) 49 msg.WithEnvName(cfg.Global.EnvName) 50 if cfg.Global.EnableSet == "Y" { 51 msg.WithSetName(cfg.Global.FullSetName) 52 } 53 if len(cfg.Server.Service) > 0 { 54 msg.WithCalleeServiceName(cfg.Server.Service[0].Name) 55 } else { 56 msg.WithCalleeApp(cfg.Server.App) 57 msg.WithCalleeServer(cfg.Server.Server) 58 } 59 return ctx 60 } 61 62 // GetMetaData returns metadata from ctx by key. 63 func GetMetaData(ctx context.Context, key string) []byte { 64 msg := codec.Message(ctx) 65 if len(msg.ServerMetaData()) > 0 { 66 return msg.ServerMetaData()[key] 67 } 68 return nil 69 } 70 71 // SetMetaData sets metadata which will be returned to upstream. 72 // This method is not thread-safe. 73 // Notice: SetMetaData can only be called in the server side rpc entry goroutine, 74 // not in goroutine that calls the client. 75 func SetMetaData(ctx context.Context, key string, val []byte) { 76 msg := codec.Message(ctx) 77 if len(msg.ServerMetaData()) > 0 { 78 msg.ServerMetaData()[key] = val 79 return 80 } 81 md := make(map[string][]byte) 82 md[key] = val 83 msg.WithServerMetaData(md) 84 } 85 86 // Request returns RequestProtocol from ctx. 87 // If the RequestProtocol not found, a new RequestProtocol will be created and returned. 88 func Request(ctx context.Context) *trpcpb.RequestProtocol { 89 msg := codec.Message(ctx) 90 request, ok := msg.ServerReqHead().(*trpcpb.RequestProtocol) 91 if !ok { 92 return &trpcpb.RequestProtocol{} 93 } 94 return request 95 } 96 97 // Response returns ResponseProtocol from ctx. 98 // If the ResponseProtocol not found, a new ResponseProtocol will be created and returned. 99 func Response(ctx context.Context) *trpcpb.ResponseProtocol { 100 msg := codec.Message(ctx) 101 response, ok := msg.ServerRspHead().(*trpcpb.ResponseProtocol) 102 if !ok { 103 return &trpcpb.ResponseProtocol{} 104 } 105 return response 106 } 107 108 // CloneContext copies the context to get a context that retains the value and doesn't cancel. 109 // This is used when the handler is processed asynchronously to detach the original timeout control 110 // and retains the original context information. 111 // 112 // After the trpc handler function returns, ctx will be canceled, and put the ctx's Msg back into pool, 113 // and the associated Metrics and logger will be released. 114 // 115 // Before starting a goroutine to run the handler function asynchronously, 116 // this method must be called to copy context, detach the original timeout control, 117 // and retain the information in Msg for Metrics. 118 // 119 // Retain the logger context for printing the associated log, 120 // keep other value in context, such as tracing context, etc. 121 func CloneContext(ctx context.Context) context.Context { 122 oldMsg := codec.Message(ctx) 123 newCtx, newMsg := codec.WithNewMessage(detach(ctx)) 124 codec.CopyMsg(newMsg, oldMsg) 125 return newCtx 126 } 127 128 type detachedContext struct{ parent context.Context } 129 130 func detach(ctx context.Context) context.Context { return detachedContext{ctx} } 131 132 // Deadline implements context.Deadline 133 func (v detachedContext) Deadline() (time.Time, bool) { return time.Time{}, false } 134 135 // Done implements context.Done 136 func (v detachedContext) Done() <-chan struct{} { return nil } 137 138 // Err implements context.Err 139 func (v detachedContext) Err() error { return nil } 140 141 // Value implements context.Value 142 func (v detachedContext) Value(key interface{}) interface{} { return v.parent.Value(key) } 143 144 // GoAndWait provides safe concurrent handling. Per input handler, it starts a goroutine. 145 // Then it waits until all handlers are done and will recover if any handler panics. 146 // The returned error is the first non-nil error returned by one of the handlers. 147 // It can be set that non-nil error will be returned if the "key" handler fails while other handlers always 148 // return nil error. 149 func GoAndWait(handlers ...func() error) error { 150 var ( 151 wg sync.WaitGroup 152 once sync.Once 153 err error 154 ) 155 for _, f := range handlers { 156 wg.Add(1) 157 go func(handler func() error) { 158 defer func() { 159 if e := recover(); e != nil { 160 buf := make([]byte, PanicBufLen) 161 buf = buf[:runtime.Stack(buf, false)] 162 log.Errorf("[PANIC]%v\n%s\n", e, buf) 163 report.PanicNum.Incr() 164 once.Do(func() { 165 err = errs.New(errs.RetServerSystemErr, "panic found in call handlers") 166 }) 167 } 168 wg.Done() 169 }() 170 if e := handler(); e != nil { 171 once.Do(func() { 172 err = e 173 }) 174 } 175 }(f) 176 } 177 wg.Wait() 178 return err 179 } 180 181 // Goer is the interface that launches a testable and safe goroutine. 182 type Goer interface { 183 Go(ctx context.Context, timeout time.Duration, handler func(context.Context)) error 184 } 185 186 type asyncGoer struct { 187 panicBufLen int 188 shouldRecover bool 189 pool *ants.PoolWithFunc 190 } 191 192 type goerParam struct { 193 ctx context.Context 194 cancel context.CancelFunc 195 handler func(context.Context) 196 } 197 198 // NewAsyncGoer creates a goer that executes handler asynchronously with a goroutine when Go() is called. 199 func NewAsyncGoer(workerPoolSize int, panicBufLen int, shouldRecover bool) Goer { 200 g := &asyncGoer{ 201 panicBufLen: panicBufLen, 202 shouldRecover: shouldRecover, 203 } 204 if workerPoolSize == 0 { 205 return g 206 } 207 208 pool, err := ants.NewPoolWithFunc(workerPoolSize, func(args interface{}) { 209 p := args.(*goerParam) 210 g.handle(p.ctx, p.handler, p.cancel) 211 }) 212 if err != nil { 213 panic(err) 214 } 215 g.pool = pool 216 return g 217 } 218 219 func (g *asyncGoer) handle(ctx context.Context, handler func(context.Context), cancel context.CancelFunc) { 220 defer func() { 221 if g.shouldRecover { 222 if err := recover(); err != nil { 223 buf := make([]byte, g.panicBufLen) 224 buf = buf[:runtime.Stack(buf, false)] 225 log.ErrorContextf(ctx, "[PANIC]%v\n%s\n", err, buf) 226 report.PanicNum.Incr() 227 } 228 } 229 cancel() 230 }() 231 handler(ctx) 232 } 233 234 func (g *asyncGoer) Go(ctx context.Context, timeout time.Duration, handler func(context.Context)) error { 235 oldMsg := codec.Message(ctx) 236 newCtx, newMsg := codec.WithNewMessage(detach(ctx)) 237 codec.CopyMsg(newMsg, oldMsg) 238 newCtx, cancel := context.WithTimeout(newCtx, timeout) 239 if g.pool != nil { 240 p := &goerParam{ 241 ctx: newCtx, 242 cancel: cancel, 243 handler: handler, 244 } 245 return g.pool.Invoke(p) 246 } 247 go g.handle(newCtx, handler, cancel) 248 return nil 249 } 250 251 type syncGoer struct { 252 } 253 254 // NewSyncGoer creates a goer that executes handler synchronously without cloning ctx when Go() is called. 255 // it's usually used for testing. 256 func NewSyncGoer() Goer { 257 return &syncGoer{} 258 } 259 260 func (g *syncGoer) Go(ctx context.Context, timeout time.Duration, handler func(context.Context)) error { 261 newCtx, cancel := context.WithTimeout(ctx, timeout) 262 defer cancel() 263 handler(newCtx) 264 return nil 265 } 266 267 // DefaultGoer is an async goer without worker pool. 268 var DefaultGoer = NewAsyncGoer(0, PanicBufLen, true) 269 270 // Go launches a safer goroutine for async task inside rpc handler. 271 // it clones ctx and msg before the goroutine, and will recover and report metrics when the goroutine panics. 272 // you should set a suitable timeout to control the lifetime of the new goroutine to prevent goroutine leaks. 273 func Go(ctx context.Context, timeout time.Duration, handler func(context.Context)) error { 274 return DefaultGoer.Go(ctx, timeout, handler) 275 } 276 277 // --------------- the following code is IP Config related -----------------// 278 279 // nicIP defines the parameters used to record the ip address (ipv4 & ipv6) of the nic. 280 type nicIP struct { 281 nic string 282 ipv4 []string 283 ipv6 []string 284 } 285 286 // netInterfaceIP maintains the nic name to nicIP mapping. 287 type netInterfaceIP struct { 288 once sync.Once 289 ips map[string]*nicIP 290 } 291 292 // enumAllIP returns the nic name to nicIP mapping. 293 func (p *netInterfaceIP) enumAllIP() map[string]*nicIP { 294 p.once.Do(func() { 295 p.ips = make(map[string]*nicIP) 296 interfaces, err := net.Interfaces() 297 if err != nil { 298 return 299 } 300 for _, i := range interfaces { 301 p.addInterface(i) 302 } 303 }) 304 return p.ips 305 } 306 307 func (p *netInterfaceIP) addInterface(i net.Interface) { 308 addrs, err := i.Addrs() 309 if err != nil { 310 return 311 } 312 for _, addr := range addrs { 313 ipNet, ok := addr.(*net.IPNet) 314 if !ok { 315 continue 316 } 317 if ipNet.IP.To4() != nil { 318 p.addIPv4(i.Name, ipNet.IP.String()) 319 } else if ipNet.IP.To16() != nil { 320 p.addIPv6(i.Name, ipNet.IP.String()) 321 } 322 } 323 } 324 325 // addIPv4 append ipv4 address 326 func (p *netInterfaceIP) addIPv4(nic string, ip4 string) { 327 ips := p.getNicIP(nic) 328 ips.ipv4 = append(ips.ipv4, ip4) 329 } 330 331 // addIPv6 append ipv6 address 332 func (p *netInterfaceIP) addIPv6(nic string, ip6 string) { 333 ips := p.getNicIP(nic) 334 ips.ipv6 = append(ips.ipv6, ip6) 335 } 336 337 // getNicIP returns nicIP by nic name. 338 func (p *netInterfaceIP) getNicIP(nic string) *nicIP { 339 if _, ok := p.ips[nic]; !ok { 340 p.ips[nic] = &nicIP{nic: nic} 341 } 342 return p.ips[nic] 343 } 344 345 // getIPByNic returns ip address by nic name. 346 // If the ipv4 addr is not empty, it will be returned. 347 // Otherwise, the ipv6 addr will be returned. 348 func (p *netInterfaceIP) getIPByNic(nic string) string { 349 p.enumAllIP() 350 if len(p.ips) <= 0 { 351 return "" 352 } 353 if _, ok := p.ips[nic]; !ok { 354 return "" 355 } 356 ip := p.ips[nic] 357 if len(ip.ipv4) > 0 { 358 return ip.ipv4[0] 359 } 360 if len(ip.ipv6) > 0 { 361 return ip.ipv6[0] 362 } 363 return "" 364 } 365 366 // localIP records the local nic name->nicIP mapping. 367 var localIP = &netInterfaceIP{} 368 369 // getIP returns ip addr by nic name. 370 func getIP(nic string) string { 371 ip := localIP.getIPByNic(nic) 372 return ip 373 } 374 375 // deduplicate merges two slices. 376 // Order will be kept and duplication will be removed. 377 func deduplicate(a, b []string) []string { 378 r := make([]string, 0, len(a)+len(b)) 379 m := make(map[string]bool) 380 for _, s := range append(a, b...) { 381 if _, ok := m[s]; !ok { 382 m[s] = true 383 r = append(r, s) 384 } 385 } 386 return r 387 }