gorgonia.org/gorgonia@v0.9.17/op_yolo.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  	"hash"
     6  	"image"
     7  	"math"
     8  
     9  	"github.com/chewxy/hm"
    10  	"github.com/chewxy/math32"
    11  	"github.com/pkg/errors"
    12  	"gorgonia.org/tensor"
    13  )
    14  
    15  type yoloOp struct {
    16  	anchors     []float32
    17  	masks       []int
    18  	ignoreTresh float32
    19  	dimensions  int
    20  	numClasses  int
    21  	trainMode   bool
    22  }
    23  
    24  func newYoloOp(anchors []float32, masks []int, netSize, numClasses int, ignoreTresh float32, trainMode bool) *yoloOp {
    25  	yoloOp := &yoloOp{
    26  		anchors:     anchors,
    27  		dimensions:  netSize,
    28  		numClasses:  numClasses,
    29  		ignoreTresh: ignoreTresh,
    30  		masks:       masks,
    31  		trainMode:   trainMode,
    32  	}
    33  	return yoloOp
    34  }
    35  
    36  // YOLOv3 https://arxiv.org/abs/1804.02767
    37  func YOLOv3(input *Node, anchors []float32, masks []int, netSize, numClasses int, ignoreTresh float32, targets ...*Node) (*Node, error) {
    38  	if len(targets) > 0 {
    39  		inputSlice, err := Slice(input, S(0), nil, nil, nil)
    40  		if err != nil {
    41  			return nil, errors.Wrap(err, "Can't prepare YOLOv3 node for training mode due Slice() on input node error")
    42  		}
    43  		targetsSlice, err := Slice(targets[0], S(0), nil, nil, nil)
    44  		if err != nil {
    45  			return nil, errors.Wrap(err, "Can't prepare YOLOv3 node for training mode due Slice() on first node in target nodes slice error")
    46  		}
    47  		inputTargetConcat, err := Concat(0, inputSlice, targetsSlice)
    48  		if err != nil {
    49  			return nil, errors.Wrap(err, "Can't prepare YOLOv3 node for training mode due Concat() error")
    50  		}
    51  		concatShp := inputTargetConcat.Shape()
    52  		inputTargetConcat, err = Reshape(inputTargetConcat, []int{1, concatShp[0], concatShp[1], concatShp[2]})
    53  		if err != nil {
    54  			return nil, errors.Wrap(err, "Can't prepare YOLOv3 node for training mode due Reshape() error")
    55  		}
    56  		op := newYoloOp(anchors, masks, netSize, numClasses, ignoreTresh, true)
    57  		return ApplyOp(op, inputTargetConcat)
    58  	}
    59  	op := newYoloOp(anchors, masks, netSize, numClasses, ignoreTresh, false)
    60  	return ApplyOp(op, input)
    61  }
    62  
    63  func (op *yoloOp) Arity() int {
    64  	return 1
    65  }
    66  
    67  func (op *yoloOp) ReturnsPtr() bool { return false }
    68  
    69  func (op *yoloOp) CallsExtern() bool { return false }
    70  
    71  func (op *yoloOp) WriteHash(h hash.Hash) {
    72  	fmt.Fprintf(h, "YOLO{}(anchors: (%v))", op.anchors)
    73  }
    74  func (op *yoloOp) Hashcode() uint32 { return simpleHash(op) }
    75  
    76  func (op *yoloOp) String() string {
    77  	return fmt.Sprintf("YOLO{}(anchors: (%v))", op.anchors)
    78  }
    79  func (op *yoloOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) {
    80  	shp := inputs[0].(tensor.Shape)
    81  	if len(shp) < 4 {
    82  		return nil, fmt.Errorf("InferShape() for YOLO must contain 4 dimensions")
    83  	}
    84  	s := shp.Clone()
    85  	if op.trainMode {
    86  		return []int{s[0], s[2] * s[3] * len(op.masks), (s[1] - 1) / len(op.masks)}, nil
    87  	}
    88  	return []int{s[0], s[2] * s[3] * len(op.masks), s[1] / len(op.masks)}, nil
    89  }
    90  
    91  func (op *yoloOp) Type() hm.Type {
    92  	a := hm.TypeVariable('a')
    93  	t := newTensorType(4, a)
    94  	o := newTensorType(3, a)
    95  	return hm.NewFnType(t, o)
    96  }
    97  
    98  func (op *yoloOp) OverwritesInput() int { return -1 }
    99  
   100  func (op *yoloOp) checkInput(inputs ...Value) (tensor.Tensor, error) {
   101  	if err := checkArity(op, len(inputs)); err != nil {
   102  		return nil, errors.Wrap(err, "Can't check arity for YOLO operation")
   103  	}
   104  	var in tensor.Tensor
   105  	var ok bool
   106  	if in, ok = inputs[0].(tensor.Tensor); !ok {
   107  		return nil, errors.Errorf("Can't check YOLO input: expected input has to be a tensor")
   108  	}
   109  	if in.Shape().Dims() != 4 {
   110  		return nil, errors.Errorf("Can't check YOLO input: expected input must have 4 dimensions")
   111  	}
   112  	return in, nil
   113  }
   114  
   115  func sigmoidSlice(v tensor.View) error {
   116  	switch v.Dtype() {
   117  	case Float32:
   118  		_, err := v.Apply(_sigmoidf32, tensor.UseUnsafe())
   119  		if err != nil {
   120  			return errors.Wrap(err, "Can't apply _sigmoidf32 as activation function to YOLO operation")
   121  		}
   122  	case Float64:
   123  		_, err := v.Apply(_sigmoidf64, tensor.UseUnsafe())
   124  		if err != nil {
   125  			return errors.Wrap(err, "Can't apply _sigmoidf64 as activation function to YOLO operation")
   126  		}
   127  	default:
   128  		return fmt.Errorf("Unsupported numeric type for YOLO sigmoid function. Please use float64 or float32")
   129  	}
   130  	return nil
   131  }
   132  
   133  func expSlice(v tensor.View) error {
   134  	switch v.Dtype() {
   135  	case Float32:
   136  		_, err := v.Apply(math32.Exp, tensor.UseUnsafe())
   137  		if err != nil {
   138  			return errors.Wrap(err, "Can't apply exp32 to YOLO operation")
   139  		}
   140  	case Float64:
   141  		_, err := v.Apply(math.Exp, tensor.UseUnsafe())
   142  		if err != nil {
   143  			return errors.Wrap(err, "Can't apply exp64 to YOLO operation")
   144  		}
   145  	default:
   146  		return fmt.Errorf("Unsupported numeric type for YOLO for exp function. Please use float64 or float32")
   147  	}
   148  	return nil
   149  }
   150  
   151  func (op *yoloOp) Do(inputs ...Value) (retVal Value, err error) {
   152  	if !op.trainMode {
   153  		inputTensor, err := op.checkInput(inputs...)
   154  		if err != nil {
   155  			return nil, errors.Wrap(err, "Can't check YOLO input")
   156  		}
   157  		batchSize := inputTensor.Shape()[0]
   158  		stride := op.dimensions / inputTensor.Shape()[2]
   159  		gridSize := inputTensor.Shape()[2]
   160  		bboxAttributes := 5 + op.numClasses
   161  		numAnchors := len(op.anchors) / 2
   162  		currentAnchors := []float32{}
   163  		for i := range op.masks {
   164  			if op.masks[i] >= numAnchors {
   165  				return nil, fmt.Errorf("Incorrect mask %v for anchors in YOLO layer", op.masks)
   166  			}
   167  			currentAnchors = append(currentAnchors, op.anchors[i*2], op.anchors[i*2+1])
   168  		}
   169  		return op.evaluateYOLO_f32(inputTensor, batchSize, stride, gridSize, bboxAttributes, len(op.masks), currentAnchors)
   170  	}
   171  
   172  	// Training mode
   173  	input, err := op.checkInput(inputs...)
   174  	if err != nil {
   175  		return nil, errors.Wrap(err, "Can't check YOLO input [Training mode]")
   176  	}
   177  	inv, err := input.Slice(nil, S(0, input.Shape()[1]-1), nil, nil)
   178  	if err != nil {
   179  		return nil, errors.Wrap(err, "Can't prepare slice in YOLO (1) [Training mode]")
   180  	}
   181  	numTargets, err := input.At(0, input.Shape()[1]-1, 0, 0)
   182  	if err != nil {
   183  		return nil, errors.Wrap(err, "Can't select targets from YOLO input [Training mode]")
   184  	}
   185  
   186  	batchSize := input.Shape()[0]
   187  	stride := op.dimensions / input.Shape()[2]
   188  	grid := input.Shape()[2]
   189  	bboxAttributes := 5 + op.numClasses
   190  	numAnchors := len(op.masks)
   191  	currentAnchors := []float32{}
   192  	for i := range op.masks {
   193  		if op.masks[i] >= (len(op.anchors) / 2) {
   194  			return nil, fmt.Errorf("Incorrect mask %v for anchors in YOLO layer [Training mode]", op.masks)
   195  		}
   196  		currentAnchors = append(currentAnchors, op.anchors[i*2], op.anchors[i*2+1])
   197  	}
   198  
   199  	targets := []float32{}
   200  	inputNumericType := input.Dtype()
   201  
   202  	switch inputNumericType {
   203  	case Float32:
   204  		lt := int(numTargets.(float32))
   205  		targets = make([]float32, lt)
   206  		for i := 1; i <= lt; i++ {
   207  			valAt, err := input.At(0, input.Shape()[1]-1, i/grid, i%grid)
   208  			if err != nil {
   209  				return nil, fmt.Errorf("Can't select float32 targets for YOLO [Training mode]")
   210  			}
   211  			targets[i-1] = valAt.(float32)
   212  		}
   213  		break
   214  	case Float64:
   215  		lt := int(numTargets.(float64))
   216  		targets = make([]float32, lt)
   217  		for i := 1; i <= lt; i++ {
   218  			valAt, err := input.At(0, input.Shape()[1]-1, i/grid, i%grid)
   219  			if err != nil {
   220  				return nil, fmt.Errorf("Can't select float64 targets for YOLO [Training mode]")
   221  			}
   222  			targets[i-1] = float32(valAt.(float64))
   223  		}
   224  		break
   225  	default:
   226  		return nil, fmt.Errorf("Unsupported numeric type while preparing targets for YOLO Please use float64 or float32 [Training mode]")
   227  	}
   228  
   229  	input = inv.Materialize()
   230  
   231  	err = input.Reshape(batchSize, bboxAttributes*numAnchors, grid*grid)
   232  	if err != nil {
   233  		return nil, errors.Wrap(err, "Can't reshape in YOLO (1) [Training mode]")
   234  	}
   235  	err = input.T(0, 2, 1)
   236  	if err != nil {
   237  		return nil, errors.Wrap(err, "Can't safely transponse in YOLO (1) [Training mode]")
   238  	}
   239  	err = input.Transpose()
   240  	if err != nil {
   241  		return nil, errors.Wrap(err, "Can't transponse in YOLO (1) [Training mode]")
   242  	}
   243  	err = input.Reshape(batchSize, grid*grid*numAnchors, bboxAttributes)
   244  	if err != nil {
   245  		return nil, errors.Wrap(err, "Can't reshape in YOLO (2) [Training mode]")
   246  	}
   247  
   248  	clonedInput := input.Clone().(tensor.Tensor)
   249  	outyolo, err := op.evaluateYOLO_f32(input, batchSize, stride, grid, bboxAttributes, numAnchors, currentAnchors)
   250  	if err != nil {
   251  		return nil, errors.Wrap(err, "Can't evaluate YOLO operation [Training mode]")
   252  	}
   253  
   254  	yoloNumericType := outyolo.Dtype()
   255  	result := &tensor.Dense{}
   256  
   257  	switch yoloNumericType {
   258  	case Float32:
   259  		yoloBBoxesF32 := make([]float32, 0)
   260  		inputF32 := make([]float32, 0)
   261  		err = clonedInput.Reshape(input.Shape()[0] * input.Shape()[1] * input.Shape()[2])
   262  		if err != nil {
   263  			return nil, errors.Wrap(err, "Can't reshape in YOLO (3) [Training mode]")
   264  		}
   265  		err = outyolo.Reshape(outyolo.Shape()[0] * outyolo.Shape()[1] * outyolo.Shape()[2])
   266  		if err != nil {
   267  			return nil, errors.Wrap(err, "Can't reshape in YOLO (3) [Training mode]")
   268  		}
   269  		for i := 0; i < outyolo.Shape()[0]; i++ {
   270  			buf, err := outyolo.At(i)
   271  			if err != nil {
   272  				return nil, errors.Wrap(err, "Can't select value from YOLO output [Training mode]")
   273  			}
   274  			yoloBBoxesF32 = append(yoloBBoxesF32, buf.(float32))
   275  			buf, err = clonedInput.At(i)
   276  			if err != nil {
   277  				return nil, errors.Wrap(err, "Can't select value from YOLO bounding boxes [Training mode]")
   278  			}
   279  			inputF32 = append(inputF32, buf.(float32))
   280  		}
   281  		preparedOut := prepareOutputYOLO_f32(inputF32, yoloBBoxesF32, targets, op.anchors, op.masks, op.numClasses, op.dimensions, grid, op.ignoreTresh)
   282  		result = tensor.New(tensor.WithShape(1, grid*grid*len(op.masks), 5+op.numClasses), tensor.Of(tensor.Float32), tensor.WithBacking(preparedOut))
   283  		break
   284  	case Float64:
   285  		// @todo
   286  		return nil, fmt.Errorf("float64 numeric type is not implemented for preparing result for YOLO [Training mode]")
   287  	default:
   288  		return nil, fmt.Errorf("Unsupported numeric type for preparing result for YOLO. Please use float64 or float32 [Training mode]")
   289  	}
   290  
   291  	return result, nil
   292  }
   293  
   294  func (op *yoloOp) evaluateYOLO_f32(input tensor.Tensor, batchSize, stride, grid, bboxAttrs, numAnchors int, currentAnchors []float32) (retVal tensor.Tensor, err error) {
   295  
   296  	inputNumericType := input.Dtype()
   297  	if inputNumericType != Float32 {
   298  		return nil, fmt.Errorf("evaluateYOLO_f32() called with input tensor of type %v. Float32 is required", inputNumericType)
   299  	}
   300  
   301  	err = input.Reshape(batchSize, bboxAttrs*numAnchors, grid*grid)
   302  	if err != nil {
   303  		return nil, errors.Wrap(err, "Can't make reshape grid^2 for YOLO")
   304  	}
   305  
   306  	err = input.T(0, 2, 1)
   307  	if err != nil {
   308  		return nil, errors.Wrap(err, "Can't safely transponse input for YOLO")
   309  	}
   310  	err = input.Transpose()
   311  	if err != nil {
   312  		return nil, errors.Wrap(err, "Can't transponse input for YOLO")
   313  	}
   314  	err = input.Reshape(batchSize, grid*grid*numAnchors, bboxAttrs)
   315  	if err != nil {
   316  		return nil, errors.Wrap(err, "Can't reshape bbox for YOLO")
   317  	}
   318  
   319  	// Activation of x, y, and objects via sigmoid function
   320  	slXY, err := input.Slice(nil, nil, S(0, 2))
   321  	err = sigmoidSlice(slXY)
   322  	if err != nil {
   323  		return nil, errors.Wrap(err, "Can't activate XY")
   324  	}
   325  	slClasses, err := input.Slice(nil, nil, S(4, 5+op.numClasses))
   326  	err = sigmoidSlice(slClasses)
   327  	if err != nil {
   328  		return nil, errors.Wrap(err, "Can't activate classes")
   329  	}
   330  
   331  	step := grid * numAnchors
   332  	for i := 0; i < grid; i++ {
   333  
   334  		vy, err := input.Slice(nil, S(i*step, i*step+step), S(1))
   335  		if err != nil {
   336  			return nil, errors.Wrap(err, "Can't slice while doing steps for grid")
   337  		}
   338  
   339  		_, err = tensor.Add(vy, float32(i), tensor.UseUnsafe())
   340  		if err != nil {
   341  			return nil, errors.Wrap(err, "Can't do tensor.Add(...) for float32; (1)")
   342  		}
   343  
   344  		for n := 0; n < numAnchors; n++ {
   345  			anchorsSlice, err := input.Slice(nil, S(i*numAnchors+n, input.Shape()[1], step), S(0))
   346  			if err != nil {
   347  				return nil, errors.Wrap(err, "Can't slice anchors while doing steps for grid")
   348  			}
   349  			_, err = tensor.Add(anchorsSlice, float32(i), tensor.UseUnsafe())
   350  			if err != nil {
   351  				return nil, errors.Wrap(err, "Can't do tensor.Add(...) for float32; (1)")
   352  			}
   353  		}
   354  
   355  	}
   356  
   357  	anchors := []float32{}
   358  	for i := 0; i < grid*grid; i++ {
   359  		anchors = append(anchors, currentAnchors...)
   360  	}
   361  
   362  	anchorsTensor := tensor.New(tensor.Of(inputNumericType), tensor.WithShape(1, grid*grid*numAnchors, 2))
   363  	for i := range anchors {
   364  		anchorsTensor.Set(i, anchors[i])
   365  	}
   366  
   367  	_, err = tensor.Div(anchorsTensor, float32(stride), tensor.UseUnsafe())
   368  	if err != nil {
   369  		return nil, errors.Wrap(err, "Can't do tensor.Div(...) for float32")
   370  	}
   371  
   372  	vhw, err := input.Slice(nil, nil, S(2, 4))
   373  	if err != nil {
   374  		return nil, errors.Wrap(err, "Can't do slice on input S(2,4)")
   375  	}
   376  
   377  	_, err = vhw.Apply(math32.Exp, tensor.UseUnsafe())
   378  	if err != nil {
   379  		return nil, errors.Wrap(err, "Can't apply exp32 to YOLO operation")
   380  	}
   381  
   382  	_, err = tensor.Mul(vhw, anchorsTensor, tensor.UseUnsafe())
   383  	if err != nil {
   384  		return nil, errors.Wrap(err, "Can't do tensor.Mul(...) for anchors")
   385  	}
   386  
   387  	vv, err := input.Slice(nil, nil, S(0, 4))
   388  	if err != nil {
   389  		return nil, errors.Wrap(err, "Can't do slice on input S(0,4)")
   390  	}
   391  
   392  	_, err = tensor.Mul(vv, float32(stride), tensor.UseUnsafe())
   393  	if err != nil {
   394  		return nil, errors.Wrap(err, "Can't do tensor.Mul(...) for float32")
   395  	}
   396  
   397  	return input, nil
   398  }
   399  
   400  func iou_f32(r1, r2 image.Rectangle) float32 {
   401  	intersection := r1.Intersect(r2)
   402  	interArea := intersection.Dx() * intersection.Dy()
   403  	r1Area := r1.Dx() * r1.Dy()
   404  	r2Area := r2.Dx() * r2.Dy()
   405  	return float32(interArea) / float32(r1Area+r2Area-interArea)
   406  }
   407  
   408  func getBestIOU_f32(input, target []float32, numClasses, dims int) [][]float32 {
   409  	ious := make([][]float32, 0)
   410  	imgsize := float32(dims)
   411  	for i := 0; i < len(input); i = i + numClasses + 5 {
   412  		ious = append(ious, []float32{0, -1})
   413  		r1 := rectifyBox_f32(input[i], input[i+1], input[i+2], input[i+3], dims)
   414  		for j := 0; j < len(target); j = j + 5 {
   415  			r2 := rectifyBox_f32(target[j+1]*imgsize, target[j+2]*imgsize, target[j+3]*imgsize, target[j+4]*imgsize, dims)
   416  			curiou := iou_f32(r1, r2)
   417  			if curiou > ious[i/(5+numClasses)][0] {
   418  				ious[i/(5+numClasses)][0] = curiou
   419  				ious[i/(5+numClasses)][1] = float32(j / 5)
   420  			}
   421  		}
   422  	}
   423  	return ious
   424  }
   425  
   426  func getBestAnchors_f32(target []float32, anchors []float32, masks []int, dims int, gridSize float32) [][]int {
   427  	bestAnchors := make([][]int, len(target)/5)
   428  	imgsize := float32(dims)
   429  	for j := 0; j < len(target); j = j + 5 {
   430  		targetRect := rectifyBox_f32(0, 0, target[j+3]*imgsize, target[j+4]*imgsize, dims) //not absolutely confident in rectangle sizes
   431  		bestIOU := float32(0.0)
   432  		bestAnchors[j/5] = make([]int, 3)
   433  		for i := 0; i < len(anchors); i = i + 2 {
   434  			anchorRect := rectifyBox_f32(0, 0, anchors[i], anchors[i+1], dims)
   435  			currentIOU := iou_f32(anchorRect, targetRect)
   436  			if currentIOU >= bestIOU {
   437  				bestAnchors[j/5][0] = i
   438  				bestIOU = currentIOU
   439  			}
   440  		}
   441  		bestAnchors[j/5][0] = findIntElement(masks, bestAnchors[j/5][0]/2)
   442  		if bestAnchors[j/5][0] != -1 {
   443  			bestAnchors[j/5][1] = int(target[j+1] * gridSize)
   444  			bestAnchors[j/5][2] = int(target[j+2] * gridSize)
   445  		}
   446  	}
   447  	return bestAnchors
   448  }
   449  
   450  func prepareOutputYOLO_f32(input, yoloBoxes, target, anchors []float32, masks []int, numClasses, dims, gridSize int, ignoreTresh float32) []float32 {
   451  	yoloBBoxes := make([]float32, len(yoloBoxes))
   452  	gridSizeF32 := float32(gridSize)
   453  	bestAnchors := getBestAnchors_f32(target, anchors, masks, dims, gridSizeF32)
   454  	bestIous := getBestIOU_f32(yoloBoxes, target, numClasses, dims)
   455  	for i := 0; i < len(yoloBoxes); i = i + (5 + numClasses) {
   456  		if bestIous[i/(5+numClasses)][0] <= ignoreTresh {
   457  			yoloBBoxes[i+4] = bceLoss32(0, yoloBoxes[i+4])
   458  		}
   459  	}
   460  	for i := 0; i < len(bestAnchors); i++ {
   461  		if bestAnchors[i][0] != -1 {
   462  			scale := (2 - target[i*5+3]*target[i*5+4])
   463  			giInt := bestAnchors[i][1]
   464  			gjInt := bestAnchors[i][2]
   465  			gx := invsigm32(target[i*5+1]*gridSizeF32 - float32(giInt))
   466  			gy := invsigm32(target[i*5+2]*gridSizeF32 - float32(gjInt))
   467  			gw := math32.Log(target[i*5+3]/anchors[bestAnchors[i][0]] + 1e-16)
   468  			gh := math32.Log(target[i*5+4]/anchors[bestAnchors[i][0]+1] + 1e-16)
   469  			bboxIdx := gjInt*gridSize*len(masks) + giInt*len(masks) + bestAnchors[i][0]
   470  			yoloBBoxes[bboxIdx] = mseLoss32(gx, input[bboxIdx], scale)
   471  			yoloBBoxes[bboxIdx+1] = mseLoss32(gy, input[bboxIdx+1], scale)
   472  			yoloBBoxes[bboxIdx+2] = mseLoss32(gw, input[bboxIdx+2], scale)
   473  			yoloBBoxes[bboxIdx+3] = mseLoss32(gh, input[bboxIdx+3], scale)
   474  			yoloBBoxes[bboxIdx+4] = bceLoss32(1, yoloBoxes[bboxIdx+4])
   475  			for j := 0; j < numClasses; j++ {
   476  				if j == int(target[i]) {
   477  					yoloBBoxes[bboxIdx+5+j] = bceLoss32(1, yoloBoxes[bboxIdx+4])
   478  				} else {
   479  					yoloBBoxes[bboxIdx+5+j] = bceLoss32(0, yoloBoxes[bboxIdx+4])
   480  				}
   481  			}
   482  		}
   483  	}
   484  	return yoloBBoxes
   485  }
   486  
   487  func findIntElement(arr []int, ele int) int {
   488  	for i := range arr {
   489  		if arr[i] == ele {
   490  			return i
   491  		}
   492  	}
   493  	return -1
   494  }
   495  
   496  func rectifyBox_f32(x, y, h, w float32, imgSize int) image.Rectangle {
   497  	return image.Rect(maxInt(int(x-w/2), 0), maxInt(int(y-h/2), 0), minInt(int(x+w/2+1), imgSize), minInt(int(y+h/2+1), imgSize))
   498  }
   499  
   500  func bceLoss32(target, pred float32) float32 {
   501  	if target == 1.0 {
   502  		return -(math32.Log(pred + 1e-16))
   503  	}
   504  	return -(math32.Log((1.0 - pred) + 1e-16))
   505  }
   506  
   507  func mseLoss32(target, pred, scale float32) float32 {
   508  	return math32.Pow(scale*(target-pred), 2) / 2.0
   509  }
   510  
   511  func invsigm32(target float32) float32 {
   512  	return -math32.Log(1-target+1e-16) + math32.Log(target+1e-16)
   513  }
   514  
   515  func (op *yoloOp) evaluateYOLO_f64(input tensor.Tensor, batchSize, stride, grid, bboxAttrs, numAnchors int, currentAnchors []float64) (retVal tensor.Tensor, err error) {
   516  	inputNumericType := input.Dtype()
   517  	if inputNumericType != Float64 {
   518  		return nil, fmt.Errorf("evaluateYOLO_f64() called with input tensor of type %v. Float64 is required", inputNumericType)
   519  	}
   520  	err = input.Reshape(batchSize, bboxAttrs*numAnchors, grid*grid)
   521  	if err != nil {
   522  		return nil, errors.Wrap(err, "Can't make reshape grid^2 for YOLO")
   523  	}
   524  	err = input.T(0, 2, 1)
   525  	if err != nil {
   526  		return nil, errors.Wrap(err, "Can't safely transponse input for YOLO")
   527  	}
   528  	err = input.Transpose()
   529  	if err != nil {
   530  		return nil, errors.Wrap(err, "Can't transponse input for YOLO")
   531  	}
   532  	err = input.Reshape(batchSize, grid*grid*numAnchors, bboxAttrs)
   533  	if err != nil {
   534  		return nil, errors.Wrap(err, "Can't reshape bbox for YOLO")
   535  	}
   536  
   537  	// Activation of x, y, and objects via sigmoid function
   538  	slXY, err := input.Slice(nil, nil, S(0, 2))
   539  	err = sigmoidSlice(slXY)
   540  	if err != nil {
   541  		return nil, errors.Wrap(err, "Can't activate XY")
   542  	}
   543  	slClasses, err := input.Slice(nil, nil, S(4, 5+op.numClasses))
   544  	err = sigmoidSlice(slClasses)
   545  	if err != nil {
   546  		return nil, errors.Wrap(err, "Can't activate classes")
   547  	}
   548  
   549  	step := grid * numAnchors
   550  	for i := 0; i < grid; i++ {
   551  		vy, err := input.Slice(nil, S(i*step, i*step+step), S(1))
   552  		if err != nil {
   553  			return nil, errors.Wrap(err, "Can't slice while doing steps for grid")
   554  		}
   555  		_, err = tensor.Add(vy, float64(i), tensor.UseUnsafe())
   556  		if err != nil {
   557  			return nil, errors.Wrap(err, "Can't do tensor.Add(...) for float64; (1)")
   558  		}
   559  		for n := 0; n < numAnchors; n++ {
   560  			anchorsSlice, err := input.Slice(nil, S(i*numAnchors+n, input.Shape()[1], step), S(0))
   561  			if err != nil {
   562  				return nil, errors.Wrap(err, "Can't slice anchors while doing steps for grid")
   563  			}
   564  			_, err = tensor.Add(anchorsSlice, float64(i), tensor.UseUnsafe())
   565  			if err != nil {
   566  				return nil, errors.Wrap(err, "Can't do tensor.Add(...) for float64; (2)")
   567  			}
   568  		}
   569  
   570  	}
   571  
   572  	anchors := []float64{}
   573  	for i := 0; i < grid*grid; i++ {
   574  		anchors = append(anchors, currentAnchors...)
   575  	}
   576  
   577  	anchorsTensor := tensor.New(tensor.Of(inputNumericType), tensor.WithShape(1, grid*grid*numAnchors, 2))
   578  	for i := range anchors {
   579  		anchorsTensor.Set(i, anchors[i])
   580  	}
   581  
   582  	_, err = tensor.Div(anchorsTensor, float64(stride), tensor.UseUnsafe())
   583  	if err != nil {
   584  		return nil, errors.Wrap(err, "Can't do tensor.Div(...) for float64")
   585  	}
   586  
   587  	vhw, err := input.Slice(nil, nil, S(2, 4))
   588  	if err != nil {
   589  		return nil, errors.Wrap(err, "Can't do slice on input S(2,4)")
   590  	}
   591  
   592  	_, err = vhw.Apply(math.Exp, tensor.UseUnsafe())
   593  	if err != nil {
   594  		return nil, errors.Wrap(err, "Can't apply exp64 to YOLO operation")
   595  	}
   596  
   597  	_, err = tensor.Mul(vhw, anchorsTensor, tensor.UseUnsafe())
   598  	if err != nil {
   599  		return nil, errors.Wrap(err, "Can't do tensor.Mul(...) for anchors")
   600  	}
   601  
   602  	vv, err := input.Slice(nil, nil, S(0, 4))
   603  	if err != nil {
   604  		return nil, errors.Wrap(err, "Can't do slice on input S(0,4)")
   605  	}
   606  
   607  	_, err = tensor.Mul(vv, float64(stride), tensor.UseUnsafe())
   608  	if err != nil {
   609  		return nil, errors.Wrap(err, "Can't do tensor.Mul(...) for float64")
   610  	}
   611  
   612  	return input, nil
   613  }
   614  
   615  func iou_f64(r1, r2 image.Rectangle) float64 {
   616  	intersection := r1.Intersect(r2)
   617  	interArea := intersection.Dx() * intersection.Dy()
   618  	r1Area := r1.Dx() * r1.Dy()
   619  	r2Area := r2.Dx() * r2.Dy()
   620  	return float64(interArea) / float64(r1Area+r2Area-interArea)
   621  }
   622  
   623  func getBestIOU_f64(input, target []float64, numClasses, dims int) [][]float64 {
   624  	ious := make([][]float64, 0)
   625  	imgsize := float64(dims)
   626  	for i := 0; i < len(input); i = i + numClasses + 5 {
   627  		ious = append(ious, []float64{0, -1})
   628  		r1 := rectifyBox_f64(input[i], input[i+1], input[i+2], input[i+3], dims)
   629  		for j := 0; j < len(target); j = j + 5 {
   630  			r2 := rectifyBox_f64(target[j+1]*imgsize, target[j+2]*imgsize, target[j+3]*imgsize, target[j+4]*imgsize, dims)
   631  			curiou := iou_f64(r1, r2)
   632  			if curiou > ious[i/(5+numClasses)][0] {
   633  				ious[i/(5+numClasses)][0] = curiou
   634  				ious[i/(5+numClasses)][1] = float64(j / 5)
   635  			}
   636  		}
   637  	}
   638  	return ious
   639  }
   640  
   641  func getBestAnchors_f64(target []float64, anchors []float64, masks []int, dims int, gridSize float64) [][]int {
   642  	bestAnchors := make([][]int, len(target)/5)
   643  	imgsize := float64(dims)
   644  	for j := 0; j < len(target); j = j + 5 {
   645  		targetRect := rectifyBox_f64(0, 0, target[j+3]*imgsize, target[j+4]*imgsize, dims) //not absolutely confident in rectangle sizes
   646  		bestIOU := float64(0.0)
   647  		bestAnchors[j/5] = make([]int, 3)
   648  		for i := 0; i < len(anchors); i = i + 2 {
   649  			anchorRect := rectifyBox_f64(0, 0, anchors[i], anchors[i+1], dims)
   650  			currentIOU := iou_f64(anchorRect, targetRect)
   651  			if currentIOU >= bestIOU {
   652  				bestAnchors[j/5][0] = i
   653  				bestIOU = currentIOU
   654  			}
   655  		}
   656  		bestAnchors[j/5][0] = findIntElement(masks, bestAnchors[j/5][0]/2)
   657  		if bestAnchors[j/5][0] != -1 {
   658  			bestAnchors[j/5][1] = int(target[j+1] * gridSize)
   659  			bestAnchors[j/5][2] = int(target[j+2] * gridSize)
   660  		}
   661  	}
   662  	return bestAnchors
   663  }
   664  
   665  func prepareOutputYOLO_f64(input, yoloBoxes, target, anchors []float64, masks []int, numClasses, dims, gridSize int, ignoreTresh float64) []float64 {
   666  	yoloBBoxes := make([]float64, len(yoloBoxes))
   667  	gridSizeF64 := float64(gridSize)
   668  	bestAnchors := getBestAnchors_f64(target, anchors, masks, dims, gridSizeF64)
   669  	bestIous := getBestIOU_f64(yoloBoxes, target, numClasses, dims)
   670  	for i := 0; i < len(yoloBoxes); i = i + (5 + numClasses) {
   671  		if bestIous[i/(5+numClasses)][0] <= ignoreTresh {
   672  			yoloBBoxes[i+4] = bceLoss64(0, yoloBoxes[i+4])
   673  		}
   674  	}
   675  	for i := 0; i < len(bestAnchors); i++ {
   676  		if bestAnchors[i][0] != -1 {
   677  			scale := (2 - target[i*5+3]*target[i*5+4])
   678  			giInt := bestAnchors[i][1]
   679  			gjInt := bestAnchors[i][2]
   680  			gx := invsigm64(target[i*5+1]*gridSizeF64 - float64(giInt))
   681  			gy := invsigm64(target[i*5+2]*gridSizeF64 - float64(gjInt))
   682  			gw := math.Log(target[i*5+3]/anchors[bestAnchors[i][0]] + 1e-16)
   683  			gh := math.Log(target[i*5+4]/anchors[bestAnchors[i][0]+1] + 1e-16)
   684  			bboxIdx := gjInt*gridSize*len(masks) + giInt*len(masks) + bestAnchors[i][0]
   685  			yoloBBoxes[bboxIdx] = mseLoss64(gx, input[bboxIdx], scale)
   686  			yoloBBoxes[bboxIdx+1] = mseLoss64(gy, input[bboxIdx+1], scale)
   687  			yoloBBoxes[bboxIdx+2] = mseLoss64(gw, input[bboxIdx+2], scale)
   688  			yoloBBoxes[bboxIdx+3] = mseLoss64(gh, input[bboxIdx+3], scale)
   689  			yoloBBoxes[bboxIdx+4] = bceLoss64(1, yoloBoxes[bboxIdx+4])
   690  			for j := 0; j < numClasses; j++ {
   691  				if j == int(target[i]) {
   692  					yoloBBoxes[bboxIdx+5+j] = bceLoss64(1, yoloBoxes[bboxIdx+4])
   693  				} else {
   694  					yoloBBoxes[bboxIdx+5+j] = bceLoss64(0, yoloBoxes[bboxIdx+4])
   695  				}
   696  			}
   697  		}
   698  	}
   699  	return yoloBBoxes
   700  }
   701  
   702  func rectifyBox_f64(x, y, h, w float64, imgSize int) image.Rectangle {
   703  	return image.Rect(maxInt(int(x-w/2), 0), maxInt(int(y-h/2), 0), minInt(int(x+w/2+1), imgSize), minInt(int(y+h/2+1), imgSize))
   704  }
   705  
   706  func bceLoss64(target, pred float64) float64 {
   707  	if target == 1.0 {
   708  		return -(math.Log(pred + 1e-16))
   709  	}
   710  	return -(math.Log((1.0 - pred) + 1e-16))
   711  }
   712  
   713  func mseLoss64(target, pred, scale float64) float64 {
   714  	return math.Pow(scale*(target-pred), 2) / 2.0
   715  }
   716  
   717  func invsigm64(target float64) float64 {
   718  	return -math.Log(1-target+1e-16) + math.Log(target+1e-16)
   719  }