github.com/wzzhu/tensor@v0.9.24/genlib2/agg2_body.go (about)

     1  package main
     2  
     3  import "text/template"
     4  
     5  // level 2 aggregation (tensor.StdEng) templates
     6  
     7  const cmpPrepRaw = `var safe, same bool
     8  	if reuse, safe, _, _, same, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(),  {{.VecVar}}.DataOrder(),false, opts...); err != nil{
     9  		return nil, errors.Wrap(err, "Unable to handle funcOpts")
    10  	}
    11  	if !safe {
    12  		same = true
    13  	}
    14  `
    15  
    16  const arithPrepRaw = `var safe, toReuse, incr bool
    17  	if reuse, safe, toReuse, incr, _, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), {{.VecVar}}.DataOrder(), true, opts...); err != nil{
    18  		return nil, errors.Wrap(err, "Unable to handle funcOpts")
    19  	}
    20  `
    21  
    22  const minmaxPrepRaw = `var safe bool
    23  	if reuse, safe, _, _, _, err = handleFuncOpts({{.VecVar}}.Shape(), {{.VecVar}}.Dtype(), {{.VecVar}}.DataOrder(), true, opts...); err != nil{
    24  		return nil, errors.Wrap(err, "Unable to handle funcOpts")
    25  	}
    26  `
    27  
    28  const prepVVRaw = `if err = binaryCheck(a, b, {{.TypeClassCheck | lower}}Types); err != nil {
    29  		return nil, errors.Wrapf(err, "{{.Name}} failed")
    30  	}
    31  
    32  	var reuse DenseTensor
    33  	{{template "prep" . -}}
    34  
    35  
    36  	typ := a.Dtype().Type
    37  	var dataA, dataB, dataReuse *storage.Header
    38  	var ait, bit, iit Iterator
    39  	var useIter, swap bool
    40  	if dataA, dataB, dataReuse, ait, bit, iit, useIter, swap, err = prepDataVV(a, b, reuse); err != nil {
    41  		return nil, errors.Wrapf(err, "StdEng.{{.Name}}")
    42  	}
    43  `
    44  
    45  const prepMixedRaw = `if err = unaryCheck(t, {{.TypeClassCheck | lower}}Types); err != nil {
    46  		return nil, errors.Wrapf(err, "{{.Name}} failed")
    47  	}
    48  
    49  	if err = scalarDtypeCheck(t, s); err != nil {
    50  		return nil, errors.Wrap(err, "{{.Name}} failed")
    51  	}
    52  
    53  	var reuse DenseTensor
    54  	{{template "prep" . -}}
    55  
    56  	a := t
    57  	typ := t.Dtype().Type
    58  	var ait, bit,  iit Iterator
    59  	var dataA, dataB, dataReuse, scalarHeader *storage.Header
    60  	var useIter, newAlloc bool
    61  
    62  	if leftTensor {
    63  		if dataA, dataB, dataReuse, ait, iit, useIter, newAlloc, err = prepDataVS(t, s, reuse); err != nil {
    64  			return nil, errors.Wrapf(err, opFail, "StdEng.{{.Name}}")
    65  		}
    66  		scalarHeader = dataB
    67  	} else {
    68  		if dataA, dataB, dataReuse, bit, iit, useIter, newAlloc, err = prepDataSV(s, t, reuse); err != nil {
    69  			return nil, errors.Wrapf(err, opFail, "StdEng.{{.Name}}")
    70  		}
    71  		scalarHeader = dataA
    72  	}
    73  
    74  `
    75  
    76  const prepUnaryRaw = `if err = unaryCheck(a, {{.TypeClassCheck | lower}}Types); err != nil {
    77  		err = errors.Wrapf(err, "{{.Name}} failed")
    78  		return
    79  	}
    80  	var reuse DenseTensor
    81  	var safe, toReuse, incr bool
    82  	if reuse, safe, toReuse, incr, _, err = handleFuncOpts(a.Shape(), a.Dtype(), a.DataOrder(), true, opts...); err != nil {
    83  		return nil, errors.Wrap(err, "Unable to handle funcOpts")
    84  	}
    85  
    86  	typ := a.Dtype().Type
    87  	var ait, rit Iterator
    88  	var dataA, dataReuse *storage.Header
    89  	var useIter bool
    90  
    91  	if dataA, dataReuse, ait, rit, useIter, err = prepDataUnary(a, reuse); err != nil{
    92  		return nil, errors.Wrapf(err, opFail, "StdEng.{{.Name}}")
    93  	}
    94  	`
    95  
    96  const agg2BodyRaw = `if useIter {
    97  		switch {
    98  		case incr:
    99  			err = e.E.{{.Name}}IterIncr(typ, dataA, dataB, dataReuse, ait, bit, iit)
   100  			retVal = reuse
   101  		{{if .VV -}}
   102  		case toReuse:
   103  			storage.CopyIter(typ,dataReuse, dataA, iit, ait)
   104  			ait.Reset()
   105  			iit.Reset()
   106  			err = e.E.{{.Name}}Iter(typ, dataReuse, dataB, iit, bit)
   107  			retVal = reuse
   108  		{{else -}}
   109  		case toReuse && leftTensor:
   110  			storage.CopyIter(typ, dataReuse, dataA, iit, ait)
   111  			ait.Reset()
   112  			iit.Reset()
   113  			err = e.E.{{.Name}}Iter(typ, dataReuse, dataB, iit, bit)
   114  			retVal = reuse
   115  		case toReuse && !leftTensor:
   116  			storage.CopyIter(typ, dataReuse, dataB, iit, bit)
   117  			iit.Reset()
   118  			bit.Reset()
   119  			err = e.E.{{.Name}}Iter(typ, dataA, dataReuse, ait, iit)
   120  			retVal = reuse
   121  		{{end -}}
   122  		case !safe:
   123  			err = e.E.{{.Name}}Iter(typ, dataA, dataB, ait, bit)
   124  			retVal = a
   125  		default:
   126  		{{if .VV -}}
   127  			if swap {
   128  				retVal = b.Clone().(Tensor)
   129  			}else{
   130  				retVal = a.Clone().(Tensor)
   131  			}
   132  			err = e.E.{{.Name}}Iter(typ, retVal.hdr(), dataB, ait, bit)
   133  		{{else -}}
   134  			retVal = a.Clone().(Tensor)
   135  			if leftTensor {
   136  				err = e.E.{{.Name}}Iter(typ, retVal.hdr(), dataB, ait, bit)
   137  			} else {
   138  				err = e.E.{{.Name}}Iter(typ, dataA, retVal.hdr(), ait, bit)
   139  			}
   140  		{{end -}}
   141  		}
   142  		{{if not .VV -}}
   143  		if newAlloc{
   144  		freeScalar(scalarHeader.Raw)
   145  		}
   146  		returnHeader(scalarHeader)
   147  		{{end -}}
   148  		return
   149  	}
   150  	switch {
   151  	case incr:
   152  		err = e.E.{{.Name}}Incr(typ, dataA, dataB, dataReuse)
   153  		retVal = reuse
   154  	{{if .VV -}}
   155  	case toReuse:
   156  		err = e.E.{{.Name}}Recv(typ, dataA, dataB, dataReuse)
   157  		retVal = reuse
   158  	{{else -}}
   159  	case toReuse && leftTensor:
   160  		storage.Copy(typ, dataReuse, dataA)
   161  		err = e.E.{{.Name}}(typ, dataReuse, dataB)
   162  		retVal = reuse
   163  	case toReuse && !leftTensor:
   164  		storage.Copy(typ, dataReuse, dataB)
   165  		err = e.E.{{.Name}}(typ, dataA, dataReuse)
   166  		{{if not .VV -}}
   167  		if t.Shape().IsScalarEquiv() {
   168  			storage.Copy(typ, dataReuse, dataA)
   169  		}
   170  		{{end -}}
   171  		retVal = reuse
   172  	{{end -}}
   173  	case !safe:
   174  		err = e.E.{{.Name}}(typ, dataA, dataB)
   175  		{{if not .VV -}}
   176  		if t.Shape().IsScalarEquiv() && !leftTensor {
   177  			storage.Copy(typ, dataB, dataA)
   178  		}
   179  		{{end -}}
   180  		retVal = a
   181  	default:
   182  		{{if .VV -}}
   183  			if swap {
   184  				retVal = b.Clone().(Tensor)
   185  			}else{
   186  				retVal = a.Clone().(Tensor)
   187  			}
   188  			err = e.E.{{.Name}}(typ, retVal.hdr(), dataB)
   189  		{{else -}}
   190  			retVal = a.Clone().(Tensor)
   191  			if !leftTensor {
   192  				storage.Fill(typ, retVal.hdr(), dataA)
   193  			}
   194  			err = e.E.{{.Name}}(typ, retVal.hdr(), dataB)
   195  		{{end -}}
   196  	}
   197  	{{if not .VV -}}
   198  	if newAlloc{
   199  	freeScalar(scalarHeader.Raw)
   200  	}
   201  	returnHeader(scalarHeader)
   202  	{{end -}}
   203  	return
   204  `
   205  
   206  const agg2CmpBodyRaw = `// check to see if anything needs to be created
   207  	switch {
   208  	case same && safe && reuse == nil:
   209  		{{if .VV -}}
   210  		if swap{
   211  			reuse = NewDense(b.Dtype(), b.Shape().Clone(), WithEngine(e))
   212  		} else{
   213  			reuse = NewDense(a.Dtype(), a.Shape().Clone(), WithEngine(e))
   214  		}
   215  		{{else -}}
   216  		reuse = NewDense(a.Dtype(), a.Shape().Clone(), WithEngine(e))
   217  		{{end -}}
   218  		dataReuse = reuse.hdr()
   219  		if useIter{
   220  		iit = IteratorFromDense(reuse)
   221  		}
   222  	case !same && safe && reuse == nil:
   223  		reuse = NewDense(Bool, a.Shape().Clone(), WithEngine(e))
   224  		dataReuse =  reuse.hdr()
   225  		if useIter{
   226  		iit = IteratorFromDense(reuse)
   227  		}
   228  	}
   229  
   230  	if useIter {
   231  		switch {
   232  		case !safe && same && reuse == nil:
   233  			err = e.E.{{.Name}}SameIter(typ, dataA, dataB, ait, bit)
   234  			retVal = a
   235  		{{if .VV -}}
   236  		case same && safe && reuse != nil:
   237  			storage.CopyIter(typ,dataReuse,dataA, iit, ait)
   238  			ait.Reset()
   239  			iit.Reset()
   240  			err = e.E.{{.Name}}SameIter(typ, dataReuse, dataB, iit, bit)
   241  			retVal = reuse
   242  		{{else -}}
   243  		case same && safe && reuse != nil && !leftTensor:
   244  			storage.CopyIter(typ,dataReuse,dataB, iit, bit)
   245  			bit.Reset()
   246  			iit.Reset()
   247  			err = e.E.{{.Name}}SameIter(typ, dataA, dataReuse, ait, bit)
   248  			retVal = reuse
   249  		case same && safe && reuse != nil && leftTensor:
   250  			storage.CopyIter(typ,dataReuse,dataA, iit, ait)
   251  			ait.Reset()
   252  			iit.Reset()
   253  			err = e.E.{{.Name}}SameIter(typ, dataReuse, dataB, iit, bit)
   254  			retVal = reuse
   255  		{{end -}}
   256  		default: // safe && bool
   257  			err = e.E.{{.Name}}Iter(typ, dataA, dataB, dataReuse, ait, bit, iit)
   258  			retVal = reuse
   259  		}
   260  		{{if not .VV -}}
   261  		if newAlloc{
   262  		freeScalar(scalarHeader.Raw)
   263  		}
   264  		returnHeader(scalarHeader)
   265  		{{end -}}
   266  		return
   267  	}
   268  
   269  	{{if not .VV -}}
   270  	// handle special case where A and B have both len 1
   271  	if len(dataA.Raw) == int(typ.Size()) && len(dataB.Raw) == int(typ.Size()) {
   272  		switch {
   273  		case same && safe && reuse != nil && leftTensor:
   274  			storage.Copy(typ,dataReuse,dataA)
   275  			err = e.E.{{.Name}}Same(typ, dataReuse, dataB)
   276  			retVal = reuse
   277  			return
   278  		case same && safe && reuse != nil && !leftTensor:
   279  			storage.Copy(typ,dataReuse,dataB)
   280  			err = e.E.{{.Inv}}Same(typ, dataReuse, dataA)
   281  			retVal = reuse
   282  			return
   283  		}
   284  	}
   285  	{{end -}}
   286  
   287  	// standard
   288  	switch {
   289  		case !safe && same && reuse == nil:
   290  			err = e.E.{{.Name}}Same(typ, dataA, dataB)
   291  			retVal = a
   292  		{{if .VV -}}
   293  		case same && safe && reuse != nil:
   294  			storage.Copy(typ,dataReuse,dataA)
   295  			err = e.E.{{.Name}}Same(typ, dataReuse, dataB)
   296  			retVal = reuse
   297  		{{else -}}
   298  		case same && safe && reuse != nil && leftTensor:
   299  			storage.Copy(typ,dataReuse,dataA)
   300  			err = e.E.{{.Name}}Same(typ, dataReuse, dataB)
   301  			retVal = reuse
   302  		case same && safe && reuse != nil && !leftTensor:
   303  			storage.Copy(typ,dataReuse,dataB)
   304  			err = e.E.{{.Name}}Same(typ, dataA, dataReuse)
   305  			retVal = reuse
   306  		{{end -}}
   307  		default:
   308  			err = e.E.{{.Name}}(typ, dataA, dataB, dataReuse)
   309  			retVal = reuse
   310  	}
   311  	{{if not .VV -}}
   312  	if newAlloc{
   313  		freeScalar(scalarHeader.Raw)
   314  	}
   315  	returnHeader(scalarHeader)
   316  	{{end -}}
   317  	return
   318  `
   319  
   320  const agg2MinMaxBodyRaw = `// check to see if anything needs to be created
   321  	if reuse == nil{
   322  		{{if .VV -}}
   323  		if swap{
   324  			reuse = NewDense(b.Dtype(), b.Shape().Clone(), WithEngine(e))
   325  		} else{
   326  			reuse = NewDense(a.Dtype(), a.Shape().Clone(), WithEngine(e))
   327  		}
   328  		{{else -}}
   329  		reuse = NewDense(a.Dtype(), a.Shape().Clone(), WithEngine(e))
   330  		{{end -}}
   331  		dataReuse = reuse.hdr()
   332  		if useIter{
   333  		iit = IteratorFromDense(reuse)
   334  		}
   335  	}
   336  
   337  
   338  	if useIter {
   339  		switch {
   340  		case !safe && reuse == nil:
   341  			err = e.E.{{.Name}}Iter(typ, dataA, dataB, ait, bit)
   342  			retVal = a
   343  		{{if .VV -}}
   344  		case  safe && reuse != nil:
   345  			storage.CopyIter(typ,dataReuse,dataA, iit, ait)
   346  			ait.Reset()
   347  			iit.Reset()
   348  			err = e.E.{{.Name}}Iter(typ, dataReuse, dataB, iit, bit)
   349  			retVal = reuse
   350  		{{else -}}
   351  		case safe && reuse != nil && !leftTensor:
   352  			storage.CopyIter(typ,dataReuse,dataB, iit, bit)
   353  			bit.Reset()
   354  			iit.Reset()
   355  			err = e.E.{{.Name}}Iter(typ, dataA, dataReuse, ait, bit)
   356  			retVal = reuse
   357  		case safe && reuse != nil && leftTensor:
   358  			storage.CopyIter(typ,dataReuse,dataA, iit, ait)
   359  			ait.Reset()
   360  			iit.Reset()
   361  			err = e.E.{{.Name}}Iter(typ, dataReuse, dataB, iit, bit)
   362  			retVal = reuse
   363  		{{end -}}
   364  		default: // safe && bool
   365  			panic("Unreachable")
   366  		}
   367  		{{if not .VV -}}
   368  		if newAlloc{
   369  		freeScalar(scalarHeader.Raw)
   370  		}
   371  		returnHeader(scalarHeader)
   372  		{{end -}}
   373  		return
   374  	}
   375  
   376  	{{if not .VV -}}
   377  	// handle special case where A and B have both len 1
   378  	if len(dataA.Raw) == int(typ.Size()) && len(dataB.Raw) == int(typ.Size()) {
   379  		switch {
   380  		case safe && reuse != nil && leftTensor:
   381  			storage.Copy(typ,dataReuse,dataA)
   382  			err = e.E.{{.Name}}(typ, dataReuse, dataB)
   383  			retVal = reuse
   384  			return
   385  		case safe && reuse != nil && !leftTensor:
   386  			storage.Copy(typ,dataReuse,dataB)
   387  			err = e.E.{{.Name}}(typ, dataReuse, dataA)
   388  			retVal = reuse
   389  			return
   390  		}
   391  	}
   392  	{{end -}}
   393  
   394  	// standard
   395  	switch {
   396  		case !safe  && reuse == nil:
   397  			err = e.E.{{.Name}}(typ, dataA, dataB)
   398  			retVal = a
   399  		{{if .VV -}}
   400  		case  safe && reuse != nil:
   401  			storage.Copy(typ,dataReuse,dataA)
   402  			err = e.E.{{.Name}}(typ, dataReuse, dataB)
   403  			retVal = reuse
   404  		{{else -}}
   405  		case safe && reuse != nil && leftTensor:
   406  			storage.Copy(typ,dataReuse,dataA)
   407  			err = e.E.{{.Name}}(typ, dataReuse, dataB)
   408  			retVal = reuse
   409  		case safe && reuse != nil && !leftTensor:
   410  			storage.Copy(typ,dataReuse,dataB)
   411  			err = e.E.{{.Name}}(typ, dataA, dataReuse)
   412  			retVal = reuse
   413  		{{end -}}
   414  		default:
   415  			panic("Unreachable")
   416  	}
   417  	{{if not .VV -}}
   418  	if newAlloc{
   419  		freeScalar(scalarHeader.Raw)
   420  	}
   421  	returnHeader(scalarHeader)
   422  	{{end -}}
   423  	return
   424  `
   425  
   426  const agg2UnaryBodyRaw = `
   427  	if useIter{
   428  		switch {
   429  		case incr:
   430  			cloned:= a.Clone().(Tensor)
   431  			if err = e.E.{{.Name}}Iter(typ, cloned.hdr(), ait); err != nil {
   432  				return nil, errors.Wrap(err, "Unable to perform {{.Name}}")
   433  			}
   434  			ait.Reset()
   435  			err = e.E.AddIter(typ, dataReuse, cloned.hdr(), rit, ait)
   436  			retVal = reuse
   437  		case toReuse:
   438  			storage.CopyIter(typ, dataReuse, dataA, rit, ait)
   439  			rit.Reset()
   440  			err = e.E.{{.Name}}Iter(typ, dataReuse, rit)
   441  			retVal = reuse
   442  		case !safe:
   443  			err = e.E.{{.Name}}Iter(typ, dataA, ait)
   444  			retVal = a
   445  		default: // safe by default
   446  			cloned := a.Clone().(Tensor)
   447  			err = e.E.{{.Name}}Iter(typ, cloned.hdr(), ait)
   448  			retVal = cloned
   449  		}
   450  		return
   451  	}
   452  		switch {
   453  		case incr:
   454  			cloned := a.Clone().(Tensor)
   455  			if err = e.E.{{.Name}}(typ, cloned.hdr()); err != nil {
   456  				return nil, errors.Wrap(err, "Unable to perform {{.Name}}")
   457  			}
   458  			err = e.E.Add(typ, dataReuse, cloned.hdr())
   459  			retVal = reuse
   460  		case toReuse:
   461  			storage.Copy(typ, dataReuse, dataA)
   462  			err = e.E.{{.Name}}(typ, dataReuse)
   463  			retVal = reuse
   464  		case !safe:
   465  			err = e.E.{{.Name}}(typ, dataA)
   466  			retVal = a
   467  		default: // safe by default
   468  			cloned := a.Clone().(Tensor)
   469  			err = e.E.{{.Name}}(typ, cloned.hdr())
   470  			retVal = cloned
   471  		}
   472  		return
   473  `
   474  
   475  var (
   476  	prepVV         *template.Template
   477  	prepMixed      *template.Template
   478  	prepUnary      *template.Template
   479  	agg2Body       *template.Template
   480  	agg2CmpBody    *template.Template
   481  	agg2UnaryBody  *template.Template
   482  	agg2MinMaxBody *template.Template
   483  )
   484  
   485  func init() {
   486  	prepVV = template.Must(template.New("prepVV").Funcs(funcs).Parse(prepVVRaw))
   487  	prepMixed = template.Must(template.New("prepMixed").Funcs(funcs).Parse(prepMixedRaw))
   488  	prepUnary = template.Must(template.New("prepUnary").Funcs(funcs).Parse(prepUnaryRaw))
   489  	agg2Body = template.Must(template.New("agg2body").Funcs(funcs).Parse(agg2BodyRaw))
   490  	agg2CmpBody = template.Must(template.New("agg2CmpBody").Funcs(funcs).Parse(agg2CmpBodyRaw))
   491  	agg2UnaryBody = template.Must(template.New("agg2UnaryBody").Funcs(funcs).Parse(agg2UnaryBodyRaw))
   492  	agg2MinMaxBody = template.Must(template.New("agg2MinMaxBody").Funcs(funcs).Parse(agg2MinMaxBodyRaw))
   493  }