gorgonia.org/gorgonia@v0.9.17/cuda/utils.go (about)

     1  package cuda
     2  
     3  import (
     4  	"fmt"
     5  	"log"
     6  
     7  	"github.com/pkg/errors"
     8  	"gorgonia.org/tensor"
     9  )
    10  
    11  func getDenseTensor(t tensor.Tensor) (tensor.DenseTensor, error) {
    12  	switch tt := t.(type) {
    13  	case tensor.DenseTensor:
    14  		return tt, nil
    15  	case tensor.Densor:
    16  		return tt.Dense(), nil
    17  	default:
    18  		return nil, errors.Errorf("Tensor %T is not a DenseTensor", t)
    19  	}
    20  }
    21  
    22  func handleFuncOpts(expShape tensor.Shape, expType tensor.Dtype, o tensor.DataOrder, strict bool, opts ...tensor.FuncOpt) (reuse tensor.DenseTensor, safe, toReuse, incr, same bool, err error) {
    23  	fo := tensor.ParseFuncOpts(opts...)
    24  
    25  	reuseT, incr := fo.IncrReuse()
    26  	safe = fo.Safe()
    27  	same = fo.Same()
    28  	toReuse = reuseT != nil
    29  
    30  	if toReuse {
    31  		if reuse, err = getDenseTensor(reuseT); err != nil {
    32  			err = errors.Wrapf(err, "Expected a tensor.DenseTensor")
    33  			return
    34  		}
    35  
    36  		if (strict || same) && reuse.Dtype() != expType {
    37  			err = errors.Errorf(typeMismatch, expType, reuse.Dtype())
    38  			err = errors.Wrapf(err, "Cannot use reuse")
    39  			return
    40  		}
    41  
    42  		if reuse.DataSize() != expShape.TotalSize() && !expShape.IsScalar() {
    43  			log.Printf("REUSE CHECK reuse shape %v, expected Shape %v", reuse.Shape(), expShape)
    44  			err = errors.Errorf(shapeMismatch, reuse.Shape(), expShape)
    45  			err = errors.Wrapf(err, "Cannot use reuse: shape mismatch - reuse.len() %v, expShape.TotalSize() %v", reuse.DataSize(), expShape.TotalSize())
    46  			return
    47  		}
    48  
    49  		if !incr && reuse != nil {
    50  			// reuse.setDataOrder(o)
    51  			// err = reuse.reshape(expShape...)
    52  		}
    53  
    54  	}
    55  	return
    56  }
    57  
    58  func binaryCheck(a, b tensor.Tensor) (err error) {
    59  	at := a.Dtype()
    60  	bt := b.Dtype()
    61  
    62  	switch at {
    63  	case tensor.Float32, tensor.Float64:
    64  	default:
    65  		return errors.Errorf("Unsupported Dtype for a: %v", at)
    66  	}
    67  
    68  	switch bt {
    69  	case tensor.Float32, tensor.Float64:
    70  	default:
    71  		return errors.Errorf("Unsupported Dtype for b: %v", bt)
    72  	}
    73  
    74  	if at.Kind() != bt.Kind() {
    75  		return errors.Errorf(typeMismatch, at, bt)
    76  	}
    77  
    78  	if !a.Shape().Eq(b.Shape()) {
    79  		log.Printf("BINARY CHECK %v %v", a.Shape(), b.Shape())
    80  		return errors.Errorf(shapeMismatch, b.Shape(), a.Shape())
    81  	}
    82  
    83  	if a.RequiresIterator() {
    84  		return errors.New("unsupported operation: a requires an iterator")
    85  	}
    86  
    87  	if b.RequiresIterator() {
    88  		return errors.New("unsupported operation: b requires an iterator")
    89  	}
    90  	return nil
    91  }
    92  
    93  func unaryCheck(a tensor.Tensor) error {
    94  	at := a.Dtype()
    95  	switch at {
    96  	case tensor.Float32, tensor.Float64:
    97  	default:
    98  		return errors.Errorf("Unsupported Dtype for a: %v", at)
    99  	}
   100  
   101  	if a.RequiresIterator() {
   102  		return errors.New("unsupported operation: a requires an iterator")
   103  	}
   104  	return nil
   105  }
   106  
   107  func logicalSize(s tensor.Shape) int {
   108  	if s.IsScalar() {
   109  		return 1
   110  	}
   111  	return s.TotalSize()
   112  }
   113  
   114  func constructName2(a, b tensor.Tensor, fn string) (name string) {
   115  	dt := a.Dtype()
   116  	as := a.Shape()
   117  	bs := b.Shape()
   118  	switch {
   119  	case as.IsScalar() && bs.IsScalar():
   120  		name = fmt.Sprintf("%v.%s_ss_f%d", elemBinOpMod, fn, int(dt.Size()*8))
   121  	case as.IsScalar() && !bs.IsScalar():
   122  		name = fmt.Sprintf("%v.%s_sv_f%d", elemBinOpMod, fn, int(dt.Size()*8))
   123  	case !as.IsScalar() && bs.IsScalar():
   124  		name = fmt.Sprintf("%v.%s_vs_f%d", elemBinOpMod, fn, int(dt.Size()*8))
   125  	default:
   126  		name = fmt.Sprintf("%v.%s_vv_f%d", elemBinOpMod, fn, int(dt.Size()*8))
   127  	}
   128  	return
   129  }
   130  
   131  func constructName1(a tensor.Tensor, leftTensor bool, fn string) (name string) {
   132  	dt := a.Dtype()
   133  	if leftTensor {
   134  		name = fmt.Sprintf("%v.%s_vs_f%d", elemBinOpMod, fn, int(dt.Size()*8))
   135  	} else {
   136  		name = fmt.Sprintf("%v.%s_sv_f%d", elemBinOpMod, fn, int(dt.Size()*8))
   137  	}
   138  	return
   139  }