github.com/xg0n/routine@v0.0.0-20240119033701-c364deb94aee/thread_local_inheritable_test.go (about) 1 package routine 2 3 import ( 4 "math" 5 "sync" 6 "testing" 7 8 "github.com/stretchr/testify/assert" 9 ) 10 11 func TestInheritableThreadLocal_Index(t *testing.T) { 12 tls := NewInheritableThreadLocal[string]() 13 assert.GreaterOrEqual(t, tls.(*inheritableThreadLocal[string]).index, 0) 14 tls2 := NewInheritableThreadLocalWithInitial[string](func() string { 15 return "Hello" 16 }) 17 assert.Greater(t, tls2.(*inheritableThreadLocal[string]).index, tls.(*inheritableThreadLocal[string]).index) 18 } 19 20 func TestInheritableThreadLocal_NextIndex(t *testing.T) { 21 backup := inheritableThreadLocalIndex 22 defer func() { 23 inheritableThreadLocalIndex = backup 24 }() 25 // 26 inheritableThreadLocalIndex = math.MaxInt32 27 assert.Panics(t, func() { 28 nextInheritableThreadLocalIndex() 29 }) 30 assert.Equal(t, math.MaxInt32, int(inheritableThreadLocalIndex)) 31 } 32 33 func TestInheritableThreadLocal_Common(t *testing.T) { 34 tls := NewInheritableThreadLocal[int]() 35 tls2 := NewInheritableThreadLocal[string]() 36 tls.Remove() 37 tls2.Remove() 38 assert.Equal(t, 0, tls.Get()) 39 assert.Equal(t, "", tls2.Get()) 40 // 41 tls.Set(1) 42 tls2.Set("World") 43 assert.Equal(t, 1, tls.Get()) 44 assert.Equal(t, "World", tls2.Get()) 45 // 46 tls.Set(0) 47 tls2.Set("") 48 assert.Equal(t, 0, tls.Get()) 49 assert.Equal(t, "", tls2.Get()) 50 // 51 tls.Set(2) 52 tls2.Set("!") 53 assert.Equal(t, 2, tls.Get()) 54 assert.Equal(t, "!", tls2.Get()) 55 // 56 tls.Remove() 57 tls2.Remove() 58 assert.Equal(t, 0, tls.Get()) 59 assert.Equal(t, "", tls2.Get()) 60 // 61 tls.Set(2) 62 tls2.Set("!") 63 assert.Equal(t, 2, tls.Get()) 64 assert.Equal(t, "!", tls2.Get()) 65 wg := &sync.WaitGroup{} 66 wg.Add(100) 67 for i := 0; i < 100; i++ { 68 Go(func() { 69 assert.Equal(t, 2, tls.Get()) 70 assert.Equal(t, "!", tls2.Get()) 71 wg.Done() 72 }) 73 } 74 wg.Wait() 75 assert.Equal(t, 2, tls.Get()) 76 assert.Equal(t, "!", tls2.Get()) 77 } 78 79 func TestInheritableThreadLocal_Mixed(t *testing.T) { 80 tls := NewInheritableThreadLocal[int]() 81 tls2 := NewInheritableThreadLocalWithInitial[string](func() string { 82 return "Hello" 83 }) 84 assert.Equal(t, 0, tls.Get()) 85 assert.Equal(t, "Hello", tls2.Get()) 86 // 87 tls.Set(1) 88 tls2.Set("World") 89 assert.Equal(t, 1, tls.Get()) 90 assert.Equal(t, "World", tls2.Get()) 91 // 92 tls.Set(0) 93 tls2.Set("") 94 assert.Equal(t, 0, tls.Get()) 95 assert.Equal(t, "", tls2.Get()) 96 // 97 tls.Set(2) 98 tls2.Set("!") 99 assert.Equal(t, 2, tls.Get()) 100 assert.Equal(t, "!", tls2.Get()) 101 // 102 tls.Remove() 103 tls2.Remove() 104 assert.Equal(t, 0, tls.Get()) 105 assert.Equal(t, "Hello", tls2.Get()) 106 // 107 tls.Set(2) 108 tls2.Set("!") 109 assert.Equal(t, 2, tls.Get()) 110 assert.Equal(t, "!", tls2.Get()) 111 wg := &sync.WaitGroup{} 112 wg.Add(100) 113 for i := 0; i < 100; i++ { 114 Go(func() { 115 assert.Equal(t, 2, tls.Get()) 116 assert.Equal(t, "!", tls2.Get()) 117 wg.Done() 118 }) 119 } 120 wg.Wait() 121 assert.Equal(t, 2, tls.Get()) 122 assert.Equal(t, "!", tls2.Get()) 123 } 124 125 func TestInheritableThreadLocal_WithInitial(t *testing.T) { 126 src := &person{Id: 1, Name: "Tim"} 127 tls := NewInheritableThreadLocalWithInitial[*person](nil) 128 tls2 := NewInheritableThreadLocalWithInitial[*person](func() *person { 129 var value *person 130 return value 131 }) 132 tls3 := NewInheritableThreadLocalWithInitial[*person](func() *person { 133 return src 134 }) 135 tls4 := NewInheritableThreadLocalWithInitial[person](func() person { 136 return *src 137 }) 138 139 for i := 0; i < 100; i++ { 140 p := tls.Get() 141 assert.Nil(t, p) 142 // 143 p2 := tls2.Get() 144 assert.Nil(t, p2) 145 // 146 p3 := tls3.Get() 147 assert.Same(t, src, p3) 148 149 p4 := tls4.Get() 150 assert.NotSame(t, src, &p4) 151 assert.Equal(t, *src, p4) 152 153 wg := &sync.WaitGroup{} 154 wg.Add(1) 155 Go(func() { 156 assert.Same(t, src, tls3.Get()) 157 p5 := tls4.Get() 158 assert.NotSame(t, src, &p5) 159 assert.Equal(t, *src, p5) 160 // 161 wg.Done() 162 }) 163 wg.Wait() 164 } 165 166 tls3.Set(nil) 167 tls4.Set(person{}) 168 assert.Nil(t, tls3.Get()) 169 assert.Equal(t, person{}, tls4.Get()) 170 171 tls3.Remove() 172 tls4.Remove() 173 assert.Same(t, src, tls3.Get()) 174 p6 := tls4.Get() 175 assert.NotSame(t, src, &p6) 176 assert.Equal(t, *src, p6) 177 } 178 179 func TestInheritableThreadLocal_CrossCoroutine(t *testing.T) { 180 tls := NewInheritableThreadLocal[string]() 181 tls.Set("Hello") 182 assert.Equal(t, "Hello", tls.Get()) 183 subWait := &sync.WaitGroup{} 184 subWait.Add(2) 185 finishWait := &sync.WaitGroup{} 186 finishWait.Add(2) 187 go func() { 188 subWait.Wait() 189 assert.Equal(t, "", tls.Get()) 190 finishWait.Done() 191 }() 192 Go(func() { 193 subWait.Wait() 194 assert.Equal(t, "Hello", tls.Get()) 195 finishWait.Done() 196 }) 197 tls.Remove() //remove in parent goroutine should not affect child goroutine 198 subWait.Done() //allow sub goroutine run 199 subWait.Done() //allow sub goroutine run 200 finishWait.Wait() //wait sub goroutine done 201 finishWait.Wait() //wait sub goroutine done 202 } 203 204 func TestInheritableThreadLocal_CreateBatch(t *testing.T) { 205 const count = 128 206 tlsList := make([]ThreadLocal[int], count) 207 for i := 0; i < count; i++ { 208 value := i 209 tlsList[i] = NewInheritableThreadLocalWithInitial[int](func() int { return value }) 210 } 211 for i := 0; i < count; i++ { 212 assert.Equal(t, i, tlsList[i].Get()) 213 } 214 } 215 216 func TestInheritableThreadLocal_Copy(t *testing.T) { 217 tls := NewInheritableThreadLocalWithInitial[*person](func() *person { 218 return &person{Id: 1, Name: "Tim"} 219 }) 220 tls2 := NewInheritableThreadLocalWithInitial[person](func() person { 221 return person{Id: 2, Name: "Andy"} 222 }) 223 224 p1 := tls.Get() 225 assert.Equal(t, 1, p1.Id) 226 assert.Equal(t, "Tim", p1.Name) 227 p2 := tls2.Get() 228 assert.Equal(t, 2, p2.Id) 229 assert.Equal(t, "Andy", p2.Name) 230 // 231 task := GoWait(func(token CancelToken) { 232 p3 := tls.Get() 233 assert.Same(t, p1, p3) 234 assert.Equal(t, 1, p3.Id) 235 assert.Equal(t, "Tim", p1.Name) 236 p4 := tls2.Get() 237 assert.NotSame(t, &p2, &p4) 238 assert.Equal(t, p2, p4) 239 assert.Equal(t, 2, p4.Id) 240 assert.Equal(t, "Andy", p4.Name) 241 // 242 p3.Name = "Tim2" 243 p4.Name = "Andy2" 244 }) 245 task.Get() 246 // 247 p5 := tls.Get() 248 assert.Same(t, p1, p5) 249 assert.Equal(t, 1, p5.Id) 250 assert.Equal(t, "Tim2", p5.Name) 251 p6 := tls2.Get() 252 assert.NotSame(t, &p2, &p6) 253 assert.Equal(t, p2, p6) 254 assert.Equal(t, 2, p6.Id) 255 assert.Equal(t, "Andy", p6.Name) 256 } 257 258 func TestInheritableThreadLocal_Cloneable(t *testing.T) { 259 tls := NewInheritableThreadLocalWithInitial[*personCloneable](func() *personCloneable { 260 return &personCloneable{Id: 1, Name: "Tim"} 261 }) 262 tls2 := NewInheritableThreadLocalWithInitial[personCloneable](func() personCloneable { 263 return personCloneable{Id: 2, Name: "Andy"} 264 }) 265 266 p1 := tls.Get() 267 assert.Equal(t, 1, p1.Id) 268 assert.Equal(t, "Tim", p1.Name) 269 p2 := tls2.Get() 270 assert.Equal(t, 2, p2.Id) 271 assert.Equal(t, "Andy", p2.Name) 272 // 273 task := GoWait(func(token CancelToken) { 274 p3 := tls.Get() //p3 is clone from p1 275 assert.NotSame(t, p1, p3) 276 assert.Equal(t, 1, p3.Id) 277 assert.Equal(t, "Tim", p1.Name) 278 p4 := tls2.Get() 279 assert.NotSame(t, &p2, &p4) 280 assert.Equal(t, p2, p4) 281 assert.Equal(t, 2, p4.Id) 282 assert.Equal(t, "Andy", p4.Name) 283 // 284 p3.Name = "Tim2" 285 p4.Name = "Andy2" 286 }) 287 task.Get() 288 // 289 p5 := tls.Get() 290 assert.Same(t, p1, p5) 291 assert.Equal(t, 1, p5.Id) 292 assert.Equal(t, "Tim", p5.Name) 293 p6 := tls2.Get() 294 assert.NotSame(t, &p2, &p6) 295 assert.Equal(t, p2, p6) 296 assert.Equal(t, 2, p6.Id) 297 assert.Equal(t, "Andy", p6.Name) 298 }