github.com/wzzhu/tensor@v0.9.24/dense_norms.go (about)

     1  package tensor
     2  
     3  import (
     4  	"math"
     5  
     6  	"github.com/chewxy/math32"
     7  	"github.com/pkg/errors"
     8  )
     9  
    10  func (t *Dense) multiSVDNorm(rowAxis, colAxis int) (retVal *Dense, err error) {
    11  	if rowAxis > colAxis {
    12  		rowAxis--
    13  	}
    14  	dims := t.Dims()
    15  
    16  	if retVal, err = t.RollAxis(colAxis, dims, true); err != nil {
    17  		return
    18  	}
    19  
    20  	if retVal, err = retVal.RollAxis(rowAxis, dims, true); err != nil {
    21  		return
    22  	}
    23  
    24  	// manual, since SVD only works on matrices. In the future, this needs to be fixed when gonum's lapack works for float32
    25  	// TODO: SVDFuture
    26  	switch dims {
    27  	case 2:
    28  		retVal, _, _, err = retVal.SVD(false, false)
    29  	case 3:
    30  		toStack := make([]*Dense, retVal.Shape()[0])
    31  		for i := 0; i < retVal.Shape()[0]; i++ {
    32  			var sliced, ithS *Dense
    33  			if sliced, err = sliceDense(retVal, ss(i)); err != nil {
    34  				return
    35  			}
    36  
    37  			if ithS, _, _, err = sliced.SVD(false, false); err != nil {
    38  				return
    39  			}
    40  
    41  			toStack[i] = ithS
    42  		}
    43  
    44  		retVal, err = toStack[0].Stack(0, toStack[1:]...)
    45  		return
    46  	default:
    47  		err = errors.Errorf("multiSVDNorm for dimensions greater than 3")
    48  	}
    49  
    50  	return
    51  }
    52  
    53  // Norm returns the p-ordered norm of the *Dense, given the axes.
    54  //
    55  // This implementation is directly adapted from Numpy, which is licenced under a BSD-like licence, and can be found here: https://docs.scipy.org/doc/numpy-1.9.1/license.html
    56  func (t *Dense) Norm(ord NormOrder, axes ...int) (retVal *Dense, err error) {
    57  	var ret Tensor
    58  	var ok bool
    59  	var abs, norm0, normN interface{}
    60  	var oneOverOrd interface{}
    61  	switch t.t {
    62  	case Float64:
    63  		abs = math.Abs
    64  		norm0 = func(x float64) float64 {
    65  			if x != 0 {
    66  				return 1
    67  			}
    68  			return 0
    69  		}
    70  		normN = func(x float64) float64 {
    71  			return math.Pow(math.Abs(x), float64(ord))
    72  		}
    73  		oneOverOrd = float64(1) / float64(ord)
    74  	case Float32:
    75  		abs = math32.Abs
    76  		norm0 = func(x float32) float32 {
    77  			if x != 0 {
    78  				return 1
    79  			}
    80  			return 0
    81  		}
    82  		normN = func(x float32) float32 {
    83  			return math32.Pow(math32.Abs(x), float32(ord))
    84  		}
    85  		oneOverOrd = float32(1) / float32(ord)
    86  	default:
    87  		err = errors.Errorf("Norms only works on float types")
    88  		return
    89  	}
    90  
    91  	dims := t.Dims()
    92  
    93  	// simple case
    94  	if len(axes) == 0 {
    95  		if ord.IsUnordered() || (ord.IsFrobenius() && dims == 2) || (ord == Norm(2) && dims == 1) {
    96  			backup := t.AP
    97  			ap := makeAP(1)
    98  			defer ap.zero()
    99  
   100  			ap.unlock()
   101  			ap.SetShape(t.Size())
   102  			ap.lock()
   103  
   104  			t.AP = ap
   105  			if ret, err = Dot(t, t); err != nil { // returns a scalar
   106  				err = errors.Wrapf(err, opFail, "Norm-0")
   107  				return
   108  			}
   109  			if retVal, ok = ret.(*Dense); !ok {
   110  				return nil, errors.Errorf(opFail, "Norm-0")
   111  			}
   112  
   113  			switch t.t {
   114  			case Float64:
   115  				retVal.SetF64(0, math.Sqrt(retVal.GetF64(0)))
   116  			case Float32:
   117  				retVal.SetF32(0, math32.Sqrt(retVal.GetF32(0)))
   118  			}
   119  			t.AP = backup
   120  			return
   121  		}
   122  
   123  		axes = make([]int, dims)
   124  		for i := range axes {
   125  			axes[i] = i
   126  		}
   127  	}
   128  
   129  	switch len(axes) {
   130  	case 1:
   131  		cloned := t.Clone().(*Dense)
   132  		switch {
   133  		case ord.IsUnordered() || ord == Norm(2):
   134  			if ret, err = Square(cloned); err != nil {
   135  				return
   136  			}
   137  
   138  			if retVal, ok = ret.(*Dense); !ok {
   139  				return nil, errors.Errorf(opFail, "UnorderedNorm-1")
   140  			}
   141  
   142  			if retVal, err = retVal.Sum(axes...); err != nil {
   143  				return
   144  			}
   145  
   146  			if ret, err = Sqrt(retVal); err != nil {
   147  				return
   148  			}
   149  			return assertDense(ret)
   150  		case ord.IsInf(1):
   151  			if ret, err = cloned.Apply(abs); err != nil {
   152  				return
   153  			}
   154  			if retVal, ok = ret.(*Dense); !ok {
   155  				return nil, errors.Errorf(opFail, "InfNorm-1")
   156  			}
   157  			return retVal.Max(axes...)
   158  		case ord.IsInf(-1):
   159  			if ret, err = cloned.Apply(abs); err != nil {
   160  				return
   161  			}
   162  			if retVal, ok = ret.(*Dense); !ok {
   163  				return nil, errors.Errorf(opFail, "-InfNorm-1")
   164  			}
   165  			return retVal.Min(axes...)
   166  		case ord == Norm(0):
   167  			if ret, err = cloned.Apply(norm0); err != nil {
   168  				return
   169  			}
   170  			if retVal, ok = ret.(*Dense); !ok {
   171  				return nil, errors.Errorf(opFail, "Norm-0")
   172  			}
   173  			return retVal.Sum(axes...)
   174  		case ord == Norm(1):
   175  			if ret, err = cloned.Apply(abs); err != nil {
   176  				return
   177  			}
   178  			if retVal, ok = ret.(*Dense); !ok {
   179  				return nil, errors.Errorf(opFail, "Norm-1")
   180  			}
   181  			return retVal.Sum(axes...)
   182  		default:
   183  			if ret, err = cloned.Apply(normN); err != nil {
   184  				return
   185  			}
   186  			if retVal, ok = ret.(*Dense); !ok {
   187  				return nil, errors.Errorf(opFail, "Norm-N")
   188  			}
   189  
   190  			if retVal, err = retVal.Sum(axes...); err != nil {
   191  				return
   192  			}
   193  			return retVal.PowScalar(oneOverOrd, true)
   194  		}
   195  	case 2:
   196  		rowAxis := axes[0]
   197  		colAxis := axes[1]
   198  
   199  		// checks
   200  		if rowAxis < 0 {
   201  			return nil, errors.Errorf("Row Axis %d is < 0", rowAxis)
   202  		}
   203  		if colAxis < 0 {
   204  			return nil, errors.Errorf("Col Axis %d is < 0", colAxis)
   205  		}
   206  
   207  		if rowAxis == colAxis {
   208  			return nil, errors.Errorf("Duplicate axes found. Row Axis: %d, Col Axis %d", rowAxis, colAxis)
   209  		}
   210  
   211  		cloned := t.Clone().(*Dense)
   212  		switch {
   213  		case ord == Norm(2):
   214  			// svd norm
   215  			if retVal, err = t.multiSVDNorm(rowAxis, colAxis); err != nil {
   216  				return nil, errors.Wrapf(err, opFail, "MultiSVDNorm, case 2 with Ord == Norm(2)")
   217  			}
   218  			dims := retVal.Dims()
   219  			return retVal.Max(dims - 1)
   220  		case ord == Norm(-2):
   221  			// svd norm
   222  			if retVal, err = t.multiSVDNorm(rowAxis, colAxis); err != nil {
   223  				return nil, errors.Wrapf(err, opFail, "MultiSVDNorm, case 2 with Ord == Norm(-2)")
   224  			}
   225  			dims := retVal.Dims()
   226  			return retVal.Min(dims - 1)
   227  		case ord == Norm(1):
   228  			if colAxis > rowAxis {
   229  				colAxis--
   230  			}
   231  			if ret, err = cloned.Apply(abs); err != nil {
   232  				return nil, errors.Wrapf(err, opFail, "Apply abs in Norm. ord == Norm(1")
   233  			}
   234  			if retVal, err = assertDense(ret); err != nil {
   235  				return nil, errors.Wrapf(err, opFail, "Norm-1, axis=2")
   236  			}
   237  			if retVal, err = retVal.Sum(rowAxis); err != nil {
   238  				return
   239  			}
   240  			return retVal.Max(colAxis)
   241  		case ord == Norm(-1):
   242  			if colAxis > rowAxis {
   243  				colAxis--
   244  			}
   245  			if ret, err = cloned.Apply(abs); err != nil {
   246  				return
   247  			}
   248  			if retVal, err = assertDense(ret); err != nil {
   249  				return nil, errors.Wrapf(err, opFail, "Norm-(-1), axis=2")
   250  			}
   251  			if retVal, err = retVal.Sum(rowAxis); err != nil {
   252  				return
   253  			}
   254  			return retVal.Min(colAxis)
   255  		case ord == Norm(0):
   256  			return nil, errors.Errorf("Norm of order 0 undefined for matrices")
   257  		case ord.IsInf(1):
   258  			if rowAxis > colAxis {
   259  				rowAxis--
   260  			}
   261  			if ret, err = cloned.Apply(abs); err != nil {
   262  				return
   263  			}
   264  			if retVal, err = assertDense(ret); err != nil {
   265  				return nil, errors.Wrapf(err, opFail, "InfNorm, axis=2")
   266  			}
   267  			if retVal, err = retVal.Sum(colAxis); err != nil {
   268  				return nil, errors.Wrapf(err, "Sum in infNorm")
   269  			}
   270  			return retVal.Max(rowAxis)
   271  		case ord.IsInf(-1):
   272  			if rowAxis > colAxis {
   273  				rowAxis--
   274  			}
   275  			if ret, err = cloned.Apply(abs); err != nil {
   276  				return
   277  			}
   278  			if retVal, err = assertDense(ret); err != nil {
   279  				return nil, errors.Wrapf(err, opFail, "-InfNorm, axis=2")
   280  			}
   281  			if retVal, err = retVal.Sum(colAxis); err != nil {
   282  				return nil, errors.Wrapf(err, opFail, "Sum with InfNorm")
   283  			}
   284  			return retVal.Min(rowAxis)
   285  		case ord.IsUnordered() || ord.IsFrobenius():
   286  			if ret, err = cloned.Apply(abs); err != nil {
   287  				return
   288  			}
   289  			if retVal, ok = ret.(*Dense); !ok {
   290  				return nil, errors.Errorf(opFail, "Frobenius Norm, axis = 2")
   291  			}
   292  			if ret, err = Square(retVal); err != nil {
   293  				return
   294  			}
   295  			if retVal, err = assertDense(ret); err != nil {
   296  				return nil, errors.Wrapf(err, opFail, "Norm-0, axis=2")
   297  			}
   298  			if retVal, err = retVal.Sum(axes...); err != nil {
   299  				return
   300  			}
   301  			if ret, err = Sqrt(retVal); err != nil {
   302  				return
   303  			}
   304  			return assertDense(ret)
   305  		case ord.IsNuclear():
   306  			// svd norm
   307  			if retVal, err = t.multiSVDNorm(rowAxis, colAxis); err != nil {
   308  				return
   309  			}
   310  			return retVal.Sum(len(t.Shape()) - 1)
   311  		case ord == Norm(0):
   312  			err = errors.Errorf("Norm order 0 undefined for matrices")
   313  			return
   314  		default:
   315  			return nil, errors.Errorf("Not yet implemented: Norm for Axes %v, ord %v", axes, ord)
   316  		}
   317  	default:
   318  		err = errors.Errorf(dimMismatch, 2, len(axes))
   319  		return
   320  	}
   321  	panic("Unreachable")
   322  }