github.com/aloncn/graphics-go@v0.0.1/graphics/detect/opencv_parser.go (about)

     1  // Copyright 2011 The Graphics-Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package detect
     6  
     7  import (
     8  	"bytes"
     9  	"encoding/xml"
    10  	"errors"
    11  	"fmt"
    12  	"image"
    13  	"io"
    14  	"io/ioutil"
    15  	"strconv"
    16  	"strings"
    17  )
    18  
    19  type xmlFeature struct {
    20  	Rects     []string `xml:"grp>feature>rects>grp"`
    21  	Tilted    int      `xml:"grp>feature>tilted"`
    22  	Threshold float64  `xml:"grp>threshold"`
    23  	Left      float64  `xml:"grp>left_val"`
    24  	Right     float64  `xml:"grp>right_val"`
    25  }
    26  
    27  type xmlStages struct {
    28  	Trees           []xmlFeature `xml:"trees>grp"`
    29  	Stage_threshold float64      `xml:"stage_threshold"`
    30  	Parent          int          `xml:"parent"`
    31  	Next            int          `xml:"next"`
    32  }
    33  
    34  type opencv_storage struct {
    35  	Any struct {
    36  		XMLName xml.Name
    37  		Type    string      `xml:"type_id,attr"`
    38  		Size    string      `xml:"size"`
    39  		Stages  []xmlStages `xml:"stages>grp"`
    40  	} `xml:",any"`
    41  }
    42  
    43  func buildFeature(r string) (f Feature, err error) {
    44  	var x, y, w, h int
    45  	var weight float64
    46  	_, err = fmt.Sscanf(r, "%d %d %d %d %f", &x, &y, &w, &h, &weight)
    47  	if err != nil {
    48  		return
    49  	}
    50  	f.Rect = image.Rect(x, y, x+w, y+h)
    51  	f.Weight = weight
    52  	return
    53  }
    54  
    55  func buildCascade(s *opencv_storage) (c *Cascade, name string, err error) {
    56  	if s.Any.Type != "opencv-haar-classifier" {
    57  		err = fmt.Errorf("got %s want opencv-haar-classifier", s.Any.Type)
    58  		return
    59  	}
    60  	name = s.Any.XMLName.Local
    61  
    62  	c = &Cascade{}
    63  	sizes := strings.Split(s.Any.Size, " ")
    64  	w, err := strconv.Atoi(sizes[0])
    65  	if err != nil {
    66  		return nil, "", err
    67  	}
    68  	h, err := strconv.Atoi(sizes[1])
    69  	if err != nil {
    70  		return nil, "", err
    71  	}
    72  	c.Size = image.Pt(w, h)
    73  	c.Stage = []CascadeStage{}
    74  
    75  	for _, stage := range s.Any.Stages {
    76  		cs := CascadeStage{
    77  			Classifier: []Classifier{},
    78  			Threshold:  stage.Stage_threshold,
    79  		}
    80  		for _, tree := range stage.Trees {
    81  			if tree.Tilted != 0 {
    82  				err = errors.New("Cascade does not support tilted features")
    83  				return
    84  			}
    85  
    86  			cls := Classifier{
    87  				Feature:   []Feature{},
    88  				Threshold: tree.Threshold,
    89  				Left:      tree.Left,
    90  				Right:     tree.Right,
    91  			}
    92  
    93  			for _, rect := range tree.Rects {
    94  				f, err := buildFeature(rect)
    95  				if err != nil {
    96  					return nil, "", err
    97  				}
    98  				cls.Feature = append(cls.Feature, f)
    99  			}
   100  
   101  			cs.Classifier = append(cs.Classifier, cls)
   102  		}
   103  		c.Stage = append(c.Stage, cs)
   104  	}
   105  
   106  	return
   107  }
   108  
   109  // ParseOpenCV produces a detection Cascade from an OpenCV XML file.
   110  func ParseOpenCV(r io.Reader) (cascade *Cascade, name string, err error) {
   111  	// BUG(crawshaw): tag-based parsing doesn't seem to work with <_>
   112  	buf, err := ioutil.ReadAll(r)
   113  	if err != nil {
   114  		return
   115  	}
   116  	buf = bytes.Replace(buf, []byte("<_>"), []byte("<grp>"), -1)
   117  	buf = bytes.Replace(buf, []byte("</_>"), []byte("</grp>"), -1)
   118  
   119  	s := &opencv_storage{}
   120  	err = xml.Unmarshal(buf, s)
   121  	if err != nil {
   122  		return
   123  	}
   124  	return buildCascade(s)
   125  }