github.com/qiuhoude/go-web@v0.0.0-20220223060959-ab545e78f20d/algorithm/datastructures/tree/segmenttree/segmentTree.go (about) 1 package segmenttree 2 3 import ( 4 "fmt" 5 "strings" 6 ) 7 8 type MergeFunc func(l, r interface{}) interface{} 9 10 /* 11 [1, 6] 12 / \ 13 [1, 3] [4, 6] 14 / \ / \ 15 [1, 2] [3,3] [4, 5] [6,6] 16 / \ / \ 17 [1,1] [2,2] [4,4] [5,5] 18 19 */ 20 type SegmentTree struct { 21 data []interface{} // 线段数存储的数据 22 tree []interface{} 23 merge MergeFunc 24 } 25 26 func NewSegmentTree(data []interface{}, mf MergeFunc) *SegmentTree { 27 if data == nil || len(data) == 0 { 28 return nil 29 } 30 ret := &SegmentTree{ 31 data: data, 32 merge: mf, 33 } 34 // 线段数是非满二叉树, 通过等比数列求和公式可得 2^n - 1 ,n(n>=0)是层数,表示第n层前有2^n - 1个元素 35 // 4n就可以容纳所有线段树 36 size := 4 * len(data) 37 ret.tree = make([]interface{}, size, size) //直接申请这么多的空间 38 ret.buildSegmentTree(0, 0, len(data)-1) 39 return ret 40 } 41 42 // 构建线段树 43 func (st *SegmentTree) buildSegmentTree(treeIndex, l, r int) { 44 if l == r { 45 st.tree[treeIndex] = st.data[l] 46 return 47 } 48 lTreeIndex := leftChild(treeIndex) 49 rTreeIndex := rightChild(treeIndex) 50 51 mid := l + (r-l)/2 52 st.buildSegmentTree(lTreeIndex, l, mid) // 构建左边的 53 st.buildSegmentTree(rTreeIndex, mid+1, r) 54 55 st.tree[treeIndex] = st.merge(st.tree[lTreeIndex], st.tree[rTreeIndex]) 56 } 57 58 func (st *SegmentTree) Size() int { 59 if st == nil { 60 return 0 61 } 62 return len(st.data) 63 } 64 65 // 进行范围查询 66 func (st *SegmentTree) Query(queryL, queryR int) interface{} { 67 if queryL < 0 || queryL >= len(st.data) || 68 queryR < 0 || queryR >= len(st.data) || queryL > queryR { 69 return nil 70 } 71 return st.query(0, 0, st.Size()-1, queryL, queryR) 72 } 73 74 // 在以treeIndex为根的线段树中[l...r]的范围里,搜索区间[queryL...queryR]的值 75 func (st *SegmentTree) query(treeIndex, l, r, queryL, queryR int) interface{} { 76 if l == queryL && r == queryR { //直接找到到了 77 return st.tree[treeIndex] 78 } 79 lTreeIndex := leftChild(treeIndex) 80 rTreeIndex := rightChild(treeIndex) 81 mid := l + (r-l)/2 82 if queryL >= mid+1 { // 区间全部落在右边 83 return st.query(rTreeIndex, mid+1, r, queryL, queryR) 84 } else if queryR <= mid { //区间全部落在左边 85 return st.query(lTreeIndex, l, mid, queryL, queryR) 86 } else { // 左右各占一半 87 leftRes := st.query(lTreeIndex, l, mid, queryL, mid) 88 rightRes := st.query(rTreeIndex, mid+1, r, mid+1, queryR) 89 return st.merge(rightRes, leftRes) 90 } 91 } 92 93 // 设置某个位置的值,返回false表示设置失败 94 func (st *SegmentTree) Set(index int, e interface{}) bool { 95 if index < 0 || index >= st.Size() { 96 return false 97 } 98 st.data[index] = e 99 st.set(0, 0, st.Size()-1, index, e) 100 return true 101 } 102 103 func (st *SegmentTree) set(treeIndex, l, r, index int, e interface{}) { 104 if l == r { // 说明已经找到 105 st.tree[treeIndex] = e 106 return 107 } 108 lTreeIndex := leftChild(treeIndex) 109 rTreeIndex := rightChild(treeIndex) 110 mid := l + (r-l)/2 111 if index >= mid+1 { // 落在右边 112 st.set(rTreeIndex, mid+1, r, index, e) 113 } else { // index<mid 114 st.set(lTreeIndex, l, mid, index, e) 115 } 116 st.tree[treeIndex] = st.merge(st.tree[lTreeIndex], st.tree[rTreeIndex]) 117 } 118 119 // 左孩子下标(和二叉堆一样) 120 func leftChild(index int) int { 121 return index*2 + 1 122 } 123 124 // 右孩子下标 125 func rightChild(index int) int { 126 return index*2 + 2 127 } 128 129 // 获取当前层级 130 131 func (st *SegmentTree) String() string { 132 sb := strings.Builder{} 133 sb.WriteString("[") 134 //curDepth := 0 135 for i := 0; i < len(st.tree); i++ { 136 //d := depth(i) 137 //if d > curDepth { 138 // curDepth = d 139 // sb.WriteRune('\n') 140 //} 141 if st.tree[i] == nil { 142 sb.WriteString("nil") 143 } else { 144 sb.WriteString(fmt.Sprintf("%v", st.tree[i])) 145 } 146 if i != len(st.tree)-1 { 147 sb.WriteString(", ") 148 } 149 } 150 sb.WriteString("]") 151 return sb.String() 152 } 153 154 func depth(index int) int { 155 if index <= 0 { 156 return 0 157 } 158 i := 1 //1<<1 -1 // 公式 : i = 2^n -1 159 cnt := 0 160 for index >= i-1 { 161 i = i << 1 162 cnt++ 163 } 164 // cnt 表示第n层前, index的层数是 cnt-1 165 return cnt - 1 166 }