github.com/qiuhoude/go-web@v0.0.0-20220223060959-ab545e78f20d/algorithm/datastructures/tree/segmenttree/segmentTree.go (about)

     1  package segmenttree
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  )
     7  
     8  type MergeFunc func(l, r interface{}) interface{}
     9  
    10  /*
    11                 [1,  6]
    12               /        \
    13        [1,  3]           [4,  6]
    14        /     \           /     \
    15     [1, 2]  [3,3]     [4, 5]   [6,6]
    16     /    \           /     \
    17  [1,1]   [2,2]     [4,4]   [5,5]
    18  
    19  */
    20  type SegmentTree struct {
    21  	data  []interface{} // 线段数存储的数据
    22  	tree  []interface{}
    23  	merge MergeFunc
    24  }
    25  
    26  func NewSegmentTree(data []interface{}, mf MergeFunc) *SegmentTree {
    27  	if data == nil || len(data) == 0 {
    28  		return nil
    29  	}
    30  	ret := &SegmentTree{
    31  		data:  data,
    32  		merge: mf,
    33  	}
    34  	// 线段数是非满二叉树, 通过等比数列求和公式可得 2^n - 1 ,n(n>=0)是层数,表示第n层前有2^n - 1个元素
    35  	// 4n就可以容纳所有线段树
    36  	size := 4 * len(data)
    37  	ret.tree = make([]interface{}, size, size) //直接申请这么多的空间
    38  	ret.buildSegmentTree(0, 0, len(data)-1)
    39  	return ret
    40  }
    41  
    42  // 构建线段树
    43  func (st *SegmentTree) buildSegmentTree(treeIndex, l, r int) {
    44  	if l == r {
    45  		st.tree[treeIndex] = st.data[l]
    46  		return
    47  	}
    48  	lTreeIndex := leftChild(treeIndex)
    49  	rTreeIndex := rightChild(treeIndex)
    50  
    51  	mid := l + (r-l)/2
    52  	st.buildSegmentTree(lTreeIndex, l, mid) // 构建左边的
    53  	st.buildSegmentTree(rTreeIndex, mid+1, r)
    54  
    55  	st.tree[treeIndex] = st.merge(st.tree[lTreeIndex], st.tree[rTreeIndex])
    56  }
    57  
    58  func (st *SegmentTree) Size() int {
    59  	if st == nil {
    60  		return 0
    61  	}
    62  	return len(st.data)
    63  }
    64  
    65  // 进行范围查询
    66  func (st *SegmentTree) Query(queryL, queryR int) interface{} {
    67  	if queryL < 0 || queryL >= len(st.data) ||
    68  		queryR < 0 || queryR >= len(st.data) || queryL > queryR {
    69  		return nil
    70  	}
    71  	return st.query(0, 0, st.Size()-1, queryL, queryR)
    72  }
    73  
    74  // 在以treeIndex为根的线段树中[l...r]的范围里,搜索区间[queryL...queryR]的值
    75  func (st *SegmentTree) query(treeIndex, l, r, queryL, queryR int) interface{} {
    76  	if l == queryL && r == queryR { //直接找到到了
    77  		return st.tree[treeIndex]
    78  	}
    79  	lTreeIndex := leftChild(treeIndex)
    80  	rTreeIndex := rightChild(treeIndex)
    81  	mid := l + (r-l)/2
    82  	if queryL >= mid+1 { // 区间全部落在右边
    83  		return st.query(rTreeIndex, mid+1, r, queryL, queryR)
    84  	} else if queryR <= mid { //区间全部落在左边
    85  		return st.query(lTreeIndex, l, mid, queryL, queryR)
    86  	} else { // 左右各占一半
    87  		leftRes := st.query(lTreeIndex, l, mid, queryL, mid)
    88  		rightRes := st.query(rTreeIndex, mid+1, r, mid+1, queryR)
    89  		return st.merge(rightRes, leftRes)
    90  	}
    91  }
    92  
    93  // 设置某个位置的值,返回false表示设置失败
    94  func (st *SegmentTree) Set(index int, e interface{}) bool {
    95  	if index < 0 || index >= st.Size() {
    96  		return false
    97  	}
    98  	st.data[index] = e
    99  	st.set(0, 0, st.Size()-1, index, e)
   100  	return true
   101  }
   102  
   103  func (st *SegmentTree) set(treeIndex, l, r, index int, e interface{}) {
   104  	if l == r { // 说明已经找到
   105  		st.tree[treeIndex] = e
   106  		return
   107  	}
   108  	lTreeIndex := leftChild(treeIndex)
   109  	rTreeIndex := rightChild(treeIndex)
   110  	mid := l + (r-l)/2
   111  	if index >= mid+1 { // 落在右边
   112  		st.set(rTreeIndex, mid+1, r, index, e)
   113  	} else { // index<mid
   114  		st.set(lTreeIndex, l, mid, index, e)
   115  	}
   116  	st.tree[treeIndex] = st.merge(st.tree[lTreeIndex], st.tree[rTreeIndex])
   117  }
   118  
   119  // 左孩子下标(和二叉堆一样)
   120  func leftChild(index int) int {
   121  	return index*2 + 1
   122  }
   123  
   124  // 右孩子下标
   125  func rightChild(index int) int {
   126  	return index*2 + 2
   127  }
   128  
   129  // 获取当前层级
   130  
   131  func (st *SegmentTree) String() string {
   132  	sb := strings.Builder{}
   133  	sb.WriteString("[")
   134  	//curDepth := 0
   135  	for i := 0; i < len(st.tree); i++ {
   136  		//d := depth(i)
   137  		//if d > curDepth {
   138  		//	curDepth = d
   139  		//	sb.WriteRune('\n')
   140  		//}
   141  		if st.tree[i] == nil {
   142  			sb.WriteString("nil")
   143  		} else {
   144  			sb.WriteString(fmt.Sprintf("%v", st.tree[i]))
   145  		}
   146  		if i != len(st.tree)-1 {
   147  			sb.WriteString(", ")
   148  		}
   149  	}
   150  	sb.WriteString("]")
   151  	return sb.String()
   152  }
   153  
   154  func depth(index int) int {
   155  	if index <= 0 {
   156  		return 0
   157  	}
   158  	i := 1 //1<<1 -1 // 公式 : i = 2^n -1
   159  	cnt := 0
   160  	for index >= i-1 {
   161  		i = i << 1
   162  		cnt++
   163  	}
   164  	// cnt 表示第n层前, index的层数是 cnt-1
   165  	return cnt - 1
   166  }