gorgonia.org/tensor@v0.9.24/internal/execution/eng_arith_manual.go (about)

     1  package execution
     2  
     3  import (
     4  	"reflect"
     5  
     6  	"github.com/pkg/errors"
     7  	"gorgonia.org/tensor/internal/storage"
     8  )
     9  
    10  func (e E) AddSliced(t reflect.Type, dataA *storage.Header, dstStart, dstEnd int, dataB *storage.Header, srcStart, srcEnd int) (err error) {
    11  	ds := dstStart * int(t.Size())
    12  	de := dstEnd * int(t.Size())
    13  	a := &storage.Header{
    14  		Raw: dataA.Raw[ds:de],
    15  	}
    16  
    17  	ss := srcStart * int(t.Size())
    18  	se := srcEnd * int(t.Size())
    19  	b := &storage.Header{
    20  		Raw: dataB.Raw[ss:se],
    21  	}
    22  
    23  	as := isScalar(a, t)
    24  	bs := isScalar(b, t)
    25  
    26  	switch t {
    27  	case Int:
    28  		at := a.Ints()
    29  		bt := b.Ints()
    30  
    31  		switch {
    32  		case as && bs:
    33  			VecAddI(at, bt)
    34  		case as && !bs:
    35  			AddSVI(at[0], bt)
    36  		case !as && bs:
    37  			AddVSI(at, bt[0])
    38  		default:
    39  			VecAddI(at, bt)
    40  		}
    41  		return
    42  	case Int8:
    43  		at := a.Int8s()
    44  		bt := b.Int8s()
    45  		switch {
    46  		case as && bs:
    47  			VecAddI8(at, bt)
    48  		case as && !bs:
    49  			AddSVI8(at[0], bt)
    50  		case !as && bs:
    51  			AddVSI8(at, bt[0])
    52  		default:
    53  			VecAddI8(at, bt)
    54  		}
    55  		return
    56  	case Int16:
    57  		at := a.Int16s()
    58  		bt := b.Int16s()
    59  		switch {
    60  		case as && bs:
    61  			VecAddI16(at, bt)
    62  		case as && !bs:
    63  			AddSVI16(at[0], bt)
    64  		case !as && bs:
    65  			AddVSI16(at, bt[0])
    66  		default:
    67  			VecAddI16(at, bt)
    68  		}
    69  		return
    70  	case Int32:
    71  		at := a.Int32s()
    72  		bt := b.Int32s()
    73  		switch {
    74  		case as && bs:
    75  			VecAddI32(at, bt)
    76  		case as && !bs:
    77  			AddSVI32(at[0], bt)
    78  		case !as && bs:
    79  			AddVSI32(at, bt[0])
    80  		default:
    81  			VecAddI32(at, bt)
    82  		}
    83  		return
    84  	case Int64:
    85  		at := a.Int64s()
    86  		bt := b.Int64s()
    87  		switch {
    88  		case as && bs:
    89  			VecAddI64(at, bt)
    90  		case as && !bs:
    91  			AddSVI64(at[0], bt)
    92  		case !as && bs:
    93  			AddVSI64(at, bt[0])
    94  		default:
    95  			VecAddI64(at, bt)
    96  		}
    97  		return
    98  	case Uint:
    99  		at := a.Uints()
   100  		bt := b.Uints()
   101  		switch {
   102  		case as && bs:
   103  			VecAddU(at, bt)
   104  		case as && !bs:
   105  			AddSVU(at[0], bt)
   106  		case !as && bs:
   107  			AddVSU(at, bt[0])
   108  		default:
   109  			VecAddU(at, bt)
   110  		}
   111  		return
   112  	case Uint8:
   113  		at := a.Uint8s()
   114  		bt := b.Uint8s()
   115  		switch {
   116  		case as && bs:
   117  			VecAddU8(at, bt)
   118  		case as && !bs:
   119  			AddSVU8(at[0], bt)
   120  		case !as && bs:
   121  			AddVSU8(at, bt[0])
   122  		default:
   123  			VecAddU8(at, bt)
   124  		}
   125  		return
   126  	case Uint16:
   127  		at := a.Uint16s()
   128  		bt := b.Uint16s()
   129  		switch {
   130  		case as && bs:
   131  			VecAddU16(at, bt)
   132  		case as && !bs:
   133  			AddSVU16(at[0], bt)
   134  		case !as && bs:
   135  			AddVSU16(at, bt[0])
   136  		default:
   137  			VecAddU16(at, bt)
   138  		}
   139  		return
   140  	case Uint32:
   141  		at := a.Uint32s()
   142  		bt := b.Uint32s()
   143  		switch {
   144  		case as && bs:
   145  			VecAddU32(at, bt)
   146  		case as && !bs:
   147  			AddSVU32(at[0], bt)
   148  		case !as && bs:
   149  			AddVSU32(at, bt[0])
   150  		default:
   151  			VecAddU32(at, bt)
   152  		}
   153  		return
   154  	case Uint64:
   155  		at := a.Uint64s()
   156  		bt := b.Uint64s()
   157  		switch {
   158  		case as && bs:
   159  			VecAddU64(at, bt)
   160  		case as && !bs:
   161  			AddSVU64(at[0], bt)
   162  		case !as && bs:
   163  			AddVSU64(at, bt[0])
   164  		default:
   165  			VecAddU64(at, bt)
   166  		}
   167  		return
   168  	case Float32:
   169  		at := a.Float32s()
   170  		bt := b.Float32s()
   171  		switch {
   172  		case as && bs:
   173  			VecAddF32(at, bt)
   174  		case as && !bs:
   175  			AddSVF32(at[0], bt)
   176  		case !as && bs:
   177  			AddVSF32(at, bt[0])
   178  		default:
   179  			VecAddF32(at, bt)
   180  		}
   181  		return
   182  	case Float64:
   183  		at := a.Float64s()
   184  		bt := b.Float64s()
   185  		switch {
   186  		case as && bs:
   187  			VecAddF64(at, bt)
   188  		case as && !bs:
   189  			AddSVF64(at[0], bt)
   190  		case !as && bs:
   191  			AddVSF64(at, bt[0])
   192  		default:
   193  			VecAddF64(at, bt)
   194  		}
   195  		return
   196  	case Complex64:
   197  		at := a.Complex64s()
   198  		bt := b.Complex64s()
   199  		switch {
   200  		case as && bs:
   201  			VecAddC64(at, bt)
   202  		case as && !bs:
   203  			AddSVC64(at[0], bt)
   204  		case !as && bs:
   205  			AddVSC64(at, bt[0])
   206  		default:
   207  			VecAddC64(at, bt)
   208  		}
   209  		return
   210  	case Complex128:
   211  		at := a.Complex128s()
   212  		bt := b.Complex128s()
   213  		switch {
   214  		case as && bs:
   215  			VecAddC128(at, bt)
   216  		case as && !bs:
   217  			AddSVC128(at[0], bt)
   218  		case !as && bs:
   219  			AddVSC128(at, bt[0])
   220  		default:
   221  			VecAddC128(at, bt)
   222  		}
   223  		return
   224  	case String:
   225  		at := a.Strings()
   226  		bt := b.Strings()
   227  		switch {
   228  		case as && bs:
   229  			VecAddStr(at, bt)
   230  		case as && !bs:
   231  			AddSVStr(at[0], bt)
   232  		case !as && bs:
   233  			AddVSStr(at, bt[0])
   234  		default:
   235  			VecAddStr(at, bt)
   236  		}
   237  		return
   238  	default:
   239  		return errors.Errorf("Unsupported type %v for Add", t)
   240  	}
   241  }