go.mercari.io/datastore@v1.8.2/batch.go (about)

     1  package datastore
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  )
     7  
     8  // Batch can queue operations on Datastore and process them in batch.
     9  // Batch does nothing until you call Exec().
    10  // This helps to reduce the number of RPCs.
    11  type Batch struct {
    12  	Client Client
    13  
    14  	m      sync.Mutex
    15  	put    batchPut
    16  	get    batchGet
    17  	delete batchDelete
    18  }
    19  
    20  // BatchPutHandler represents Entity's individual callback when batching Put processing.
    21  type BatchPutHandler func(key Key, err error) error
    22  
    23  // BatchErrHandler represents Entity's individual callback when batching non-Put processing.
    24  type BatchErrHandler func(err error) error
    25  
    26  type batchPut struct {
    27  	m    sync.Mutex
    28  	keys []Key
    29  	srcs []interface{}
    30  	hs   []BatchPutHandler
    31  }
    32  
    33  type batchGet struct {
    34  	m    sync.Mutex
    35  	keys []Key
    36  	dsts []interface{}
    37  	hs   []BatchErrHandler
    38  }
    39  
    40  type batchDelete struct {
    41  	m    sync.Mutex
    42  	keys []Key
    43  	hs   []BatchErrHandler
    44  }
    45  
    46  // Put Entity operation into the queue.
    47  // This operation doesn't Put to Datastore immediately.
    48  // If a h is provided, it passes the processing result to the handler, and treats the return value as the value of the result of Putting.
    49  func (b *Batch) Put(key Key, src interface{}, h BatchPutHandler) {
    50  	b.put.Put(key, src, h)
    51  }
    52  
    53  // Get Entity operation into the queue.
    54  func (b *Batch) Get(key Key, dst interface{}, h BatchErrHandler) {
    55  	b.get.Get(key, dst, h)
    56  }
    57  
    58  // Delete Entity operation into the queue.
    59  func (b *Batch) Delete(key Key, h BatchErrHandler) {
    60  	b.delete.Delete(key, h)
    61  }
    62  
    63  // Exec will perform all the processing that was queued.
    64  // This process is done recursively until the queue is empty.
    65  // The return value may be MultiError, but the order of contents is not guaranteed.
    66  func (b *Batch) Exec(ctx context.Context) error {
    67  	// batch#Exec でロックを取る理由
    68  	// 次のようなシチュエーションで問題になる… 可能性がある
    69  	//
    70  	// 同一 *Batch に対して並列に動くジョブがあるとする。
    71  	// ジョブAがGet+error handlerを登録する
    72  	// ジョブBがGet+error handlerを登録する
    73  	// ジョブAがExecする 上記2つの操作が実行される 処理には少し時間がかかる
    74  	// ジョブBがExecする キューには何もないので高速に終了する ジョブAのExecは終わっていない
    75  	// ジョブBのGet+error handlerはまだ発火していないがジョブBはエラー無しとして処理を終了する
    76  	//
    77  	// 解決策は2種類ある
    78  	//   1. ここで行われている実装のように、ジョブがExecしている時は別ジョブのExecを待たせる
    79  	//   2. 呼び出し側でerror handlerが終わったことを sync.WaitGroup などを使って確定させる
    80  	//
    81  	// ここでは、 "Execしたら処理は全て終わっている" というモデルを維持するため 解決策1 を採用する
    82  	// 弊害として、Execがエラーを返さなかったからといってジョブが成功したとは限らなくなるということである
    83  	// 対策として、error handlerを使ったハンドリングを適切にやるか、同一の *Batch を使わない方法がある
    84  
    85  	b.m.Lock()
    86  	defer b.m.Unlock()
    87  
    88  	return b.exec(ctx)
    89  }
    90  
    91  func (b *Batch) exec(ctx context.Context) error {
    92  	var wg sync.WaitGroup
    93  	var errors []error
    94  	var m sync.Mutex
    95  	wg.Add(3)
    96  
    97  	go func() {
    98  		defer wg.Done()
    99  		errs := b.put.Exec(ctx, b.Client)
   100  		if len(errs) != 0 {
   101  			m.Lock()
   102  			errors = append(errors, errs...)
   103  			m.Unlock()
   104  		}
   105  	}()
   106  	go func() {
   107  		defer wg.Done()
   108  		errs := b.get.Exec(ctx, b.Client)
   109  		if len(errs) != 0 {
   110  			m.Lock()
   111  			errors = append(errors, errs...)
   112  			m.Unlock()
   113  		}
   114  	}()
   115  	go func() {
   116  		defer wg.Done()
   117  		errs := b.delete.Exec(ctx, b.Client)
   118  		if len(errs) != 0 {
   119  			m.Lock()
   120  			errors = append(errors, errs...)
   121  			m.Unlock()
   122  		}
   123  	}()
   124  
   125  	wg.Wait()
   126  
   127  	if len(errors) != 0 {
   128  		return MultiError(errors)
   129  	}
   130  
   131  	// Batch操作した後PropertyLoadSaverなどで追加のBatch操作が積まれたらそれがなくなるまで処理する
   132  	if len(b.put.keys) != 0 || len(b.get.keys) != 0 || len(b.delete.keys) != 0 {
   133  		return b.exec(ctx)
   134  	}
   135  
   136  	return nil
   137  }
   138  
   139  func (b *batchPut) Put(key Key, src interface{}, h BatchPutHandler) {
   140  	b.m.Lock()
   141  	defer b.m.Unlock()
   142  
   143  	b.keys = append(b.keys, key)
   144  	b.srcs = append(b.srcs, src)
   145  	b.hs = append(b.hs, h)
   146  }
   147  
   148  func (b *batchPut) Exec(ctx context.Context, client Client) []error {
   149  	if len(b.keys) == 0 {
   150  		return nil
   151  	}
   152  
   153  	b.m.Lock()
   154  	keys := b.keys
   155  	srcs := b.srcs
   156  	hs := b.hs
   157  	b.keys = nil
   158  	b.srcs = nil
   159  	b.hs = nil
   160  	b.m.Unlock()
   161  
   162  	newKeys, err := client.PutMulti(ctx, keys, srcs)
   163  
   164  	if merr, ok := err.(MultiError); ok {
   165  		trimmedError := make([]error, 0, len(merr))
   166  		for idx, err := range merr {
   167  			h := hs[idx]
   168  			if h != nil {
   169  				err = h(newKeys[idx], err)
   170  			}
   171  			if err != nil {
   172  				trimmedError = append(trimmedError, err)
   173  			}
   174  		}
   175  		return trimmedError
   176  	} else if err != nil {
   177  		for _, h := range hs {
   178  			if h != nil {
   179  				h(nil, err)
   180  			}
   181  		}
   182  		return []error{err}
   183  	}
   184  
   185  	errs := make([]error, 0, len(newKeys))
   186  	for idx, newKey := range newKeys {
   187  		h := hs[idx]
   188  		if h != nil {
   189  			err := h(newKey, nil)
   190  			if err != nil {
   191  				errs = append(errs, err)
   192  			}
   193  		}
   194  	}
   195  
   196  	if len(errs) != 0 {
   197  		return errs
   198  	}
   199  
   200  	return nil
   201  }
   202  
   203  func (b *batchGet) Get(key Key, dst interface{}, h BatchErrHandler) {
   204  	b.m.Lock()
   205  	defer b.m.Unlock()
   206  
   207  	b.keys = append(b.keys, key)
   208  	b.dsts = append(b.dsts, dst)
   209  	b.hs = append(b.hs, h)
   210  }
   211  
   212  func (b *batchGet) Exec(ctx context.Context, client Client) []error {
   213  	if len(b.keys) == 0 {
   214  		return nil
   215  	}
   216  
   217  	b.m.Lock()
   218  	keys := b.keys
   219  	dsts := b.dsts
   220  	hs := b.hs
   221  	b.keys = nil
   222  	b.dsts = nil
   223  	b.hs = nil
   224  	b.m.Unlock()
   225  
   226  	err := client.GetMulti(ctx, keys, dsts)
   227  
   228  	if merr, ok := err.(MultiError); ok {
   229  		trimmedError := make([]error, 0, len(merr))
   230  		for idx, err := range merr {
   231  			h := hs[idx]
   232  			if h != nil {
   233  				err = h(err)
   234  			}
   235  			if err != nil {
   236  				trimmedError = append(trimmedError, err)
   237  			}
   238  		}
   239  		return trimmedError
   240  	} else if err != nil {
   241  		for _, h := range hs {
   242  			if h != nil {
   243  				h(err)
   244  			}
   245  		}
   246  		return []error{err}
   247  	}
   248  
   249  	errs := make([]error, 0, len(hs))
   250  	for _, h := range hs {
   251  		if h != nil {
   252  			err := h(nil)
   253  			if err != nil {
   254  				errs = append(errs, err)
   255  			}
   256  		}
   257  	}
   258  
   259  	if len(errs) != 0 {
   260  		return errs
   261  	}
   262  
   263  	return nil
   264  }
   265  
   266  func (b *batchDelete) Delete(key Key, h BatchErrHandler) {
   267  	b.m.Lock()
   268  	defer b.m.Unlock()
   269  
   270  	b.keys = append(b.keys, key)
   271  	b.hs = append(b.hs, h)
   272  }
   273  
   274  func (b *batchDelete) Exec(ctx context.Context, client Client) []error {
   275  	if len(b.keys) == 0 {
   276  		return nil
   277  	}
   278  
   279  	b.m.Lock()
   280  	keys := b.keys
   281  	hs := b.hs
   282  	b.keys = nil
   283  	b.hs = nil
   284  	b.m.Unlock()
   285  
   286  	err := client.DeleteMulti(ctx, keys)
   287  
   288  	if merr, ok := err.(MultiError); ok {
   289  		trimmedError := make([]error, 0, len(merr))
   290  		for idx, err := range merr {
   291  			h := hs[idx]
   292  			if h != nil {
   293  				err = h(err)
   294  			}
   295  			if err != nil {
   296  				trimmedError = append(trimmedError, err)
   297  			}
   298  		}
   299  		return trimmedError
   300  	} else if err != nil {
   301  		for _, h := range hs {
   302  			if h != nil {
   303  				h(err)
   304  			}
   305  		}
   306  		return []error{err}
   307  	}
   308  
   309  	errs := make([]error, 0, len(hs))
   310  	for _, h := range hs {
   311  		if h != nil {
   312  			err := h(nil)
   313  			if err != nil {
   314  				errs = append(errs, err)
   315  			}
   316  		}
   317  	}
   318  
   319  	if len(errs) != 0 {
   320  		return errs
   321  	}
   322  
   323  	return nil
   324  }