github.com/timandy/routine@v1.1.4-0.20240507073150-e4a3e1fe2ba5/thread_local_inheritable_test.go (about)

     1  package routine
     2  
     3  import (
     4  	"math"
     5  	"sync"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  )
    10  
    11  func TestInheritableThreadLocal_Index(t *testing.T) {
    12  	tls := NewInheritableThreadLocal[string]()
    13  	assert.GreaterOrEqual(t, tls.(*inheritableThreadLocal[string]).index, 0)
    14  	tls2 := NewInheritableThreadLocalWithInitial[string](func() string {
    15  		return "Hello"
    16  	})
    17  	assert.Greater(t, tls2.(*inheritableThreadLocal[string]).index, tls.(*inheritableThreadLocal[string]).index)
    18  }
    19  
    20  func TestInheritableThreadLocal_NextIndex(t *testing.T) {
    21  	backup := inheritableThreadLocalIndex
    22  	defer func() {
    23  		inheritableThreadLocalIndex = backup
    24  	}()
    25  	//
    26  	inheritableThreadLocalIndex = math.MaxInt32
    27  	assert.Panics(t, func() {
    28  		nextInheritableThreadLocalIndex()
    29  	})
    30  	assert.Equal(t, math.MaxInt32, int(inheritableThreadLocalIndex))
    31  }
    32  
    33  func TestInheritableThreadLocal_Common(t *testing.T) {
    34  	tls := NewInheritableThreadLocal[int]()
    35  	tls2 := NewInheritableThreadLocal[string]()
    36  	tls.Remove()
    37  	tls2.Remove()
    38  	assert.Equal(t, 0, tls.Get())
    39  	assert.Equal(t, "", tls2.Get())
    40  	//
    41  	tls.Set(1)
    42  	tls2.Set("World")
    43  	assert.Equal(t, 1, tls.Get())
    44  	assert.Equal(t, "World", tls2.Get())
    45  	//
    46  	tls.Set(0)
    47  	tls2.Set("")
    48  	assert.Equal(t, 0, tls.Get())
    49  	assert.Equal(t, "", tls2.Get())
    50  	//
    51  	tls.Set(2)
    52  	tls2.Set("!")
    53  	assert.Equal(t, 2, tls.Get())
    54  	assert.Equal(t, "!", tls2.Get())
    55  	//
    56  	tls.Remove()
    57  	tls2.Remove()
    58  	assert.Equal(t, 0, tls.Get())
    59  	assert.Equal(t, "", tls2.Get())
    60  	//
    61  	tls.Set(2)
    62  	tls2.Set("!")
    63  	assert.Equal(t, 2, tls.Get())
    64  	assert.Equal(t, "!", tls2.Get())
    65  	wg := &sync.WaitGroup{}
    66  	wg.Add(100)
    67  	for i := 0; i < 100; i++ {
    68  		Go(func() {
    69  			assert.Equal(t, 2, tls.Get())
    70  			assert.Equal(t, "!", tls2.Get())
    71  			wg.Done()
    72  		})
    73  	}
    74  	wg.Wait()
    75  	assert.Equal(t, 2, tls.Get())
    76  	assert.Equal(t, "!", tls2.Get())
    77  }
    78  
    79  func TestInheritableThreadLocal_Mixed(t *testing.T) {
    80  	tls := NewInheritableThreadLocal[int]()
    81  	tls2 := NewInheritableThreadLocalWithInitial[string](func() string {
    82  		return "Hello"
    83  	})
    84  	assert.Equal(t, 0, tls.Get())
    85  	assert.Equal(t, "Hello", tls2.Get())
    86  	//
    87  	tls.Set(1)
    88  	tls2.Set("World")
    89  	assert.Equal(t, 1, tls.Get())
    90  	assert.Equal(t, "World", tls2.Get())
    91  	//
    92  	tls.Set(0)
    93  	tls2.Set("")
    94  	assert.Equal(t, 0, tls.Get())
    95  	assert.Equal(t, "", tls2.Get())
    96  	//
    97  	tls.Set(2)
    98  	tls2.Set("!")
    99  	assert.Equal(t, 2, tls.Get())
   100  	assert.Equal(t, "!", tls2.Get())
   101  	//
   102  	tls.Remove()
   103  	tls2.Remove()
   104  	assert.Equal(t, 0, tls.Get())
   105  	assert.Equal(t, "Hello", tls2.Get())
   106  	//
   107  	tls.Set(2)
   108  	tls2.Set("!")
   109  	assert.Equal(t, 2, tls.Get())
   110  	assert.Equal(t, "!", tls2.Get())
   111  	wg := &sync.WaitGroup{}
   112  	wg.Add(100)
   113  	for i := 0; i < 100; i++ {
   114  		Go(func() {
   115  			assert.Equal(t, 2, tls.Get())
   116  			assert.Equal(t, "!", tls2.Get())
   117  			wg.Done()
   118  		})
   119  	}
   120  	wg.Wait()
   121  	assert.Equal(t, 2, tls.Get())
   122  	assert.Equal(t, "!", tls2.Get())
   123  }
   124  
   125  func TestInheritableThreadLocal_WithInitial(t *testing.T) {
   126  	src := &person{Id: 1, Name: "Tim"}
   127  	tls := NewInheritableThreadLocalWithInitial[*person](nil)
   128  	tls2 := NewInheritableThreadLocalWithInitial[*person](func() *person {
   129  		var value *person
   130  		return value
   131  	})
   132  	tls3 := NewInheritableThreadLocalWithInitial[*person](func() *person {
   133  		return src
   134  	})
   135  	tls4 := NewInheritableThreadLocalWithInitial[person](func() person {
   136  		return *src
   137  	})
   138  
   139  	for i := 0; i < 100; i++ {
   140  		p := tls.Get()
   141  		assert.Nil(t, p)
   142  		//
   143  		p2 := tls2.Get()
   144  		assert.Nil(t, p2)
   145  		//
   146  		p3 := tls3.Get()
   147  		assert.Same(t, src, p3)
   148  
   149  		p4 := tls4.Get()
   150  		assert.NotSame(t, src, &p4)
   151  		assert.Equal(t, *src, p4)
   152  
   153  		wg := &sync.WaitGroup{}
   154  		wg.Add(1)
   155  		Go(func() {
   156  			assert.Same(t, src, tls3.Get())
   157  			p5 := tls4.Get()
   158  			assert.NotSame(t, src, &p5)
   159  			assert.Equal(t, *src, p5)
   160  			//
   161  			wg.Done()
   162  		})
   163  		wg.Wait()
   164  	}
   165  
   166  	tls3.Set(nil)
   167  	tls4.Set(person{})
   168  	assert.Nil(t, tls3.Get())
   169  	assert.Equal(t, person{}, tls4.Get())
   170  
   171  	tls3.Remove()
   172  	tls4.Remove()
   173  	assert.Same(t, src, tls3.Get())
   174  	p6 := tls4.Get()
   175  	assert.NotSame(t, src, &p6)
   176  	assert.Equal(t, *src, p6)
   177  }
   178  
   179  func TestInheritableThreadLocal_CrossCoroutine(t *testing.T) {
   180  	tls := NewInheritableThreadLocal[string]()
   181  	tls.Set("Hello")
   182  	assert.Equal(t, "Hello", tls.Get())
   183  	subWait := &sync.WaitGroup{}
   184  	subWait.Add(2)
   185  	finishWait := &sync.WaitGroup{}
   186  	finishWait.Add(2)
   187  	go func() {
   188  		subWait.Wait()
   189  		assert.Equal(t, "", tls.Get())
   190  		finishWait.Done()
   191  	}()
   192  	Go(func() {
   193  		subWait.Wait()
   194  		assert.Equal(t, "Hello", tls.Get())
   195  		finishWait.Done()
   196  	})
   197  	tls.Remove()      //remove in parent goroutine should not affect child goroutine
   198  	subWait.Done()    //allow sub goroutine run
   199  	subWait.Done()    //allow sub goroutine run
   200  	finishWait.Wait() //wait sub goroutine done
   201  	finishWait.Wait() //wait sub goroutine done
   202  }
   203  
   204  func TestInheritableThreadLocal_CreateBatch(t *testing.T) {
   205  	const count = 128
   206  	tlsList := make([]ThreadLocal[int], count)
   207  	for i := 0; i < count; i++ {
   208  		value := i
   209  		tlsList[i] = NewInheritableThreadLocalWithInitial[int](func() int { return value })
   210  	}
   211  	for i := 0; i < count; i++ {
   212  		assert.Equal(t, i, tlsList[i].Get())
   213  	}
   214  }
   215  
   216  func TestInheritableThreadLocal_Copy(t *testing.T) {
   217  	tls := NewInheritableThreadLocalWithInitial[*person](func() *person {
   218  		return &person{Id: 1, Name: "Tim"}
   219  	})
   220  	tls2 := NewInheritableThreadLocalWithInitial[person](func() person {
   221  		return person{Id: 2, Name: "Andy"}
   222  	})
   223  
   224  	p1 := tls.Get()
   225  	assert.Equal(t, 1, p1.Id)
   226  	assert.Equal(t, "Tim", p1.Name)
   227  	p2 := tls2.Get()
   228  	assert.Equal(t, 2, p2.Id)
   229  	assert.Equal(t, "Andy", p2.Name)
   230  	//
   231  	task := GoWait(func(token CancelToken) {
   232  		p3 := tls.Get()
   233  		assert.Same(t, p1, p3)
   234  		assert.Equal(t, 1, p3.Id)
   235  		assert.Equal(t, "Tim", p1.Name)
   236  		p4 := tls2.Get()
   237  		assert.NotSame(t, &p2, &p4)
   238  		assert.Equal(t, p2, p4)
   239  		assert.Equal(t, 2, p4.Id)
   240  		assert.Equal(t, "Andy", p4.Name)
   241  		//
   242  		p3.Name = "Tim2"
   243  		p4.Name = "Andy2"
   244  	})
   245  	task.Get()
   246  	//
   247  	p5 := tls.Get()
   248  	assert.Same(t, p1, p5)
   249  	assert.Equal(t, 1, p5.Id)
   250  	assert.Equal(t, "Tim2", p5.Name)
   251  	p6 := tls2.Get()
   252  	assert.NotSame(t, &p2, &p6)
   253  	assert.Equal(t, p2, p6)
   254  	assert.Equal(t, 2, p6.Id)
   255  	assert.Equal(t, "Andy", p6.Name)
   256  }
   257  
   258  func TestInheritableThreadLocal_Cloneable(t *testing.T) {
   259  	tls := NewInheritableThreadLocalWithInitial[*personCloneable](func() *personCloneable {
   260  		return &personCloneable{Id: 1, Name: "Tim"}
   261  	})
   262  	tls2 := NewInheritableThreadLocalWithInitial[personCloneable](func() personCloneable {
   263  		return personCloneable{Id: 2, Name: "Andy"}
   264  	})
   265  
   266  	p1 := tls.Get()
   267  	assert.Equal(t, 1, p1.Id)
   268  	assert.Equal(t, "Tim", p1.Name)
   269  	p2 := tls2.Get()
   270  	assert.Equal(t, 2, p2.Id)
   271  	assert.Equal(t, "Andy", p2.Name)
   272  	//
   273  	task := GoWait(func(token CancelToken) {
   274  		p3 := tls.Get() //p3 is clone from p1
   275  		assert.NotSame(t, p1, p3)
   276  		assert.Equal(t, 1, p3.Id)
   277  		assert.Equal(t, "Tim", p1.Name)
   278  		p4 := tls2.Get()
   279  		assert.NotSame(t, &p2, &p4)
   280  		assert.Equal(t, p2, p4)
   281  		assert.Equal(t, 2, p4.Id)
   282  		assert.Equal(t, "Andy", p4.Name)
   283  		//
   284  		p3.Name = "Tim2"
   285  		p4.Name = "Andy2"
   286  	})
   287  	task.Get()
   288  	//
   289  	p5 := tls.Get()
   290  	assert.Same(t, p1, p5)
   291  	assert.Equal(t, 1, p5.Id)
   292  	assert.Equal(t, "Tim", p5.Name)
   293  	p6 := tls2.Get()
   294  	assert.NotSame(t, &p2, &p6)
   295  	assert.Equal(t, p2, p6)
   296  	assert.Equal(t, 2, p6.Id)
   297  	assert.Equal(t, "Andy", p6.Name)
   298  }