github.com/qiuhoude/go-web@v0.0.0-20220223060959-ab545e78f20d/algorithm/leetcode/range-sum-query-immutable_test.go (about)

     1  package leetcode
     2  
     3  import (
     4  	"github.com/bmizerany/assert"
     5  	"testing"
     6  )
     7  
     8  // leetcode 307号问题
     9  // https://leetcode.com/problems/range-sum-query-mutable/description/
    10  // leetcode 303. 区域和检索 - 数组不可变
    11  // https://leetcode-cn.com/problems/range-sum-query-immutable/
    12  
    13  type NumArray struct {
    14  	nums []int
    15  	tree []int
    16  }
    17  
    18  func Constructor1(nums []int) NumArray {
    19  	n := len(nums)
    20  	ret := NumArray{
    21  		nums: nums,
    22  		tree: make([]int, 4*n),
    23  	}
    24  	if n > 0 {
    25  		ret.buildTree(0, 0, n-1)
    26  	}
    27  	return ret
    28  }
    29  
    30  func (this *NumArray) query(i, l, r, ql, qr int) int {
    31  	if r < 0 {
    32  		return -1
    33  	}
    34  	if l == ql && r == qr {
    35  		return this.tree[i]
    36  	}
    37  	// 分3段
    38  	mid := l + (r-l)>>1
    39  	lci := i*2 + 1 // left child index
    40  	rci := i*2 + 2 // right child index
    41  	if qr <= mid { // 区域范围再左边
    42  		return this.query(lci, l, mid, ql, qr)
    43  	} else if ql >= mid+1 { // 区域范围全在右边
    44  		return this.query(rci, mid+1, r, ql, qr)
    45  	} else { // 两边个一半
    46  		sumL := this.query(lci, l, mid, ql, mid)
    47  		sumR := this.query(rci, mid+1, r, mid+1, qr)
    48  		return sumL + sumR
    49  	}
    50  
    51  }
    52  
    53  func (this *NumArray) Update(i int, val int) {
    54  	n := len(this.nums)
    55  	if i < 0 || i >= n {
    56  		return
    57  	}
    58  	this.nums[i] = val // 修改原数组
    59  	this.set(i, val, 0, 0, len(this.nums)-1)
    60  
    61  }
    62  
    63  func (this *NumArray) set(ni int, val int, i, li, ri int) {
    64  	if li == ri {
    65  		this.tree[i] = val
    66  		return
    67  	}
    68  	mid := li + (ri-li)>>1
    69  	lci := i*2 + 1 // left child index
    70  	rci := i*2 + 2 // right child index
    71  	if ni <= mid {
    72  		this.set(ni, val, lci, li, mid)
    73  	} else if ni > mid {
    74  		this.set(ni, val, rci, mid+1, ri)
    75  	}
    76  	this.tree[i] = this.tree[lci] + this.tree[rci]
    77  }
    78  
    79  func (this *NumArray) buildTree(i, li, ri int) {
    80  	if li == ri {
    81  		this.tree[i] = this.nums[li]
    82  		return
    83  	}
    84  	mid := li + (ri-li)>>1
    85  
    86  	lci := i*2 + 1 // left child index
    87  	rci := i*2 + 2 // right child index
    88  	this.buildTree(lci, li, mid)
    89  	this.buildTree(rci, mid+1, ri)
    90  	this.tree[i] = this.tree[lci] + this.tree[rci] // 中间位置的线段树
    91  
    92  }
    93  
    94  func (this *NumArray) SumRange(i int, j int) int {
    95  	if i > j {
    96  		return -1
    97  	}
    98  	return this.query(0, 0, len(this.nums)-1, i, j)
    99  }
   100  
   101  func TestSumRange(t *testing.T) {
   102  	nums := []int{-2, 0, 3, -5, 2, -1}
   103  	obj := Constructor1(nums)
   104  	assert.Equal(t, obj.SumRange(0, 2), 1)
   105  	assert.Equal(t, obj.SumRange(2, 5), -1)
   106  	assert.Equal(t, obj.SumRange(0, 5), -3)
   107  }
   108  
   109  /**
   110   * Your NumArray object will be instantiated and called as such:
   111   * obj := Constructor(nums);
   112   * param_1 := obj.SumRange(i,j);
   113   */