github.com/insolar/vanilla@v0.0.0-20201023172447-248fdf805322/longbits/bit_builder.go (about)

     1  // Copyright 2020 Insolar Network Ltd.
     2  // All rights reserved.
     3  // This material is licensed under the Insolar License version 1.0,
     4  // available at https://github.com/insolar/assured-ledger/blob/master/LICENSE.md.
     5  
     6  package longbits
     7  
     8  import (
     9  	"math/bits"
    10  
    11  	"github.com/insolar/vanilla/throw"
    12  )
    13  
    14  type BitBuilderOrder byte
    15  
    16  const (
    17  	// Least significant bit is first - first AppendBit() appends the least significant bit
    18  	LSB BitBuilderOrder = 0
    19  	// Most significant bit is first - first AppendBit() appends the most significant bit
    20  	MSB BitBuilderOrder = 1
    21  
    22  	initLSB = 0x01
    23  	initMSB = 0x80
    24  )
    25  
    26  func NewBitBuilder(direction BitBuilderOrder, expectedByteLen int) BitBuilder {
    27  	return AppendBitBuilder(make([]byte, 0, expectedByteLen), direction)
    28  }
    29  
    30  func AppendBitBuilder(appendTo []byte, direction BitBuilderOrder) BitBuilder {
    31  	switch direction {
    32  	case LSB:
    33  		return BitBuilder{accInit: initLSB, accBit: initLSB, bytes: appendTo}
    34  	case MSB:
    35  		return BitBuilder{accInit: initMSB, accBit: initMSB, bytes: appendTo}
    36  	default:
    37  		panic("illegal value")
    38  	}
    39  }
    40  
    41  // var _ IndexedBits = &BitBuilder{} // TODO support IndexedBits
    42  
    43  // supports to be created as BitBuilder{} - it equals NewBitBuilder(LSB, 0)
    44  type BitBuilder struct {
    45  	bytes       []byte
    46  	accumulator byte
    47  	accInit     byte
    48  	accBit      byte
    49  }
    50  
    51  func (p *BitBuilder) IsZero() bool {
    52  	return p.accInit == 0
    53  }
    54  
    55  func (p *BitBuilder) BitLen() int {
    56  	return len(p.bytes)<<3 + int(p.AlignOffset())
    57  }
    58  
    59  func (p *BitBuilder) ensure() {
    60  	if p.accInit == 0 {
    61  		if len(p.bytes) != 0 {
    62  			panic("illegal state")
    63  		}
    64  		p.accInit = initLSB
    65  		p.accBit = initLSB
    66  	} else if p.accBit == 0 {
    67  		panic("illegal state")
    68  	}
    69  }
    70  
    71  func (p *BitBuilder) AppendAlignedByte(b byte) {
    72  	p.ensure()
    73  	if p.accBit != p.accInit {
    74  		panic("illegal state")
    75  	}
    76  	if p._rightShift() {
    77  		b = bits.Reverse8(b)
    78  	}
    79  	p.bytes = append(p.bytes, b)
    80  }
    81  
    82  func shiftLeft(b, n byte) byte {
    83  	return b << n
    84  }
    85  
    86  func shiftRight(b, n byte) byte {
    87  	return b >> n
    88  }
    89  
    90  // nolint:unused
    91  func (p *BitBuilder) _align(rightShift bool) uint8 {
    92  	switch {
    93  	case p.accBit == p.accInit:
    94  		return 0
    95  	case rightShift:
    96  		return uint8(bits.LeadingZeros8(p.accBit))
    97  	default:
    98  		return uint8(bits.TrailingZeros8(p.accBit))
    99  	}
   100  }
   101  
   102  // nolint:unused
   103  func (p *BitBuilder) align() (rightShift bool, ofs uint8) {
   104  	switch rightShift := p._rightShift(); {
   105  	case p.accBit == p.accInit:
   106  		return rightShift, 0
   107  	case rightShift:
   108  		return true, uint8(bits.LeadingZeros8(p.accBit))
   109  	default:
   110  		return false, uint8(bits.TrailingZeros8(p.accBit))
   111  	}
   112  }
   113  
   114  func (p *BitBuilder) _rightShift() bool {
   115  	switch {
   116  	case p.accInit == initLSB:
   117  		return false
   118  	case p.accInit == initMSB:
   119  		return true
   120  	default:
   121  		panic("illegal state")
   122  	}
   123  }
   124  
   125  func shifters(rightShift bool) (normFn, revFn func(byte, byte) byte) {
   126  	if rightShift {
   127  		return shiftRight, shiftLeft
   128  	}
   129  	return shiftLeft, shiftRight
   130  }
   131  
   132  func (p *BitBuilder) AlignOffset() uint8 {
   133  	_, ofs := p.align()
   134  	return ofs
   135  }
   136  
   137  func (p *BitBuilder) CompletedByteCount() int {
   138  	return len(p.bytes)
   139  }
   140  
   141  func (p *BitBuilder) PadWithBit(bit int) {
   142  	p.PadWith(bit != 0)
   143  }
   144  
   145  func (p *BitBuilder) PadWith(bit bool) {
   146  	if bit {
   147  		p.appendN1(-1)
   148  	}
   149  	p.appendN0(-1)
   150  }
   151  
   152  func (p *BitBuilder) AppendBit(bit int) {
   153  	p.Append(bit != 0)
   154  }
   155  
   156  func (p *BitBuilder) Append(bit bool) {
   157  	p.ensure()
   158  
   159  	if bit {
   160  		p.accumulator |= p.accBit
   161  	}
   162  
   163  	if p._rightShift() {
   164  		p.accBit >>= 1
   165  	} else {
   166  		p.accBit <<= 1
   167  	}
   168  
   169  	if p.accBit == 0 {
   170  		p.bytes = append(p.bytes, p.accumulator)
   171  		p.accumulator = 0
   172  		p.accBit = p.accInit
   173  	}
   174  }
   175  
   176  func (p *BitBuilder) AppendSubByte(value byte, bitLen uint8) {
   177  	if bitLen >= 8 {
   178  		if bitLen != 8 {
   179  			panic("illegal value")
   180  		}
   181  		p.AppendByte(value)
   182  		return
   183  	}
   184  	switch bitLen {
   185  	case 0:
   186  		return
   187  	case 1:
   188  		p.Append(value&1 != 0)
   189  		return
   190  	}
   191  
   192  	p.ensure()
   193  	rightShift, usedCount := p.align()
   194  	normFn, revFn := shifters(rightShift)
   195  	if rightShift {
   196  		value = bits.Reverse8(value)
   197  	}
   198  
   199  	value &= revFn(0xFF, 8-bitLen)
   200  
   201  	remainCount := 8 - usedCount
   202  	switch {
   203  	case usedCount == 0:
   204  		p.accBit = normFn(p.accBit, bitLen)
   205  		p.accumulator = value
   206  		return
   207  	case remainCount > bitLen:
   208  		p.accBit = normFn(p.accBit, bitLen)
   209  		p.accumulator |= normFn(value, usedCount)
   210  		return
   211  	default:
   212  		p.accumulator |= normFn(value, usedCount)
   213  		bitLen -= remainCount
   214  	}
   215  
   216  	p.bytes = append(p.bytes, p.accumulator)
   217  	p.accBit = p.accInit
   218  	if bitLen == 0 {
   219  		p.accumulator = 0
   220  		return
   221  	}
   222  	p.accBit = normFn(p.accBit, bitLen)
   223  	p.accumulator = revFn(value, remainCount)
   224  }
   225  
   226  func (p *BitBuilder) AppendNBit(bitCount int, bit int) {
   227  	p.AppendN(bitCount, bit != 0)
   228  }
   229  
   230  func (p *BitBuilder) AppendN(bitCount int, bit bool) {
   231  	p.ensure()
   232  	switch {
   233  	case bitCount == 0:
   234  	case bitCount == 1:
   235  		p.Append(bit)
   236  	case bitCount < 0:
   237  		panic("invalid bitCount value")
   238  	case bit:
   239  		p.appendN1(bitCount)
   240  	default:
   241  		p.appendN0(bitCount)
   242  	}
   243  }
   244  
   245  func (p *BitBuilder) appendN0(bitCount int) {
   246  	p.ensure()
   247  
   248  	rightShift, usedCount := p.align()
   249  	normFn, _ := shifters(rightShift)
   250  
   251  	if usedCount == 0 {
   252  		if bitCount < 0 {
   253  			return
   254  		}
   255  	} else {
   256  		switch {
   257  		case bitCount < 0:
   258  			bitCount = 0
   259  		default:
   260  			alignCount := 8 - int(usedCount)
   261  			if alignCount > bitCount {
   262  				p.accBit = normFn(p.accBit, uint8(bitCount))
   263  				return
   264  			}
   265  			bitCount -= alignCount
   266  		}
   267  		p.bytes = append(p.bytes, p.accumulator)
   268  		p.accumulator = 0
   269  		p.accBit = p.accInit
   270  		if bitCount == 0 {
   271  			return
   272  		}
   273  	}
   274  
   275  	if alignCount := uint8(bitCount) & 0x7; alignCount > 0 {
   276  		p.accBit = normFn(p.accBit, alignCount)
   277  	}
   278  	if byteCount := bitCount >> 3; byteCount > 0 {
   279  		p.bytes = append(p.bytes, make([]byte, byteCount)...)
   280  	}
   281  }
   282  
   283  func (p *BitBuilder) appendN1(bitCount int) {
   284  	p.ensure()
   285  
   286  	rightShift, usedCount := p.align()
   287  	normFn, revFn := shifters(rightShift)
   288  
   289  	if usedCount == 0 {
   290  		if bitCount < 0 {
   291  			return
   292  		}
   293  	} else {
   294  		switch {
   295  		case bitCount < 0:
   296  			bitCount = 0
   297  		default:
   298  			alignCount := 8 - int(usedCount)
   299  			if alignCount > bitCount {
   300  				filler := revFn(0xFF, uint8(alignCount-bitCount)) // make some zeros
   301  				p.accumulator |= normFn(filler, usedCount)
   302  				p.accBit = normFn(p.accBit, uint8(bitCount))
   303  				return
   304  			}
   305  			bitCount -= alignCount
   306  		}
   307  		p.accumulator |= normFn(0xFF, usedCount)
   308  		p.bytes = append(p.bytes, p.accumulator)
   309  		p.accumulator = 0
   310  		p.accBit = p.accInit
   311  		if bitCount == 0 {
   312  			return
   313  		}
   314  	}
   315  
   316  	if alignCount := uint8(bitCount) & 0x7; alignCount > 0 {
   317  		p.accBit = normFn(p.accBit, alignCount)
   318  		p.accumulator = revFn(0xFF, 8-alignCount)
   319  	}
   320  
   321  	if byteCount := bitCount >> 3; byteCount > 0 {
   322  		lowBound := len(p.bytes)
   323  		p.bytes = append(p.bytes, make([]byte, byteCount)...)
   324  		for i := len(p.bytes) - 1; i >= lowBound; i-- {
   325  			p.bytes[i] = 0xFF
   326  		}
   327  	}
   328  }
   329  
   330  func (p *BitBuilder) ToggleBit(index int) bool {
   331  	rightShift, _ := p.align()
   332  	normFn, _ := shifters(rightShift)
   333  
   334  	byteIndex, bitIndex := BitPos(index)
   335  	mask := normFn(1, bitIndex)
   336  
   337  	b := p.bytes[byteIndex] ^ mask
   338  	p.bytes[byteIndex] = b
   339  
   340  	return b & mask != 0
   341  }
   342  
   343  func (p *BitBuilder) SetBit(index, bit int) {
   344  	p.Set(index, bit != 0, false)
   345  }
   346  
   347  func (p *BitBuilder) Set(index int, bit, padding bool) {
   348  	p.ensure()
   349  	if index < 0 {
   350  		panic(throw.IllegalValue())
   351  	}
   352  
   353  	switch d := index - p.BitLen(); {
   354  	case d < 0:
   355  		rightShift, _ := p.align()
   356  		normFn, _ := shifters(rightShift)
   357  
   358  		byteIndex, bitIndex := BitPos(index)
   359  		mask := normFn(1, bitIndex)
   360  
   361  		var pb *byte
   362  		if byteIndex == len(p.bytes) {
   363  			pb = &p.accumulator
   364  		} else {
   365  			pb = &p.bytes[byteIndex]
   366  		}
   367  		if bit {
   368  			*pb |= mask
   369  		} else {
   370  			*pb &^= mask
   371  		}
   372  		return
   373  	case d > 0:
   374  		if padding {
   375  			p.appendN1(d)
   376  		} else {
   377  			p.appendN0(d)
   378  		}
   379  	}
   380  	p.Append(bit)
   381  }
   382  
   383  func (p *BitBuilder) AppendByte(b byte) {
   384  	p.ensure()
   385  
   386  	rightShift, usedCount := p.align()
   387  	normFn, revFn := shifters(rightShift)
   388  
   389  	if rightShift {
   390  		b = bits.Reverse8(b)
   391  	}
   392  
   393  	if usedCount == 0 {
   394  		p.bytes = append(p.bytes, b)
   395  		return
   396  	}
   397  	nextByte := p.accumulator | normFn(b, usedCount)
   398  	p.bytes = append(p.bytes, nextByte)
   399  
   400  	p.accumulator = revFn(b, 8-usedCount)
   401  }
   402  
   403  func (p *BitBuilder) dump() []byte { // nolint:unused
   404  	_, usedCount := p.align()
   405  
   406  	bytes := append(make([]byte, 0, cap(p.bytes)), p.bytes...)
   407  	if usedCount > 0 {
   408  		bytes = append(bytes, p.accumulator)
   409  	}
   410  	return bytes
   411  }
   412  
   413  func (p *BitBuilder) Done() (b []byte, bitLen int) {
   414  	_, usedCount := p.align()
   415  
   416  	bytes := p.bytes
   417  	p.bytes = nil
   418  	if usedCount > 0 {
   419  		bytes = append(bytes, p.accumulator)
   420  		p.accumulator = 0
   421  		p.accBit = p.accInit
   422  		return bytes, (len(p.bytes)-1)<<3 + int(usedCount)
   423  	}
   424  	return bytes, len(p.bytes) << 3
   425  }
   426  
   427  func (p *BitBuilder) TrimZeros() (skippedPrefix int, b []byte) {
   428  	sb := p.bytes
   429  
   430  	for ;skippedPrefix < len(sb) && sb[skippedPrefix] == 0; skippedPrefix++ {}
   431  
   432  	sb = sb[skippedPrefix:]
   433  
   434  	if p.accumulator != 0 {
   435  		b = make([]byte, 0, len(sb) + 1)
   436  		b = append(b, sb...)
   437  		b = append(b, p.accumulator)
   438  		return
   439  	}
   440  
   441  	if len(sb) == 0 {
   442  		return 0, nil
   443  	}
   444  
   445  	tailIndex := len(sb) - 1
   446  	for ;tailIndex > 0 && sb[tailIndex] == 0; tailIndex-- {}
   447  	b = append([]byte(nil), sb[:tailIndex + 1]...)
   448  	return
   449  }
   450  
   451  func (p *BitBuilder) DoneToBytes() []byte {
   452  	b, _ := p.Done()
   453  	return b
   454  }
   455  
   456  func (p *BitBuilder) DoneToBits() BitSlice {
   457  	b, _ := p.Done()
   458  	return b
   459  }
   460  
   461  func (p *BitBuilder) DoneToByteString() (ByteString, int) {
   462  	b, l := p.Done()
   463  	return CopyBytes(b), l
   464  }
   465  
   466  func (p *BitBuilder) Copy() *BitBuilder {
   467  	c := *p
   468  	if p.bytes != nil {
   469  		c.bytes = append(make([]byte, 0, cap(p.bytes)), p.bytes...)
   470  	}
   471  	return &c
   472  }