github.com/timandy/routine@v1.1.4-0.20240507073150-e4a3e1fe2ba5/thread_local_test.go (about) 1 package routine 2 3 import ( 4 "math" 5 "sync" 6 "testing" 7 8 "github.com/stretchr/testify/assert" 9 ) 10 11 func TestThreadLocal_Index(t *testing.T) { 12 tls := NewThreadLocal[string]() 13 assert.GreaterOrEqual(t, tls.(*threadLocal[string]).index, 0) 14 tls2 := NewThreadLocalWithInitial[string](func() string { 15 return "Hello" 16 }) 17 assert.Greater(t, tls2.(*threadLocal[string]).index, tls.(*threadLocal[string]).index) 18 } 19 20 func TestThreadLocal_NextIndex(t *testing.T) { 21 backup := threadLocalIndex 22 defer func() { 23 threadLocalIndex = backup 24 }() 25 // 26 threadLocalIndex = math.MaxInt32 27 assert.Panics(t, func() { 28 nextThreadLocalIndex() 29 }) 30 assert.Equal(t, math.MaxInt32, int(threadLocalIndex)) 31 } 32 33 func TestThreadLocal_Common(t *testing.T) { 34 tls := NewThreadLocal[int]() 35 tls2 := NewThreadLocal[string]() 36 tls.Remove() 37 tls2.Remove() 38 assert.Equal(t, 0, tls.Get()) 39 assert.Equal(t, "", tls2.Get()) 40 // 41 tls.Set(1) 42 tls2.Set("World") 43 assert.Equal(t, 1, tls.Get()) 44 assert.Equal(t, "World", tls2.Get()) 45 // 46 tls.Set(0) 47 tls2.Set("") 48 assert.Equal(t, 0, tls.Get()) 49 assert.Equal(t, "", tls2.Get()) 50 // 51 tls.Set(2) 52 tls2.Set("!") 53 assert.Equal(t, 2, tls.Get()) 54 assert.Equal(t, "!", tls2.Get()) 55 // 56 tls.Remove() 57 tls2.Remove() 58 assert.Equal(t, 0, tls.Get()) 59 assert.Equal(t, "", tls2.Get()) 60 // 61 tls.Set(2) 62 tls2.Set("!") 63 assert.Equal(t, 2, tls.Get()) 64 assert.Equal(t, "!", tls2.Get()) 65 wg := &sync.WaitGroup{} 66 wg.Add(100) 67 for i := 0; i < 100; i++ { 68 Go(func() { 69 assert.Equal(t, 0, tls.Get()) 70 assert.Equal(t, "", tls2.Get()) 71 wg.Done() 72 }) 73 } 74 wg.Wait() 75 assert.Equal(t, 2, tls.Get()) 76 assert.Equal(t, "!", tls2.Get()) 77 } 78 79 func TestThreadLocal_Mixed(t *testing.T) { 80 tls := NewThreadLocal[int]() 81 tls2 := NewThreadLocalWithInitial[string](func() string { 82 return "Hello" 83 }) 84 assert.Equal(t, 0, tls.Get()) 85 assert.Equal(t, "Hello", tls2.Get()) 86 // 87 tls.Set(1) 88 tls2.Set("World") 89 assert.Equal(t, 1, tls.Get()) 90 assert.Equal(t, "World", tls2.Get()) 91 // 92 tls.Set(0) 93 tls2.Set("") 94 assert.Equal(t, 0, tls.Get()) 95 assert.Equal(t, "", tls2.Get()) 96 // 97 tls.Set(2) 98 tls2.Set("!") 99 assert.Equal(t, 2, tls.Get()) 100 assert.Equal(t, "!", tls2.Get()) 101 // 102 tls.Remove() 103 tls2.Remove() 104 assert.Equal(t, 0, tls.Get()) 105 assert.Equal(t, "Hello", tls2.Get()) 106 // 107 tls.Set(2) 108 tls2.Set("!") 109 assert.Equal(t, 2, tls.Get()) 110 assert.Equal(t, "!", tls2.Get()) 111 wg := &sync.WaitGroup{} 112 wg.Add(100) 113 for i := 0; i < 100; i++ { 114 Go(func() { 115 assert.Equal(t, 0, tls.Get()) 116 assert.Equal(t, "Hello", tls2.Get()) 117 wg.Done() 118 }) 119 } 120 wg.Wait() 121 assert.Equal(t, 2, tls.Get()) 122 assert.Equal(t, "!", tls2.Get()) 123 } 124 125 func TestThreadLocal_WithInitial(t *testing.T) { 126 src := &person{Id: 1, Name: "Tim"} 127 tls := NewThreadLocalWithInitial[*person](nil) 128 tls2 := NewThreadLocalWithInitial[*person](func() *person { 129 var value *person 130 return value 131 }) 132 tls3 := NewThreadLocalWithInitial[*person](func() *person { 133 return src 134 }) 135 tls4 := NewThreadLocalWithInitial[person](func() person { 136 return *src 137 }) 138 139 for i := 0; i < 100; i++ { 140 p := tls.Get() 141 assert.Nil(t, p) 142 // 143 p2 := tls2.Get() 144 assert.Nil(t, p2) 145 // 146 p3 := tls3.Get() 147 assert.Same(t, src, p3) 148 149 p4 := tls4.Get() 150 assert.NotSame(t, src, &p4) 151 assert.Equal(t, *src, p4) 152 153 wg := &sync.WaitGroup{} 154 wg.Add(1) 155 Go(func() { 156 assert.Same(t, src, tls3.Get()) 157 p5 := tls4.Get() 158 assert.NotSame(t, src, &p5) 159 assert.Equal(t, *src, p5) 160 // 161 wg.Done() 162 }) 163 wg.Wait() 164 } 165 166 tls3.Set(nil) 167 tls4.Set(person{}) 168 assert.Nil(t, tls3.Get()) 169 assert.Equal(t, person{}, tls4.Get()) 170 171 tls3.Remove() 172 tls4.Remove() 173 assert.Same(t, src, tls3.Get()) 174 p6 := tls4.Get() 175 assert.NotSame(t, src, &p6) 176 assert.Equal(t, *src, p6) 177 } 178 179 func TestThreadLocal_CrossCoroutine(t *testing.T) { 180 tls := NewThreadLocal[string]() 181 tls.Set("Hello") 182 assert.Equal(t, "Hello", tls.Get()) 183 subWait := &sync.WaitGroup{} 184 subWait.Add(2) 185 finishWait := &sync.WaitGroup{} 186 finishWait.Add(2) 187 go func() { 188 subWait.Wait() 189 assert.Equal(t, "", tls.Get()) 190 finishWait.Done() 191 }() 192 Go(func() { 193 subWait.Wait() 194 assert.Equal(t, "", tls.Get()) 195 finishWait.Done() 196 }) 197 tls.Remove() //remove in parent goroutine should not affect child goroutine 198 subWait.Done() //allow sub goroutine run 199 subWait.Done() //allow sub goroutine run 200 finishWait.Wait() //wait sub goroutine done 201 finishWait.Wait() //wait sub goroutine done 202 } 203 204 func TestThreadLocal_CreateBatch(t *testing.T) { 205 const count = 128 206 tlsList := make([]ThreadLocal[int], count) 207 for i := 0; i < count; i++ { 208 value := i 209 tlsList[i] = NewThreadLocalWithInitial[int](func() int { return value }) 210 } 211 for i := 0; i < count; i++ { 212 assert.Equal(t, i, tlsList[i].Get()) 213 } 214 } 215 216 type person struct { 217 Id int 218 Name string 219 }