go.mercari.io/datastore@v1.8.2/testsuite/realworld/recursivebatch/main.go (about)

     1  package recursivebatch
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"testing"
     7  
     8  	"go.mercari.io/datastore"
     9  	"go.mercari.io/datastore/boom"
    10  	"go.mercari.io/datastore/testsuite"
    11  )
    12  
    13  // TestSuite contains all the test cases that this package provides.
    14  var TestSuite = map[string]testsuite.Test{
    15  	"RealWorld_RecursiveBatch": recursiveBatch,
    16  }
    17  
    18  func init() {
    19  	testsuite.MergeTestSuite(TestSuite)
    20  }
    21  
    22  var _ datastore.PropertyLoadSaver = &depth1{}
    23  var _ datastore.PropertyLoadSaver = &depth2{}
    24  
    25  type depth1 struct {
    26  	ID         int64     `boom:"id"`
    27  	Depth2IDs  []int64   `json:"-"`
    28  	Depth2List []*depth2 `datastore:"-"`
    29  }
    30  
    31  type depth2 struct {
    32  	ID         int64     `boom:"id"`
    33  	Depth3IDs  []int64   `json:"-"`
    34  	Depth3List []*depth3 `datastore:"-"`
    35  }
    36  
    37  type depth3 struct {
    38  	ID   int64  `boom:"id"`
    39  	Name string ``
    40  }
    41  
    42  func (d *depth1) Load(ctx context.Context, ps []datastore.Property) error {
    43  	err := datastore.LoadStruct(ctx, d, ps)
    44  	if err != nil {
    45  		return err
    46  	}
    47  
    48  	bt := extractBoomBatch(ctx)
    49  
    50  	d.Depth2List = make([]*depth2, 0, len(d.Depth2IDs))
    51  	for _, depth2ID := range d.Depth2IDs {
    52  		d2 := &depth2{
    53  			ID: depth2ID,
    54  		}
    55  		bt.Get(d2, nil)
    56  		d.Depth2List = append(d.Depth2List, d2)
    57  	}
    58  
    59  	return nil
    60  }
    61  
    62  func (d *depth1) Save(ctx context.Context) ([]datastore.Property, error) {
    63  	d.Depth2IDs = make([]int64, 0, len(d.Depth2List))
    64  	for _, d2 := range d.Depth2List {
    65  		d.Depth2IDs = append(d.Depth2IDs, d2.ID)
    66  	}
    67  
    68  	return datastore.SaveStruct(ctx, d)
    69  }
    70  
    71  func (d *depth2) Load(ctx context.Context, ps []datastore.Property) error {
    72  	err := datastore.LoadStruct(ctx, d, ps)
    73  	if err != nil {
    74  		return err
    75  	}
    76  
    77  	bt := extractBoomBatch(ctx)
    78  
    79  	d.Depth3List = make([]*depth3, 0, len(d.Depth3IDs))
    80  	for _, depth3ID := range d.Depth3IDs {
    81  		d3 := &depth3{
    82  			ID: depth3ID,
    83  		}
    84  		bt.Get(d3, nil)
    85  		d.Depth3List = append(d.Depth3List, d3)
    86  	}
    87  
    88  	return nil
    89  }
    90  
    91  func (d *depth2) Save(ctx context.Context) ([]datastore.Property, error) {
    92  	d.Depth3IDs = make([]int64, 0, len(d.Depth3List))
    93  	for _, d3 := range d.Depth3List {
    94  		d.Depth3IDs = append(d.Depth3IDs, d3.ID)
    95  	}
    96  
    97  	return datastore.SaveStruct(ctx, d)
    98  }
    99  
   100  type contextBoomBatch struct{}
   101  
   102  func extractBoomBatch(ctx context.Context) *boom.Batch {
   103  	return ctx.Value(contextBoomBatch{}).(*boom.Batch)
   104  }
   105  
   106  func recursiveBatch(ctx context.Context, t *testing.T, client datastore.Client) {
   107  	defer func() {
   108  		err := client.Close()
   109  		if err != nil {
   110  			t.Fatal(err)
   111  		}
   112  	}()
   113  
   114  	bm := boom.FromClient(ctx, client)
   115  	bt := bm.Batch()
   116  	ctx = context.WithValue(ctx, contextBoomBatch{}, bt)
   117  	bm.Context = ctx
   118  
   119  	const size = 5
   120  
   121  	// make test data
   122  	for i := 1; i <= size; i++ {
   123  		d1 := &depth1{
   124  			ID: int64(i),
   125  		}
   126  		for j := 1; j <= size; j++ {
   127  			d2 := &depth2{
   128  				ID: int64(i*1000 + j),
   129  			}
   130  			for k := 1; k <= size; k++ {
   131  				d3 := &depth3{
   132  					ID:   int64(i*1000000 + j*1000 + k),
   133  					Name: fmt.Sprintf("#%d", i*1000000+j*1000+k),
   134  				}
   135  				bt.Put(d3, nil)
   136  				d2.Depth3List = append(d2.Depth3List, d3)
   137  			}
   138  			bt.Put(d2, nil)
   139  			d1.Depth2List = append(d1.Depth2List, d2)
   140  		}
   141  		bt.Put(d1, nil)
   142  	}
   143  	err := bt.Exec()
   144  	if err != nil {
   145  		t.Fatal(err)
   146  	}
   147  
   148  	// get test data
   149  	list := make([]*depth1, 0, size)
   150  	for i := 1; i <= size; i++ {
   151  		d1 := &depth1{
   152  			ID: int64(i),
   153  		}
   154  		bt.Get(d1, nil)
   155  		list = append(list, d1)
   156  	}
   157  	err = bt.Exec()
   158  	if err != nil {
   159  		t.Fatal(err)
   160  	}
   161  
   162  	if v := len(list); v != size {
   163  		t.Errorf("unexpected: %v", v)
   164  	}
   165  	for idx1, d1 := range list {
   166  		if v := d1.ID; v != int64(idx1+1) {
   167  			t.Errorf("unexpected: %v", v)
   168  		}
   169  
   170  		if v := len(d1.Depth2List); v != size {
   171  			t.Errorf("unexpected: %v", v)
   172  		}
   173  		for idx2, d2 := range d1.Depth2List {
   174  			if v := d2.ID; v != d1.ID*1000+int64(idx2+1) {
   175  				t.Errorf("unexpected: %v", v)
   176  			}
   177  
   178  			if v := len(d2.Depth3List); v != size {
   179  				t.Errorf("unexpected: %v", v)
   180  			}
   181  			for idx3, d3 := range d2.Depth3List {
   182  				if v := d3.ID; v != d2.ID*1000+int64(idx3+1) {
   183  					t.Errorf("unexpected: %v", v)
   184  					t.Errorf("unexpected: %v", d1.ID*1000000+d2.ID*1000+int64(idx3+1))
   185  					t.Errorf("unexpected: %v", d1.ID)
   186  					t.Errorf("unexpected: %v", d2.ID)
   187  				}
   188  				if v := d3.Name; v != fmt.Sprintf("#%d", d3.ID) {
   189  					t.Errorf("unexpected: %v", v)
   190  				}
   191  			}
   192  		}
   193  	}
   194  }