github.com/egonelbre/exp@v0.0.0-20240430123955-ed1d3aa93911/smime/ber/tree.go (about)

     1  package ber
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  )
     7  
     8  type Tree struct {
     9  	*Token
    10  	Children []*Tree
    11  }
    12  
    13  func ParseTree(d *Decoder) (tree *Tree, err error) {
    14  	tree = &Tree{}
    15  	tree.Token, err = d.Token()
    16  	if err != nil {
    17  		return
    18  	}
    19  	if tree.Kind == Constructed {
    20  		var sub *Tree
    21  		for {
    22  			sub, err = ParseTree(d)
    23  			if err != nil {
    24  				return
    25  			}
    26  			if sub.Token.Kind == EndConstructed {
    27  				break
    28  			}
    29  
    30  			tree.Children = append(tree.Children, sub)
    31  		}
    32  	}
    33  	return
    34  }
    35  
    36  func writeTree(out *forkableWriter, tree *Tree) {
    37  	if len(tree.Children) == 0 {
    38  		h := header{tree.Class, tree.Tag, len(tree.Bytes), false, false}
    39  		marshalHeader(out, h)
    40  		out.Write(tree.Bytes)
    41  	} else {
    42  		head, body := out.fork()
    43  		last := body
    44  		for _, sub := range tree.Children {
    45  			var pre *forkableWriter
    46  			pre, last = last.fork()
    47  			writeTree(pre, sub)
    48  		}
    49  
    50  		h := header{tree.Class, tree.Tag, body.Len(), true, false}
    51  		marshalHeader(head, h)
    52  	}
    53  }
    54  
    55  func EncodeTree(tree *Tree) ([]byte, error) {
    56  	var out bytes.Buffer
    57  	f := newForkableWriter()
    58  	writeTree(f, tree)
    59  	f.writeTo(&out)
    60  	return out.Bytes(), nil
    61  }
    62  
    63  func DecodeTree(data []byte) (*Tree, error) {
    64  	d := NewDecoder(bytes.NewReader(data))
    65  	return ParseTree(d)
    66  }
    67  
    68  func (t *Tree) mustNotBeEnd() {
    69  	if t.Kind == EndConstructed {
    70  		panic("invalid tree type")
    71  	}
    72  }
    73  
    74  func (t *Tree) AsOctetString() (ret []byte, err error) {
    75  	t.mustNotBeEnd()
    76  	if t.Kind == Value {
    77  		return t.Token.AsOctetString()
    78  	}
    79  	// constructed octet string
    80  	if t.Kind == Constructed {
    81  		var buf []byte
    82  		expectClass := t.Class
    83  		expectTag := t.Tag
    84  		for _, chunk := range t.Children {
    85  			if chunk.Tag != expectTag || chunk.Class != expectClass {
    86  				return nil, fmt.Errorf("invalid constructed octet string")
    87  			}
    88  			buf, err = chunk.AsOctetString()
    89  			if err != nil {
    90  				return
    91  			}
    92  			ret = append(ret, buf...)
    93  		}
    94  	}
    95  	return
    96  }