go.mercari.io/datastore@v1.8.2/testsuite/realworld/recursivebatch/main.go (about) 1 package recursivebatch 2 3 import ( 4 "context" 5 "fmt" 6 "testing" 7 8 "go.mercari.io/datastore" 9 "go.mercari.io/datastore/boom" 10 "go.mercari.io/datastore/testsuite" 11 ) 12 13 // TestSuite contains all the test cases that this package provides. 14 var TestSuite = map[string]testsuite.Test{ 15 "RealWorld_RecursiveBatch": recursiveBatch, 16 } 17 18 func init() { 19 testsuite.MergeTestSuite(TestSuite) 20 } 21 22 var _ datastore.PropertyLoadSaver = &depth1{} 23 var _ datastore.PropertyLoadSaver = &depth2{} 24 25 type depth1 struct { 26 ID int64 `boom:"id"` 27 Depth2IDs []int64 `json:"-"` 28 Depth2List []*depth2 `datastore:"-"` 29 } 30 31 type depth2 struct { 32 ID int64 `boom:"id"` 33 Depth3IDs []int64 `json:"-"` 34 Depth3List []*depth3 `datastore:"-"` 35 } 36 37 type depth3 struct { 38 ID int64 `boom:"id"` 39 Name string `` 40 } 41 42 func (d *depth1) Load(ctx context.Context, ps []datastore.Property) error { 43 err := datastore.LoadStruct(ctx, d, ps) 44 if err != nil { 45 return err 46 } 47 48 bt := extractBoomBatch(ctx) 49 50 d.Depth2List = make([]*depth2, 0, len(d.Depth2IDs)) 51 for _, depth2ID := range d.Depth2IDs { 52 d2 := &depth2{ 53 ID: depth2ID, 54 } 55 bt.Get(d2, nil) 56 d.Depth2List = append(d.Depth2List, d2) 57 } 58 59 return nil 60 } 61 62 func (d *depth1) Save(ctx context.Context) ([]datastore.Property, error) { 63 d.Depth2IDs = make([]int64, 0, len(d.Depth2List)) 64 for _, d2 := range d.Depth2List { 65 d.Depth2IDs = append(d.Depth2IDs, d2.ID) 66 } 67 68 return datastore.SaveStruct(ctx, d) 69 } 70 71 func (d *depth2) Load(ctx context.Context, ps []datastore.Property) error { 72 err := datastore.LoadStruct(ctx, d, ps) 73 if err != nil { 74 return err 75 } 76 77 bt := extractBoomBatch(ctx) 78 79 d.Depth3List = make([]*depth3, 0, len(d.Depth3IDs)) 80 for _, depth3ID := range d.Depth3IDs { 81 d3 := &depth3{ 82 ID: depth3ID, 83 } 84 bt.Get(d3, nil) 85 d.Depth3List = append(d.Depth3List, d3) 86 } 87 88 return nil 89 } 90 91 func (d *depth2) Save(ctx context.Context) ([]datastore.Property, error) { 92 d.Depth3IDs = make([]int64, 0, len(d.Depth3List)) 93 for _, d3 := range d.Depth3List { 94 d.Depth3IDs = append(d.Depth3IDs, d3.ID) 95 } 96 97 return datastore.SaveStruct(ctx, d) 98 } 99 100 type contextBoomBatch struct{} 101 102 func extractBoomBatch(ctx context.Context) *boom.Batch { 103 return ctx.Value(contextBoomBatch{}).(*boom.Batch) 104 } 105 106 func recursiveBatch(ctx context.Context, t *testing.T, client datastore.Client) { 107 defer func() { 108 err := client.Close() 109 if err != nil { 110 t.Fatal(err) 111 } 112 }() 113 114 bm := boom.FromClient(ctx, client) 115 bt := bm.Batch() 116 ctx = context.WithValue(ctx, contextBoomBatch{}, bt) 117 bm.Context = ctx 118 119 const size = 5 120 121 // make test data 122 for i := 1; i <= size; i++ { 123 d1 := &depth1{ 124 ID: int64(i), 125 } 126 for j := 1; j <= size; j++ { 127 d2 := &depth2{ 128 ID: int64(i*1000 + j), 129 } 130 for k := 1; k <= size; k++ { 131 d3 := &depth3{ 132 ID: int64(i*1000000 + j*1000 + k), 133 Name: fmt.Sprintf("#%d", i*1000000+j*1000+k), 134 } 135 bt.Put(d3, nil) 136 d2.Depth3List = append(d2.Depth3List, d3) 137 } 138 bt.Put(d2, nil) 139 d1.Depth2List = append(d1.Depth2List, d2) 140 } 141 bt.Put(d1, nil) 142 } 143 err := bt.Exec() 144 if err != nil { 145 t.Fatal(err) 146 } 147 148 // get test data 149 list := make([]*depth1, 0, size) 150 for i := 1; i <= size; i++ { 151 d1 := &depth1{ 152 ID: int64(i), 153 } 154 bt.Get(d1, nil) 155 list = append(list, d1) 156 } 157 err = bt.Exec() 158 if err != nil { 159 t.Fatal(err) 160 } 161 162 if v := len(list); v != size { 163 t.Errorf("unexpected: %v", v) 164 } 165 for idx1, d1 := range list { 166 if v := d1.ID; v != int64(idx1+1) { 167 t.Errorf("unexpected: %v", v) 168 } 169 170 if v := len(d1.Depth2List); v != size { 171 t.Errorf("unexpected: %v", v) 172 } 173 for idx2, d2 := range d1.Depth2List { 174 if v := d2.ID; v != d1.ID*1000+int64(idx2+1) { 175 t.Errorf("unexpected: %v", v) 176 } 177 178 if v := len(d2.Depth3List); v != size { 179 t.Errorf("unexpected: %v", v) 180 } 181 for idx3, d3 := range d2.Depth3List { 182 if v := d3.ID; v != d2.ID*1000+int64(idx3+1) { 183 t.Errorf("unexpected: %v", v) 184 t.Errorf("unexpected: %v", d1.ID*1000000+d2.ID*1000+int64(idx3+1)) 185 t.Errorf("unexpected: %v", d1.ID) 186 t.Errorf("unexpected: %v", d2.ID) 187 } 188 if v := d3.Name; v != fmt.Sprintf("#%d", d3.ID) { 189 t.Errorf("unexpected: %v", v) 190 } 191 } 192 } 193 } 194 }