gorgonia.org/gorgonia@v0.9.17/operatorLinAlg.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"github.com/chewxy/hm"
     5  	"github.com/pkg/errors"
     6  	"gorgonia.org/tensor"
     7  )
     8  
     9  // ā and Ā are used to denote that it's a matrix/vector type.
    10  // if you want to type it, it's Latin Letter A with Macron (lowercase and capital)
    11  // Codepoints : U+101 for the small one, and U+100 for the capital one
    12  
    13  type āBinaryOperator byte
    14  
    15  const (
    16  	matMulOperator        āBinaryOperator = iota // emits S/DGEMM BLAS calls
    17  	matVecMulOperator                            // emits S/DGEMV BLAS calls
    18  	vecDotOperator                               // emits S/DDOT BLAS calls
    19  	outerProdOperator                            // emits S/DGER BLAS calls
    20  	batchedMatMulOperator                        // just S/GEMM BLAS calls in a loop
    21  
    22  	maxĀBinaryOperator // delimits all possible linalg operators. Add above this line
    23  )
    24  
    25  func (op āBinaryOperator) String() string {
    26  	if op >= maxĀBinaryOperator {
    27  		return "UNSUPPORTED LINEAR ALGEBRA OPERATOR"
    28  	}
    29  	return āBinOpStrs[op]
    30  }
    31  
    32  func (op āBinaryOperator) Type() hm.Type {
    33  	if op >= maxĀBinaryOperator {
    34  		panic("UNSUPPORTED LINEAR ALGEBRA OPERATOR")
    35  	}
    36  	return āBinOpTypes[op]()
    37  }
    38  
    39  func (op āBinaryOperator) DiffWRT(inputs int) []bool {
    40  	if inputs != 2 {
    41  		panic("binary linear algebra operator only supports two and only two inputs")
    42  	}
    43  
    44  	if op >= maxĀBinaryOperator {
    45  		panic("Unsupported unary operator is not differentiable")
    46  	}
    47  	return []bool{true, true}
    48  }
    49  
    50  // todo: write explanation.
    51  func matMulDiffExpr(transA, transB bool, x, y, z, gradZ *Node) (retVal Nodes, err error) {
    52  	var dzdx, dzdy *Node
    53  	op := linAlgBinOp{
    54  		āBinaryOperator: matMulOperator,
    55  	}
    56  
    57  	switch {
    58  	case transA && transB:
    59  		op.transA = transA
    60  		op.transB = transB
    61  		if dzdx, err = binOpNode(op, y, gradZ); err != nil {
    62  			return nil, errors.Wrapf(err, binOpNodeFail, op)
    63  		}
    64  		if dzdy, err = binOpNode(op, gradZ, x); err != nil {
    65  			return nil, errors.Wrapf(err, binOpNodeFail, op)
    66  		}
    67  	case !transA && transB:
    68  		if dzdx, err = binOpNode(op, gradZ, y); err != nil {
    69  			return nil, errors.Wrapf(err, binOpNodeFail, op)
    70  		}
    71  
    72  		op.transA = true
    73  		if dzdy, err = binOpNode(op, gradZ, x); err != nil {
    74  			return nil, errors.Wrapf(err, binOpNodeFail, op)
    75  		}
    76  	case transA && !transB:
    77  		op.transB = true
    78  		if dzdx, err = binOpNode(op, y, gradZ); err != nil {
    79  			return nil, errors.Wrapf(err, binOpNodeFail, op)
    80  		}
    81  
    82  		op.transB = false
    83  		if dzdy, err = binOpNode(op, x, gradZ); err != nil {
    84  			return nil, errors.Wrapf(err, binOpNodeFail, op)
    85  		}
    86  	case !transA && !transB:
    87  		// dzdy
    88  		op.transA = false
    89  		op.transB = true
    90  		if dzdx, err = binOpNode(op, gradZ, y); err != nil {
    91  			return nil, errors.Wrapf(err, binOpNodeFail, op)
    92  		}
    93  		// do dzdx
    94  		op.transA = true
    95  		op.transB = false
    96  		if dzdy, err = binOpNode(op, x, gradZ); err != nil {
    97  			return nil, errors.Wrapf(err, binOpNodeFail, op)
    98  		}
    99  	}
   100  	retVal = Nodes{dzdx, dzdy}
   101  	return
   102  }
   103  
   104  func matMulDiff(ctx ExecutionContext, transA, transB bool, x, y, z *Node) (err error) {
   105  	xdv, ydv, zdv := getDV3(x, y, z)
   106  
   107  	op := linAlgBinOp{
   108  		āBinaryOperator: matMulOperator,
   109  	}
   110  
   111  	switch {
   112  	case transA && transB:
   113  		op.transA = transA
   114  		op.transB = transB
   115  
   116  		// dzdx
   117  		err = op.IncrDo(xdv.d, ydv.Value, zdv.d)
   118  		if err = checkErrSetDeriv(err, xdv); err != nil {
   119  			return errors.Wrapf(err, autodiffFail, x)
   120  		}
   121  
   122  		// dzdy
   123  		err = op.IncrDo(ydv.d, zdv.d, xdv.Value)
   124  		if err = checkErrSetDeriv(err, ydv); err != nil {
   125  			return errors.Wrapf(err, autodiffFail, y)
   126  		}
   127  
   128  		return
   129  
   130  	case !transA && transB:
   131  		// dzdx
   132  		err = op.IncrDo(xdv.d, zdv.d, ydv.Value)
   133  		if err = checkErrSetDeriv(err, xdv); err != nil {
   134  			return errors.Wrapf(err, autodiffFail, x)
   135  		}
   136  
   137  		// dzdy
   138  		op.transA = true
   139  		err = op.IncrDo(ydv.d, zdv.d, xdv.Value)
   140  		if err = checkErrSetDeriv(err, ydv); err != nil {
   141  			return errors.Wrapf(err, autodiffFail, x)
   142  		}
   143  
   144  		return
   145  
   146  	case transA && !transB:
   147  		// dzdx
   148  		op.transB = true
   149  		err = op.IncrDo(xdv.d, ydv.Value, zdv.d)
   150  		if err = checkErrSetDeriv(err, xdv); err != nil {
   151  			return errors.Wrapf(err, autodiffFail, x)
   152  		}
   153  
   154  		// dzdy
   155  		op.transA = false
   156  		op.transB = false
   157  		err = op.IncrDo(ydv.d, xdv.Value, zdv.d)
   158  		if err = checkErrSetDeriv(err, ydv); err != nil {
   159  			return errors.Wrapf(err, autodiffFail, x)
   160  		}
   161  		return
   162  	case !transA && !transB:
   163  		op.transB = true
   164  		err = op.IncrDo(xdv.d, zdv.d, ydv.Value)
   165  		if err = checkErrSetDeriv(err, xdv); err != nil {
   166  			return errors.Wrapf(err, autodiffFail, x)
   167  		}
   168  
   169  		op.transA = true
   170  		op.transB = false
   171  		err = op.IncrDo(ydv.d, xdv.Value, zdv.d)
   172  		if err = checkErrSetDeriv(err, ydv); err != nil {
   173  			return errors.Wrapf(err, autodiffFail, x)
   174  		}
   175  		return
   176  	}
   177  
   178  	panic("unreachable")
   179  }
   180  
   181  func matVecMulDiffExpr(transA, transB bool, x, y, z, gradZ *Node) (retVal Nodes, err error) {
   182  	var dzdx, dzdy *Node
   183  	if transA {
   184  		dzdx, err = OuterProd(y, gradZ)
   185  	} else {
   186  		dzdx, err = OuterProd(gradZ, y)
   187  	}
   188  
   189  	if err != nil {
   190  		return nil, errors.Wrap(err, "Failed to carry outper product")
   191  	}
   192  
   193  	op := linAlgBinOp{
   194  		āBinaryOperator: matVecMulOperator,
   195  		transA:          !transA,
   196  	}
   197  
   198  	if dzdy, err = binOpNode(op, x, gradZ); err != nil {
   199  		return nil, errors.Wrapf(err, binOpNodeFail, op)
   200  	}
   201  	return Nodes{dzdx, dzdy}, nil
   202  }
   203  
   204  func matVecMulDiff(ctx ExecutionContext, transA, transB bool, x, y, z *Node) (err error) {
   205  	xdv, ydv, zdv := getDV3(x, y, z)
   206  
   207  	op := linAlgBinOp{
   208  		āBinaryOperator: outerProdOperator,
   209  	}
   210  
   211  	if transA {
   212  		err = op.IncrDo(xdv.d, ydv.Value, zdv.d)
   213  	} else {
   214  		err = op.IncrDo(xdv.d, zdv.d, ydv.Value)
   215  	}
   216  	if err = checkErrSetDeriv(err, xdv); err != nil {
   217  		return errors.Wrapf(err, autodiffFail, x)
   218  	}
   219  
   220  	op = linAlgBinOp{
   221  		āBinaryOperator: matVecMulOperator,
   222  		transA:          !transA,
   223  	}
   224  
   225  	err = op.IncrDo(ydv.d, xdv.Value, zdv.d)
   226  	if err = checkErrSetDeriv(err, ydv); err != nil {
   227  		return errors.Wrapf(err, autodiffFail, x)
   228  	}
   229  	return
   230  }
   231  
   232  func vecDotDiffExpr(transA, transB bool, x, y, z, gradZ *Node) (retVal Nodes, err error) {
   233  	var dzdx, dzdy *Node
   234  	if dzdx, err = HadamardProd(y, gradZ); err == nil {
   235  		if dzdy, err = HadamardProd(x, gradZ); err == nil {
   236  			retVal = Nodes{dzdx, dzdy}
   237  		} else {
   238  			return nil, errors.Wrap(err, "Failed to carry HadamardProd()")
   239  		}
   240  	} else {
   241  		return nil, errors.Wrap(err, "Failed to carry HadamardProd()")
   242  	}
   243  	return
   244  }
   245  
   246  func vecDotDiff(ctx ExecutionContext, transA, transB bool, x, y, z *Node) (err error) {
   247  	xdv, ydv, zdv := getDV3(x, y, z)
   248  
   249  	mul := newElemBinOp(mulOpType, x, z)
   250  	err = mul.IncrDo(xdv.d, ydv.Value, zdv.d)
   251  	if err = checkErrSetDeriv(err, xdv); err != nil {
   252  		return errors.Wrapf(err, autodiffFail, x)
   253  	}
   254  
   255  	err = mul.IncrDo(ydv.d, xdv.Value, zdv.d)
   256  	if err = checkErrSetDeriv(err, ydv); err != nil {
   257  		return errors.Wrapf(err, autodiffFail, x)
   258  	}
   259  	return
   260  }
   261  
   262  func outerProdDiffExpr(transA, transB bool, x, y, z, gradZ *Node) (retVal Nodes, err error) {
   263  	var dzdx, dzdy *Node
   264  	if dzdx, err = Mul(x, gradZ); err == nil {
   265  		if dzdy, err = Mul(y, gradZ); err == nil {
   266  			retVal = Nodes{dzdx, dzdy}
   267  		} else {
   268  			return nil, errors.Wrap(err, "Failed to carry Mul()")
   269  		}
   270  	} else {
   271  		return nil, errors.Wrap(err, "Failed to carry Mul()")
   272  	}
   273  	return
   274  }
   275  
   276  func outerProdDiff(ctx ExecutionContext, transA, transB bool, x, y, z *Node) (err error) {
   277  	xdv, ydv, zdv := getDV3(x, y, z)
   278  
   279  	mul := newElemBinOp(mulOpType, x, z)
   280  	err = mul.IncrDo(xdv.d, xdv.Value, zdv.d)
   281  	err = mul.IncrDo(xdv.d, ydv.Value, zdv.d)
   282  	if err = checkErrSetDeriv(err, xdv); err != nil {
   283  		return errors.Wrapf(err, autodiffFail, x)
   284  	}
   285  
   286  	err = mul.IncrDo(ydv.d, ydv.Value, zdv.d)
   287  	if err = checkErrSetDeriv(err, ydv); err != nil {
   288  		return errors.Wrapf(err, autodiffFail, x)
   289  	}
   290  	return
   291  }
   292  
   293  func batchedMatMulDiffExpr(transA, transB bool, x, y, z, gradZ *Node) (retVal Nodes, err error) {
   294  	var dzdx, dzdy *Node
   295  	op := linAlgBinOp{
   296  		āBinaryOperator: batchedMatMulOperator,
   297  	}
   298  
   299  	switch {
   300  	case transA && transB:
   301  		op.transA = transA
   302  		op.transB = transB
   303  		if dzdx, err = binOpNode(op, y, gradZ); err != nil {
   304  			return nil, errors.Wrapf(err, binOpNodeFail, op)
   305  		}
   306  		if dzdy, err = binOpNode(op, gradZ, x); err != nil {
   307  			return nil, errors.Wrapf(err, binOpNodeFail, op)
   308  		}
   309  	case !transA && transB:
   310  		if dzdx, err = binOpNode(op, gradZ, y); err != nil {
   311  			return nil, errors.Wrapf(err, binOpNodeFail, op)
   312  		}
   313  
   314  		op.transA = true
   315  		if dzdy, err = binOpNode(op, gradZ, x); err != nil {
   316  			return nil, errors.Wrapf(err, binOpNodeFail, op)
   317  		}
   318  	case transA && !transB:
   319  		op.transB = true
   320  		if dzdx, err = binOpNode(op, y, gradZ); err != nil {
   321  			return nil, errors.Wrapf(err, binOpNodeFail, op)
   322  		}
   323  
   324  		op.transB = false
   325  		if dzdy, err = binOpNode(op, x, gradZ); err != nil {
   326  			return nil, errors.Wrapf(err, binOpNodeFail, op)
   327  		}
   328  	case !transA && !transB:
   329  		// dzdy
   330  		op.transA = false
   331  		op.transB = true
   332  		if dzdx, err = binOpNode(op, gradZ, y); err != nil {
   333  			return nil, errors.Wrapf(err, binOpNodeFail, op)
   334  		}
   335  		// do dzdx
   336  		op.transA = true
   337  		op.transB = false
   338  		if dzdy, err = binOpNode(op, x, gradZ); err != nil {
   339  			return nil, errors.Wrapf(err, binOpNodeFail, op)
   340  		}
   341  	}
   342  	retVal = Nodes{dzdx, dzdy}
   343  	return
   344  }
   345  
   346  func batchedMatMulDiff(ctx ExecutionContext, transA, transB bool, x, y, z *Node) (err error) {
   347  	xdv, ydv, zdv := getDV3(x, y, z)
   348  
   349  	op := linAlgBinOp{
   350  		āBinaryOperator: batchedMatMulOperator,
   351  	}
   352  
   353  	switch {
   354  	case transA && transB:
   355  		op.transA = transA
   356  		op.transB = transB
   357  
   358  		// dzdx
   359  		err = op.IncrDo(xdv.d, ydv.Value, zdv.d)
   360  		if err = checkErrSetDeriv(err, xdv); err != nil {
   361  			return errors.Wrapf(err, autodiffFail, x)
   362  		}
   363  
   364  		// dzdy
   365  		err = op.IncrDo(ydv.d, zdv.d, xdv.Value)
   366  		if err = checkErrSetDeriv(err, ydv); err != nil {
   367  			return errors.Wrapf(err, autodiffFail, y)
   368  		}
   369  
   370  		return
   371  
   372  	case !transA && transB:
   373  		// dzdx
   374  		err = op.IncrDo(xdv.d, zdv.d, ydv.Value)
   375  		if err = checkErrSetDeriv(err, xdv); err != nil {
   376  			return errors.Wrapf(err, autodiffFail, x)
   377  		}
   378  
   379  		// dzdy
   380  		op.transA = true
   381  		err = op.IncrDo(ydv.d, zdv.d, xdv.Value)
   382  		if err = checkErrSetDeriv(err, ydv); err != nil {
   383  			return errors.Wrapf(err, autodiffFail, x)
   384  		}
   385  
   386  		return
   387  
   388  	case transA && !transB:
   389  		// dzdx
   390  		op.transB = true
   391  		err = op.IncrDo(xdv.d, ydv.Value, zdv.d)
   392  		if err = checkErrSetDeriv(err, xdv); err != nil {
   393  			return errors.Wrapf(err, autodiffFail, x)
   394  		}
   395  
   396  		// dzdy
   397  		op.transA = false
   398  		op.transB = false
   399  		err = op.IncrDo(ydv.d, xdv.Value, zdv.d)
   400  		if err = checkErrSetDeriv(err, ydv); err != nil {
   401  			return errors.Wrapf(err, autodiffFail, x)
   402  		}
   403  		return
   404  	case !transA && !transB:
   405  		op.transB = true
   406  		err = op.IncrDo(xdv.d, zdv.d, ydv.Value)
   407  		if err = checkErrSetDeriv(err, xdv); err != nil {
   408  			return errors.Wrapf(err, autodiffFail, x)
   409  		}
   410  
   411  		op.transA = true
   412  		op.transB = false
   413  		err = op.IncrDo(ydv.d, xdv.Value, zdv.d)
   414  		if err = checkErrSetDeriv(err, ydv); err != nil {
   415  			return errors.Wrapf(err, autodiffFail, x)
   416  		}
   417  		return
   418  	}
   419  
   420  	panic("unreachable")
   421  }
   422  
   423  func batchedMatMul(a, b, c tensor.Tensor, transA, transB, incr bool) (retVal tensor.Tensor, err error) {
   424  	shapeA := a.Shape().Clone()
   425  	shapeB := b.Shape().Clone()
   426  	outer := shapeA[:len(shapeA)-2]
   427  	innerA := shapeA[len(shapeA)-2:]
   428  	innerB := shapeB[len(shapeB)-2:]
   429  
   430  	if c == nil {
   431  		newShape := append(outer, innerA[0], innerB[1])
   432  		c = tensor.New(tensor.Of(a.Dtype()), tensor.WithShape(newShape...), tensor.WithEngine(a.Engine()))
   433  	}
   434  
   435  	slices := make([]sli, len(outer))
   436  	ss := make([]tensor.Slice, len(slices))
   437  	for i := range slices {
   438  		slices[i].end = slices[i].start + 1
   439  		ss[i] = &slices[i]
   440  	}
   441  
   442  	var as, bs, cs tensor.Tensor
   443  	for halt := false; !halt; halt = incrSlices(slices, outer) {
   444  		if as, err = a.Slice(ss...); err != nil {
   445  			return nil, errors.Wrapf(err, "Slicing %v from a failed", ss)
   446  		}
   447  		if bs, err = b.Slice(ss...); err != nil {
   448  			return nil, errors.Wrapf(err, "Slicing %v from b failed", ss)
   449  		}
   450  		if cs, err = c.Slice(ss...); err != nil {
   451  			return nil, errors.Wrapf(err, "Slicing %v from c failed", ss)
   452  		}
   453  
   454  		if transA {
   455  			as.T()
   456  		}
   457  		if transB {
   458  			bs.T()
   459  		}
   460  
   461  		var fo tensor.FuncOpt
   462  		if incr {
   463  			fo = tensor.WithIncr(cs)
   464  		} else {
   465  			fo = tensor.WithReuse(cs)
   466  		}
   467  
   468  		if _, err = tensor.MatMul(as, bs, fo); err != nil {
   469  			return nil, errors.Wrapf(err, "MatMul on batch %v failed.", ss)
   470  		}
   471  
   472  	}
   473  
   474  	return c, nil
   475  }
   476  
   477  // incrSlices increments the slices. If everything has matched then return true
   478  func incrSlices(a []sli, shp tensor.Shape) (halt bool) {
   479  	for i := len(a) - 1; i >= 0; i-- {
   480  		if shp[i]-a[i].start == 1 {
   481  			a[i].start = 0
   482  			a[i].end = 1
   483  			if i == 0 {
   484  				return true
   485  			}
   486  			continue
   487  		}
   488  
   489  		a[i].start++
   490  		a[i].end = a[i].start + 1
   491  		return false
   492  	}
   493  	return true
   494  }