trpc.group/trpc-go/trpc-go@v1.0.2/trpc_util.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package trpc 15 16 import ( 17 "context" 18 "net" 19 "os" 20 "runtime" 21 "sync" 22 "time" 23 24 "github.com/panjf2000/ants/v2" 25 trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" 26 27 "trpc.group/trpc-go/trpc-go/codec" 28 "trpc.group/trpc-go/trpc-go/errs" 29 "trpc.group/trpc-go/trpc-go/internal/report" 30 "trpc.group/trpc-go/trpc-go/log" 31 ) 32 33 // PanicBufLen is len of buffer used for stack trace logging 34 // when the goroutine panics, 1024 by default. 35 var PanicBufLen = 1024 36 37 // ----------------------- trpc util functions ------------------------------------ // 38 39 // Message returns msg from ctx. 40 func Message(ctx context.Context) codec.Msg { 41 return codec.Message(ctx) 42 } 43 44 // BackgroundContext puts an initialized msg into background context and returns it. 45 func BackgroundContext() context.Context { 46 cfg := GlobalConfig() 47 ctx, msg := codec.WithNewMessage(context.Background()) 48 msg.WithCalleeContainerName(cfg.Global.ContainerName) 49 msg.WithNamespace(cfg.Global.Namespace) 50 msg.WithEnvName(cfg.Global.EnvName) 51 if cfg.Global.EnableSet == "Y" { 52 msg.WithSetName(cfg.Global.FullSetName) 53 } 54 if len(cfg.Server.Service) > 0 { 55 msg.WithCalleeServiceName(cfg.Server.Service[0].Name) 56 } else { 57 msg.WithCalleeApp(cfg.Server.App) 58 msg.WithCalleeServer(cfg.Server.Server) 59 } 60 return ctx 61 } 62 63 // GetMetaData returns metadata from ctx by key. 64 func GetMetaData(ctx context.Context, key string) []byte { 65 msg := codec.Message(ctx) 66 if len(msg.ServerMetaData()) > 0 { 67 return msg.ServerMetaData()[key] 68 } 69 return nil 70 } 71 72 // SetMetaData sets metadata which will be returned to upstream. 73 // This method is not thread-safe. 74 // Notice: SetMetaData can only be called in the server side rpc entry goroutine, 75 // not in goroutine that calls the client. 76 func SetMetaData(ctx context.Context, key string, val []byte) { 77 msg := codec.Message(ctx) 78 if len(msg.ServerMetaData()) > 0 { 79 msg.ServerMetaData()[key] = val 80 return 81 } 82 md := make(map[string][]byte) 83 md[key] = val 84 msg.WithServerMetaData(md) 85 } 86 87 // Request returns RequestProtocol from ctx. 88 // If the RequestProtocol not found, a new RequestProtocol will be created and returned. 89 func Request(ctx context.Context) *trpcpb.RequestProtocol { 90 msg := codec.Message(ctx) 91 request, ok := msg.ServerReqHead().(*trpcpb.RequestProtocol) 92 if !ok { 93 return &trpcpb.RequestProtocol{} 94 } 95 return request 96 } 97 98 // Response returns ResponseProtocol from ctx. 99 // If the ResponseProtocol not found, a new ResponseProtocol will be created and returned. 100 func Response(ctx context.Context) *trpcpb.ResponseProtocol { 101 msg := codec.Message(ctx) 102 response, ok := msg.ServerRspHead().(*trpcpb.ResponseProtocol) 103 if !ok { 104 return &trpcpb.ResponseProtocol{} 105 } 106 return response 107 } 108 109 // CloneContext copies the context to get a context that retains the value and doesn't cancel. 110 // This is used when the handler is processed asynchronously to detach the original timeout control 111 // and retains the original context information. 112 // 113 // After the trpc handler function returns, ctx will be canceled, and put the ctx's Msg back into pool, 114 // and the associated Metrics and logger will be released. 115 // 116 // Before starting a goroutine to run the handler function asynchronously, 117 // this method must be called to copy context, detach the original timeout control, 118 // and retain the information in Msg for Metrics. 119 // 120 // Retain the logger context for printing the associated log, 121 // keep other value in context, such as tracing context, etc. 122 func CloneContext(ctx context.Context) context.Context { 123 oldMsg := codec.Message(ctx) 124 newCtx, newMsg := codec.WithNewMessage(detach(ctx)) 125 codec.CopyMsg(newMsg, oldMsg) 126 return newCtx 127 } 128 129 type detachedContext struct{ parent context.Context } 130 131 func detach(ctx context.Context) context.Context { return detachedContext{ctx} } 132 133 // Deadline implements context.Deadline 134 func (v detachedContext) Deadline() (time.Time, bool) { return time.Time{}, false } 135 136 // Done implements context.Done 137 func (v detachedContext) Done() <-chan struct{} { return nil } 138 139 // Err implements context.Err 140 func (v detachedContext) Err() error { return nil } 141 142 // Value implements context.Value 143 func (v detachedContext) Value(key interface{}) interface{} { return v.parent.Value(key) } 144 145 // GoAndWait provides safe concurrent handling. Per input handler, it starts a goroutine. 146 // Then it waits until all handlers are done and will recover if any handler panics. 147 // The returned error is the first non-nil error returned by one of the handlers. 148 // It can be set that non-nil error will be returned if the "key" handler fails while other handlers always 149 // return nil error. 150 func GoAndWait(handlers ...func() error) error { 151 var ( 152 wg sync.WaitGroup 153 once sync.Once 154 err error 155 ) 156 for _, f := range handlers { 157 wg.Add(1) 158 go func(handler func() error) { 159 defer func() { 160 if e := recover(); e != nil { 161 buf := make([]byte, PanicBufLen) 162 buf = buf[:runtime.Stack(buf, false)] 163 log.Errorf("[PANIC]%v\n%s\n", e, buf) 164 report.PanicNum.Incr() 165 once.Do(func() { 166 err = errs.New(errs.RetServerSystemErr, "panic found in call handlers") 167 }) 168 } 169 wg.Done() 170 }() 171 if e := handler(); e != nil { 172 once.Do(func() { 173 err = e 174 }) 175 } 176 }(f) 177 } 178 wg.Wait() 179 return err 180 } 181 182 // Goer is the interface that launches a testable and safe goroutine. 183 type Goer interface { 184 Go(ctx context.Context, timeout time.Duration, handler func(context.Context)) error 185 } 186 187 type asyncGoer struct { 188 panicBufLen int 189 shouldRecover bool 190 pool *ants.PoolWithFunc 191 } 192 193 type goerParam struct { 194 ctx context.Context 195 cancel context.CancelFunc 196 handler func(context.Context) 197 } 198 199 // NewAsyncGoer creates a goer that executes handler asynchronously with a goroutine when Go() is called. 200 func NewAsyncGoer(workerPoolSize int, panicBufLen int, shouldRecover bool) Goer { 201 g := &asyncGoer{ 202 panicBufLen: panicBufLen, 203 shouldRecover: shouldRecover, 204 } 205 if workerPoolSize == 0 { 206 return g 207 } 208 209 pool, err := ants.NewPoolWithFunc(workerPoolSize, func(args interface{}) { 210 p := args.(*goerParam) 211 g.handle(p.ctx, p.handler, p.cancel) 212 }) 213 if err != nil { 214 panic(err) 215 } 216 g.pool = pool 217 return g 218 } 219 220 func (g *asyncGoer) handle(ctx context.Context, handler func(context.Context), cancel context.CancelFunc) { 221 defer func() { 222 if g.shouldRecover { 223 if err := recover(); err != nil { 224 buf := make([]byte, g.panicBufLen) 225 buf = buf[:runtime.Stack(buf, false)] 226 log.ErrorContextf(ctx, "[PANIC]%v\n%s\n", err, buf) 227 report.PanicNum.Incr() 228 } 229 } 230 cancel() 231 }() 232 handler(ctx) 233 } 234 235 func (g *asyncGoer) Go(ctx context.Context, timeout time.Duration, handler func(context.Context)) error { 236 oldMsg := codec.Message(ctx) 237 newCtx, newMsg := codec.WithNewMessage(detach(ctx)) 238 codec.CopyMsg(newMsg, oldMsg) 239 newCtx, cancel := context.WithTimeout(newCtx, timeout) 240 if g.pool != nil { 241 p := &goerParam{ 242 ctx: newCtx, 243 cancel: cancel, 244 handler: handler, 245 } 246 return g.pool.Invoke(p) 247 } 248 go g.handle(newCtx, handler, cancel) 249 return nil 250 } 251 252 type syncGoer struct { 253 } 254 255 // NewSyncGoer creates a goer that executes handler synchronously without cloning ctx when Go() is called. 256 // it's usually used for testing. 257 func NewSyncGoer() Goer { 258 return &syncGoer{} 259 } 260 261 func (g *syncGoer) Go(ctx context.Context, timeout time.Duration, handler func(context.Context)) error { 262 newCtx, cancel := context.WithTimeout(ctx, timeout) 263 defer cancel() 264 handler(newCtx) 265 return nil 266 } 267 268 // DefaultGoer is an async goer without worker pool. 269 var DefaultGoer = NewAsyncGoer(0, PanicBufLen, true) 270 271 // Go launches a safer goroutine for async task inside rpc handler. 272 // it clones ctx and msg before the goroutine, and will recover and report metrics when the goroutine panics. 273 // you should set a suitable timeout to control the lifetime of the new goroutine to prevent goroutine leaks. 274 func Go(ctx context.Context, timeout time.Duration, handler func(context.Context)) error { 275 return DefaultGoer.Go(ctx, timeout, handler) 276 } 277 278 // expandEnv looks for ${var} in s and replaces them with value of the 279 // corresponding environment variable. 280 // $var is considered invalid. 281 // It's not like os.ExpandEnv which will handle both ${var} and $var. 282 // Since configurations like password for redis/mysql may contain $, this 283 // method is needed. 284 func expandEnv(s string) string { 285 var buf []byte 286 i := 0 287 for j := 0; j < len(s); j++ { 288 if s[j] == '$' && j+2 < len(s) && s[j+1] == '{' { // only ${var} instead of $var is valid 289 if buf == nil { 290 buf = make([]byte, 0, 2*len(s)) 291 } 292 buf = append(buf, s[i:j]...) 293 name, w := getEnvName(s[j+1:]) 294 if name == "" && w > 0 { 295 // invalid matching, remove the $ 296 } else if name == "" { 297 buf = append(buf, s[j]) // keep the $ 298 } else { 299 buf = append(buf, os.Getenv(name)...) 300 } 301 j += w 302 i = j + 1 303 } 304 } 305 if buf == nil { 306 return s 307 } 308 return string(buf) + s[i:] 309 } 310 311 // getEnvName gets env name, that is, var from ${var}. 312 // And content of var and its len will be returned. 313 func getEnvName(s string) (string, int) { 314 // look for right curly bracket '}' 315 // it's guaranteed that the first char is '{' and the string has at least two char 316 for i := 1; i < len(s); i++ { 317 if s[i] == ' ' || s[i] == '\n' || s[i] == '"' { // "xx${xxx" 318 return "", 0 // encounter invalid char, keep the $ 319 } 320 if s[i] == '}' { 321 if i == 1 { // ${} 322 return "", 2 // remove ${} 323 } 324 return s[1:i], i + 1 325 } 326 } 327 return "", 0 // no },keep the $ 328 } 329 330 // --------------- the following code is IP Config related -----------------// 331 332 // nicIP defines the parameters used to record the ip address (ipv4 & ipv6) of the nic. 333 type nicIP struct { 334 nic string 335 ipv4 []string 336 ipv6 []string 337 } 338 339 // netInterfaceIP maintains the nic name to nicIP mapping. 340 type netInterfaceIP struct { 341 once sync.Once 342 ips map[string]*nicIP 343 } 344 345 // enumAllIP returns the nic name to nicIP mapping. 346 func (p *netInterfaceIP) enumAllIP() map[string]*nicIP { 347 p.once.Do(func() { 348 p.ips = make(map[string]*nicIP) 349 interfaces, err := net.Interfaces() 350 if err != nil { 351 return 352 } 353 for _, i := range interfaces { 354 p.addInterface(i) 355 } 356 }) 357 return p.ips 358 } 359 360 func (p *netInterfaceIP) addInterface(i net.Interface) { 361 addrs, err := i.Addrs() 362 if err != nil { 363 return 364 } 365 for _, addr := range addrs { 366 ipNet, ok := addr.(*net.IPNet) 367 if !ok { 368 continue 369 } 370 if ipNet.IP.To4() != nil { 371 p.addIPv4(i.Name, ipNet.IP.String()) 372 } else if ipNet.IP.To16() != nil { 373 p.addIPv6(i.Name, ipNet.IP.String()) 374 } 375 } 376 } 377 378 // addIPv4 append ipv4 address 379 func (p *netInterfaceIP) addIPv4(nic string, ip4 string) { 380 ips := p.getNicIP(nic) 381 ips.ipv4 = append(ips.ipv4, ip4) 382 } 383 384 // addIPv6 append ipv6 address 385 func (p *netInterfaceIP) addIPv6(nic string, ip6 string) { 386 ips := p.getNicIP(nic) 387 ips.ipv6 = append(ips.ipv6, ip6) 388 } 389 390 // getNicIP returns nicIP by nic name. 391 func (p *netInterfaceIP) getNicIP(nic string) *nicIP { 392 if _, ok := p.ips[nic]; !ok { 393 p.ips[nic] = &nicIP{nic: nic} 394 } 395 return p.ips[nic] 396 } 397 398 // getIPByNic returns ip address by nic name. 399 // If the ipv4 addr is not empty, it will be returned. 400 // Otherwise, the ipv6 addr will be returned. 401 func (p *netInterfaceIP) getIPByNic(nic string) string { 402 p.enumAllIP() 403 if len(p.ips) <= 0 { 404 return "" 405 } 406 if _, ok := p.ips[nic]; !ok { 407 return "" 408 } 409 ip := p.ips[nic] 410 if len(ip.ipv4) > 0 { 411 return ip.ipv4[0] 412 } 413 if len(ip.ipv6) > 0 { 414 return ip.ipv6[0] 415 } 416 return "" 417 } 418 419 // localIP records the local nic name->nicIP mapping. 420 var localIP = &netInterfaceIP{} 421 422 // getIP returns ip addr by nic name. 423 func getIP(nic string) string { 424 ip := localIP.getIPByNic(nic) 425 return ip 426 } 427 428 // deduplicate merges two slices. 429 // Order will be kept and duplication will be removed. 430 func deduplicate(a, b []string) []string { 431 r := make([]string, 0, len(a)+len(b)) 432 m := make(map[string]bool) 433 for _, s := range append(a, b...) { 434 if _, ok := m[s]; !ok { 435 m[s] = true 436 r = append(r, s) 437 } 438 } 439 return r 440 }