github.com/xg0n/routine@v0.0.0-20240119033701-c364deb94aee/thread_local_test.go (about)

     1  package routine
     2  
     3  import (
     4  	"math"
     5  	"sync"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/assert"
     9  )
    10  
    11  func TestThreadLocal_Index(t *testing.T) {
    12  	tls := NewThreadLocal[string]()
    13  	assert.GreaterOrEqual(t, tls.(*threadLocal[string]).index, 0)
    14  	tls2 := NewThreadLocalWithInitial[string](func() string {
    15  		return "Hello"
    16  	})
    17  	assert.Greater(t, tls2.(*threadLocal[string]).index, tls.(*threadLocal[string]).index)
    18  }
    19  
    20  func TestThreadLocal_NextIndex(t *testing.T) {
    21  	backup := threadLocalIndex
    22  	defer func() {
    23  		threadLocalIndex = backup
    24  	}()
    25  	//
    26  	threadLocalIndex = math.MaxInt32
    27  	assert.Panics(t, func() {
    28  		nextThreadLocalIndex()
    29  	})
    30  	assert.Equal(t, math.MaxInt32, int(threadLocalIndex))
    31  }
    32  
    33  func TestThreadLocal_Common(t *testing.T) {
    34  	tls := NewThreadLocal[int]()
    35  	tls2 := NewThreadLocal[string]()
    36  	tls.Remove()
    37  	tls2.Remove()
    38  	assert.Equal(t, 0, tls.Get())
    39  	assert.Equal(t, "", tls2.Get())
    40  	//
    41  	tls.Set(1)
    42  	tls2.Set("World")
    43  	assert.Equal(t, 1, tls.Get())
    44  	assert.Equal(t, "World", tls2.Get())
    45  	//
    46  	tls.Set(0)
    47  	tls2.Set("")
    48  	assert.Equal(t, 0, tls.Get())
    49  	assert.Equal(t, "", tls2.Get())
    50  	//
    51  	tls.Set(2)
    52  	tls2.Set("!")
    53  	assert.Equal(t, 2, tls.Get())
    54  	assert.Equal(t, "!", tls2.Get())
    55  	//
    56  	tls.Remove()
    57  	tls2.Remove()
    58  	assert.Equal(t, 0, tls.Get())
    59  	assert.Equal(t, "", tls2.Get())
    60  	//
    61  	tls.Set(2)
    62  	tls2.Set("!")
    63  	assert.Equal(t, 2, tls.Get())
    64  	assert.Equal(t, "!", tls2.Get())
    65  	wg := &sync.WaitGroup{}
    66  	wg.Add(100)
    67  	for i := 0; i < 100; i++ {
    68  		Go(func() {
    69  			assert.Equal(t, 0, tls.Get())
    70  			assert.Equal(t, "", tls2.Get())
    71  			wg.Done()
    72  		})
    73  	}
    74  	wg.Wait()
    75  	assert.Equal(t, 2, tls.Get())
    76  	assert.Equal(t, "!", tls2.Get())
    77  }
    78  
    79  func TestThreadLocal_Mixed(t *testing.T) {
    80  	tls := NewThreadLocal[int]()
    81  	tls2 := NewThreadLocalWithInitial[string](func() string {
    82  		return "Hello"
    83  	})
    84  	assert.Equal(t, 0, tls.Get())
    85  	assert.Equal(t, "Hello", tls2.Get())
    86  	//
    87  	tls.Set(1)
    88  	tls2.Set("World")
    89  	assert.Equal(t, 1, tls.Get())
    90  	assert.Equal(t, "World", tls2.Get())
    91  	//
    92  	tls.Set(0)
    93  	tls2.Set("")
    94  	assert.Equal(t, 0, tls.Get())
    95  	assert.Equal(t, "", tls2.Get())
    96  	//
    97  	tls.Set(2)
    98  	tls2.Set("!")
    99  	assert.Equal(t, 2, tls.Get())
   100  	assert.Equal(t, "!", tls2.Get())
   101  	//
   102  	tls.Remove()
   103  	tls2.Remove()
   104  	assert.Equal(t, 0, tls.Get())
   105  	assert.Equal(t, "Hello", tls2.Get())
   106  	//
   107  	tls.Set(2)
   108  	tls2.Set("!")
   109  	assert.Equal(t, 2, tls.Get())
   110  	assert.Equal(t, "!", tls2.Get())
   111  	wg := &sync.WaitGroup{}
   112  	wg.Add(100)
   113  	for i := 0; i < 100; i++ {
   114  		Go(func() {
   115  			assert.Equal(t, 0, tls.Get())
   116  			assert.Equal(t, "Hello", tls2.Get())
   117  			wg.Done()
   118  		})
   119  	}
   120  	wg.Wait()
   121  	assert.Equal(t, 2, tls.Get())
   122  	assert.Equal(t, "!", tls2.Get())
   123  }
   124  
   125  func TestThreadLocal_WithInitial(t *testing.T) {
   126  	src := &person{Id: 1, Name: "Tim"}
   127  	tls := NewThreadLocalWithInitial[*person](nil)
   128  	tls2 := NewThreadLocalWithInitial[*person](func() *person {
   129  		var value *person
   130  		return value
   131  	})
   132  	tls3 := NewThreadLocalWithInitial[*person](func() *person {
   133  		return src
   134  	})
   135  	tls4 := NewThreadLocalWithInitial[person](func() person {
   136  		return *src
   137  	})
   138  
   139  	for i := 0; i < 100; i++ {
   140  		p := tls.Get()
   141  		assert.Nil(t, p)
   142  		//
   143  		p2 := tls2.Get()
   144  		assert.Nil(t, p2)
   145  		//
   146  		p3 := tls3.Get()
   147  		assert.Same(t, src, p3)
   148  
   149  		p4 := tls4.Get()
   150  		assert.NotSame(t, src, &p4)
   151  		assert.Equal(t, *src, p4)
   152  
   153  		wg := &sync.WaitGroup{}
   154  		wg.Add(1)
   155  		Go(func() {
   156  			assert.Same(t, src, tls3.Get())
   157  			p5 := tls4.Get()
   158  			assert.NotSame(t, src, &p5)
   159  			assert.Equal(t, *src, p5)
   160  			//
   161  			wg.Done()
   162  		})
   163  		wg.Wait()
   164  	}
   165  
   166  	tls3.Set(nil)
   167  	tls4.Set(person{})
   168  	assert.Nil(t, tls3.Get())
   169  	assert.Equal(t, person{}, tls4.Get())
   170  
   171  	tls3.Remove()
   172  	tls4.Remove()
   173  	assert.Same(t, src, tls3.Get())
   174  	p6 := tls4.Get()
   175  	assert.NotSame(t, src, &p6)
   176  	assert.Equal(t, *src, p6)
   177  }
   178  
   179  func TestThreadLocal_CrossCoroutine(t *testing.T) {
   180  	tls := NewThreadLocal[string]()
   181  	tls.Set("Hello")
   182  	assert.Equal(t, "Hello", tls.Get())
   183  	subWait := &sync.WaitGroup{}
   184  	subWait.Add(2)
   185  	finishWait := &sync.WaitGroup{}
   186  	finishWait.Add(2)
   187  	go func() {
   188  		subWait.Wait()
   189  		assert.Equal(t, "", tls.Get())
   190  		finishWait.Done()
   191  	}()
   192  	Go(func() {
   193  		subWait.Wait()
   194  		assert.Equal(t, "", tls.Get())
   195  		finishWait.Done()
   196  	})
   197  	tls.Remove()      //remove in parent goroutine should not affect child goroutine
   198  	subWait.Done()    //allow sub goroutine run
   199  	subWait.Done()    //allow sub goroutine run
   200  	finishWait.Wait() //wait sub goroutine done
   201  	finishWait.Wait() //wait sub goroutine done
   202  }
   203  
   204  func TestThreadLocal_CreateBatch(t *testing.T) {
   205  	const count = 128
   206  	tlsList := make([]ThreadLocal[int], count)
   207  	for i := 0; i < count; i++ {
   208  		value := i
   209  		tlsList[i] = NewThreadLocalWithInitial[int](func() int { return value })
   210  	}
   211  	for i := 0; i < count; i++ {
   212  		assert.Equal(t, i, tlsList[i].Get())
   213  	}
   214  }
   215  
   216  type person struct {
   217  	Id   int
   218  	Name string
   219  }