github.com/aacfactory/fns@v1.2.86-0.20240310083819-80d667fc0a17/context/context.go (about)

     1  /*
     2   * Copyright 2023 Wang Min Xiang
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   * 	http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   *
    16   */
    17  
    18  package context
    19  
    20  import (
    21  	"context"
    22  	"sync"
    23  	"unsafe"
    24  )
    25  
    26  var (
    27  	pool = sync.Pool{}
    28  )
    29  
    30  func Acquire(ctx context.Context) Context {
    31  	cached := pool.Get()
    32  	if cached == nil {
    33  		return &context_{
    34  			Context: ctx,
    35  			users:   make(Entries, 0, 1),
    36  			locals:  make(Entries, 0, 1),
    37  		}
    38  	}
    39  	v := cached.(*context_)
    40  	v.Context = ctx
    41  	return v
    42  }
    43  
    44  func Release(ctx context.Context) {
    45  	v, ok := ctx.(*context_)
    46  	if ok {
    47  		v.Context = nil
    48  		v.users.Reset()
    49  		v.locals.Reset()
    50  		pool.Put(v)
    51  	}
    52  }
    53  
    54  type Context interface {
    55  	context.Context
    56  	UserValue(key []byte) any
    57  	SetUserValue(key []byte, val any)
    58  	RemoveUserValue(key []byte)
    59  	UserValues(fn func(key []byte, val any))
    60  	LocalValue(key []byte) any
    61  	SetLocalValue(key []byte, val any)
    62  	RemoveLocalValue(key []byte)
    63  	LocalValues(fn func(key []byte, val any))
    64  }
    65  
    66  type context_ struct {
    67  	context.Context
    68  	users  Entries
    69  	locals Entries
    70  }
    71  
    72  func (c *context_) UserValue(key []byte) any {
    73  	v := c.users.Get(key)
    74  	if v != nil {
    75  		return v
    76  	}
    77  	parent, ok := c.Context.(Context)
    78  	if ok {
    79  		return parent.UserValue(key)
    80  	}
    81  	return nil
    82  }
    83  
    84  func (c *context_) SetUserValue(key []byte, val any) {
    85  	c.users.Set(key, val)
    86  }
    87  
    88  func (c *context_) RemoveUserValue(key []byte) {
    89  	if c.users.Remove(key) {
    90  		return
    91  	}
    92  	parent, ok := c.Context.(Context)
    93  	if ok {
    94  		parent.RemoveUserValue(key)
    95  	}
    96  }
    97  
    98  func (c *context_) UserValues(fn func(key []byte, val any)) {
    99  	parent, ok := c.Context.(Context)
   100  	if ok {
   101  		parent.UserValues(fn)
   102  	}
   103  	c.users.Foreach(fn)
   104  }
   105  
   106  func (c *context_) LocalValue(key []byte) any {
   107  	v := c.locals.Get(key)
   108  	if v != nil {
   109  		return v
   110  	}
   111  	parent, ok := c.Context.(Context)
   112  	if ok {
   113  		return parent.LocalValue(key)
   114  	}
   115  	return nil
   116  }
   117  
   118  func (c *context_) SetLocalValue(key []byte, val any) {
   119  	c.locals.Set(key, val)
   120  }
   121  
   122  func (c *context_) RemoveLocalValue(key []byte) {
   123  	if c.locals.Remove(key) {
   124  		return
   125  	}
   126  	parent, ok := c.Context.(Context)
   127  	if ok {
   128  		parent.RemoveLocalValue(key)
   129  	}
   130  }
   131  
   132  func (c *context_) LocalValues(fn func(key []byte, val any)) {
   133  	parent, ok := c.Context.(Context)
   134  	if ok {
   135  		parent.LocalValues(fn)
   136  	}
   137  	c.locals.Foreach(fn)
   138  }
   139  
   140  func (c *context_) Value(key any) any {
   141  	switch k := key.(type) {
   142  	case []byte:
   143  		v := c.users.Get(k)
   144  		if v == nil {
   145  			v = c.locals.Get(k)
   146  			if v == nil {
   147  				return c.Context.Value(key)
   148  			}
   149  		}
   150  		return v
   151  	case string:
   152  		s := unsafe.Slice(unsafe.StringData(k), len(k))
   153  		v := c.users.Get(s)
   154  		if v == nil {
   155  			v = c.locals.Get(s)
   156  			if v == nil {
   157  				return c.Context.Value(key)
   158  			}
   159  		}
   160  		return v
   161  	default:
   162  		break
   163  	}
   164  	return c.Context.Value(key)
   165  }