gorgonia.org/tensor@v0.9.24/defaultengine_prep.go (about)

     1  package tensor
     2  
     3  import (
     4  	"reflect"
     5  
     6  	"github.com/pkg/errors"
     7  	"gorgonia.org/tensor/internal/storage"
     8  	// "log"
     9  )
    10  
    11  func handleFuncOpts(expShape Shape, expType Dtype, o DataOrder, strict bool, opts ...FuncOpt) (reuse DenseTensor, safe, toReuse, incr, same bool, err error) {
    12  	fo := ParseFuncOpts(opts...)
    13  
    14  	reuseT, incr := fo.IncrReuse()
    15  	safe = fo.Safe()
    16  	same = fo.Same()
    17  	toReuse = reuseT != nil
    18  
    19  	if toReuse {
    20  		if reuse, err = getDenseTensor(reuseT); err != nil {
    21  			returnOpOpt(fo)
    22  			err = errors.Wrapf(err, "Cannot reuse a Tensor that isn't a DenseTensor. Got %T instead", reuseT)
    23  			return
    24  		}
    25  
    26  		if reuse != nil && !reuse.IsNativelyAccessible() {
    27  			returnOpOpt(fo)
    28  			err = errors.Errorf(inaccessibleData, reuse)
    29  			return
    30  		}
    31  
    32  		if (strict || same) && reuse.Dtype() != expType {
    33  			returnOpOpt(fo)
    34  			err = errors.Errorf(typeMismatch, expType, reuse.Dtype())
    35  			err = errors.Wrapf(err, "Cannot use reuse")
    36  			return
    37  		}
    38  
    39  		if reuse.len() != expShape.TotalSize() && !expShape.IsScalar() {
    40  			returnOpOpt(fo)
    41  			err = errors.Errorf(shapeMismatch, reuse.Shape(), expShape)
    42  			err = errors.Wrapf(err, "Cannot use reuse: shape mismatch - reuse.len() %v, expShape.TotalSize() %v", reuse.len(), expShape.TotalSize())
    43  			return
    44  		}
    45  		if !reuse.Shape().Eq(expShape) {
    46  			cloned := expShape.Clone()
    47  			if err = reuse.Reshape(cloned...); err != nil {
    48  				return
    49  
    50  			}
    51  			ReturnInts([]int(cloned))
    52  		}
    53  
    54  		if !incr && reuse != nil {
    55  			reuse.setDataOrder(o)
    56  			// err = reuse.reshape(expShape...)
    57  		}
    58  
    59  	}
    60  	returnOpOpt(fo)
    61  	return
    62  }
    63  
    64  func binaryCheck(a, b Tensor, tc *typeclass) (err error) {
    65  	// check if the tensors are accessible
    66  	if !a.IsNativelyAccessible() {
    67  		return errors.Errorf(inaccessibleData, a)
    68  	}
    69  
    70  	if !b.IsNativelyAccessible() {
    71  		return errors.Errorf(inaccessibleData, b)
    72  	}
    73  
    74  	at := a.Dtype()
    75  	bt := b.Dtype()
    76  	if tc != nil {
    77  		if err = typeclassCheck(at, tc); err != nil {
    78  			return errors.Wrapf(err, typeclassMismatch, "a")
    79  		}
    80  		if err = typeclassCheck(bt, tc); err != nil {
    81  			return errors.Wrapf(err, typeclassMismatch, "b")
    82  		}
    83  	}
    84  
    85  	if at.Kind() != bt.Kind() {
    86  		return errors.Errorf(typeMismatch, at, bt)
    87  	}
    88  	if !a.Shape().Eq(b.Shape()) {
    89  		return errors.Errorf(shapeMismatch, b.Shape(), a.Shape())
    90  	}
    91  	return nil
    92  }
    93  
    94  func unaryCheck(a Tensor, tc *typeclass) error {
    95  	if !a.IsNativelyAccessible() {
    96  		return errors.Errorf(inaccessibleData, a)
    97  	}
    98  	at := a.Dtype()
    99  	if tc != nil {
   100  		if err := typeclassCheck(at, tc); err != nil {
   101  			return errors.Wrapf(err, typeclassMismatch, "a")
   102  		}
   103  	}
   104  	return nil
   105  }
   106  
   107  // scalarDtypeCheck checks that a scalar value has the same dtype as the dtype of a given tensor.
   108  func scalarDtypeCheck(a Tensor, b interface{}) error {
   109  	var dt Dtype
   110  	switch bt := b.(type) {
   111  	case Dtyper:
   112  		dt = bt.Dtype()
   113  	default:
   114  		t := reflect.TypeOf(b)
   115  		dt = Dtype{t}
   116  	}
   117  
   118  	if a.Dtype() != dt {
   119  		return errors.Errorf("Expected scalar to have the same Dtype as the tensor (%v). Got %T instead ", a.Dtype(), b)
   120  	}
   121  	return nil
   122  }
   123  
   124  // prepDataVV prepares the data given the input and reuse tensors. It also retruns several indicators
   125  //
   126  // useIter indicates that the iterator methods should be used.
   127  // swap indicates that the operands are swapped.
   128  func prepDataVV(a, b Tensor, reuse Tensor) (dataA, dataB, dataReuse *storage.Header, ait, bit, iit Iterator, useIter, swap bool, err error) {
   129  	// get data
   130  	dataA = a.hdr()
   131  	dataB = b.hdr()
   132  	if reuse != nil {
   133  		dataReuse = reuse.hdr()
   134  	}
   135  
   136  	// iter
   137  	useIter = a.RequiresIterator() ||
   138  		b.RequiresIterator() ||
   139  		(reuse != nil && reuse.RequiresIterator()) ||
   140  		!a.DataOrder().HasSameOrder(b.DataOrder()) ||
   141  		(reuse != nil && (!a.DataOrder().HasSameOrder(reuse.DataOrder()) || !b.DataOrder().HasSameOrder(reuse.DataOrder())))
   142  	if useIter {
   143  		ait = a.Iterator()
   144  		bit = b.Iterator()
   145  		if reuse != nil {
   146  			iit = reuse.Iterator()
   147  		}
   148  	}
   149  
   150  	// swap
   151  	if _, ok := a.(*CS); ok {
   152  		if _, ok := b.(DenseTensor); ok {
   153  			swap = true
   154  			dataA, dataB = dataB, dataA
   155  			ait, bit = bit, ait
   156  		}
   157  	}
   158  
   159  	return
   160  }
   161  
   162  func prepDataVS(a Tensor, b interface{}, reuse Tensor) (dataA, dataB, dataReuse *storage.Header, ait, iit Iterator, useIter bool, newAlloc bool, err error) {
   163  	// get data
   164  	dataA = a.hdr()
   165  	dataB, newAlloc = scalarToHeader(b)
   166  	if reuse != nil {
   167  		dataReuse = reuse.hdr()
   168  	}
   169  
   170  	if a.IsScalar() {
   171  		return
   172  	}
   173  	useIter = a.RequiresIterator() ||
   174  		(reuse != nil && reuse.RequiresIterator()) ||
   175  		(reuse != nil && !reuse.DataOrder().HasSameOrder(a.DataOrder()))
   176  	if useIter {
   177  		ait = a.Iterator()
   178  		if reuse != nil {
   179  			iit = reuse.Iterator()
   180  		}
   181  	}
   182  	return
   183  }
   184  
   185  func prepDataSV(a interface{}, b Tensor, reuse Tensor) (dataA, dataB, dataReuse *storage.Header, bit, iit Iterator, useIter bool, newAlloc bool, err error) {
   186  	// get data
   187  	dataA, newAlloc = scalarToHeader(a)
   188  	dataB = b.hdr()
   189  	if reuse != nil {
   190  		dataReuse = reuse.hdr()
   191  	}
   192  
   193  	// get iterator
   194  	if b.IsScalar() {
   195  		return
   196  	}
   197  	useIter = b.RequiresIterator() ||
   198  		(reuse != nil && reuse.RequiresIterator()) ||
   199  		(reuse != nil && !reuse.DataOrder().HasSameOrder(b.DataOrder()))
   200  
   201  	if useIter {
   202  		bit = b.Iterator()
   203  		if reuse != nil {
   204  			iit = reuse.Iterator()
   205  		}
   206  	}
   207  	return
   208  }
   209  
   210  func prepDataUnary(a Tensor, reuse Tensor) (dataA, dataReuse *storage.Header, ait, rit Iterator, useIter bool, err error) {
   211  	// get data
   212  	dataA = a.hdr()
   213  	if reuse != nil {
   214  		dataReuse = reuse.hdr()
   215  	}
   216  
   217  	// get iterator
   218  	if a.RequiresIterator() || (reuse != nil && reuse.RequiresIterator()) {
   219  		ait = a.Iterator()
   220  		if reuse != nil {
   221  			rit = reuse.Iterator()
   222  		}
   223  		useIter = true
   224  	}
   225  	return
   226  }