github.com/wzzhu/tensor@v0.9.24/dense_norms.go (about) 1 package tensor 2 3 import ( 4 "math" 5 6 "github.com/chewxy/math32" 7 "github.com/pkg/errors" 8 ) 9 10 func (t *Dense) multiSVDNorm(rowAxis, colAxis int) (retVal *Dense, err error) { 11 if rowAxis > colAxis { 12 rowAxis-- 13 } 14 dims := t.Dims() 15 16 if retVal, err = t.RollAxis(colAxis, dims, true); err != nil { 17 return 18 } 19 20 if retVal, err = retVal.RollAxis(rowAxis, dims, true); err != nil { 21 return 22 } 23 24 // manual, since SVD only works on matrices. In the future, this needs to be fixed when gonum's lapack works for float32 25 // TODO: SVDFuture 26 switch dims { 27 case 2: 28 retVal, _, _, err = retVal.SVD(false, false) 29 case 3: 30 toStack := make([]*Dense, retVal.Shape()[0]) 31 for i := 0; i < retVal.Shape()[0]; i++ { 32 var sliced, ithS *Dense 33 if sliced, err = sliceDense(retVal, ss(i)); err != nil { 34 return 35 } 36 37 if ithS, _, _, err = sliced.SVD(false, false); err != nil { 38 return 39 } 40 41 toStack[i] = ithS 42 } 43 44 retVal, err = toStack[0].Stack(0, toStack[1:]...) 45 return 46 default: 47 err = errors.Errorf("multiSVDNorm for dimensions greater than 3") 48 } 49 50 return 51 } 52 53 // Norm returns the p-ordered norm of the *Dense, given the axes. 54 // 55 // This implementation is directly adapted from Numpy, which is licenced under a BSD-like licence, and can be found here: https://docs.scipy.org/doc/numpy-1.9.1/license.html 56 func (t *Dense) Norm(ord NormOrder, axes ...int) (retVal *Dense, err error) { 57 var ret Tensor 58 var ok bool 59 var abs, norm0, normN interface{} 60 var oneOverOrd interface{} 61 switch t.t { 62 case Float64: 63 abs = math.Abs 64 norm0 = func(x float64) float64 { 65 if x != 0 { 66 return 1 67 } 68 return 0 69 } 70 normN = func(x float64) float64 { 71 return math.Pow(math.Abs(x), float64(ord)) 72 } 73 oneOverOrd = float64(1) / float64(ord) 74 case Float32: 75 abs = math32.Abs 76 norm0 = func(x float32) float32 { 77 if x != 0 { 78 return 1 79 } 80 return 0 81 } 82 normN = func(x float32) float32 { 83 return math32.Pow(math32.Abs(x), float32(ord)) 84 } 85 oneOverOrd = float32(1) / float32(ord) 86 default: 87 err = errors.Errorf("Norms only works on float types") 88 return 89 } 90 91 dims := t.Dims() 92 93 // simple case 94 if len(axes) == 0 { 95 if ord.IsUnordered() || (ord.IsFrobenius() && dims == 2) || (ord == Norm(2) && dims == 1) { 96 backup := t.AP 97 ap := makeAP(1) 98 defer ap.zero() 99 100 ap.unlock() 101 ap.SetShape(t.Size()) 102 ap.lock() 103 104 t.AP = ap 105 if ret, err = Dot(t, t); err != nil { // returns a scalar 106 err = errors.Wrapf(err, opFail, "Norm-0") 107 return 108 } 109 if retVal, ok = ret.(*Dense); !ok { 110 return nil, errors.Errorf(opFail, "Norm-0") 111 } 112 113 switch t.t { 114 case Float64: 115 retVal.SetF64(0, math.Sqrt(retVal.GetF64(0))) 116 case Float32: 117 retVal.SetF32(0, math32.Sqrt(retVal.GetF32(0))) 118 } 119 t.AP = backup 120 return 121 } 122 123 axes = make([]int, dims) 124 for i := range axes { 125 axes[i] = i 126 } 127 } 128 129 switch len(axes) { 130 case 1: 131 cloned := t.Clone().(*Dense) 132 switch { 133 case ord.IsUnordered() || ord == Norm(2): 134 if ret, err = Square(cloned); err != nil { 135 return 136 } 137 138 if retVal, ok = ret.(*Dense); !ok { 139 return nil, errors.Errorf(opFail, "UnorderedNorm-1") 140 } 141 142 if retVal, err = retVal.Sum(axes...); err != nil { 143 return 144 } 145 146 if ret, err = Sqrt(retVal); err != nil { 147 return 148 } 149 return assertDense(ret) 150 case ord.IsInf(1): 151 if ret, err = cloned.Apply(abs); err != nil { 152 return 153 } 154 if retVal, ok = ret.(*Dense); !ok { 155 return nil, errors.Errorf(opFail, "InfNorm-1") 156 } 157 return retVal.Max(axes...) 158 case ord.IsInf(-1): 159 if ret, err = cloned.Apply(abs); err != nil { 160 return 161 } 162 if retVal, ok = ret.(*Dense); !ok { 163 return nil, errors.Errorf(opFail, "-InfNorm-1") 164 } 165 return retVal.Min(axes...) 166 case ord == Norm(0): 167 if ret, err = cloned.Apply(norm0); err != nil { 168 return 169 } 170 if retVal, ok = ret.(*Dense); !ok { 171 return nil, errors.Errorf(opFail, "Norm-0") 172 } 173 return retVal.Sum(axes...) 174 case ord == Norm(1): 175 if ret, err = cloned.Apply(abs); err != nil { 176 return 177 } 178 if retVal, ok = ret.(*Dense); !ok { 179 return nil, errors.Errorf(opFail, "Norm-1") 180 } 181 return retVal.Sum(axes...) 182 default: 183 if ret, err = cloned.Apply(normN); err != nil { 184 return 185 } 186 if retVal, ok = ret.(*Dense); !ok { 187 return nil, errors.Errorf(opFail, "Norm-N") 188 } 189 190 if retVal, err = retVal.Sum(axes...); err != nil { 191 return 192 } 193 return retVal.PowScalar(oneOverOrd, true) 194 } 195 case 2: 196 rowAxis := axes[0] 197 colAxis := axes[1] 198 199 // checks 200 if rowAxis < 0 { 201 return nil, errors.Errorf("Row Axis %d is < 0", rowAxis) 202 } 203 if colAxis < 0 { 204 return nil, errors.Errorf("Col Axis %d is < 0", colAxis) 205 } 206 207 if rowAxis == colAxis { 208 return nil, errors.Errorf("Duplicate axes found. Row Axis: %d, Col Axis %d", rowAxis, colAxis) 209 } 210 211 cloned := t.Clone().(*Dense) 212 switch { 213 case ord == Norm(2): 214 // svd norm 215 if retVal, err = t.multiSVDNorm(rowAxis, colAxis); err != nil { 216 return nil, errors.Wrapf(err, opFail, "MultiSVDNorm, case 2 with Ord == Norm(2)") 217 } 218 dims := retVal.Dims() 219 return retVal.Max(dims - 1) 220 case ord == Norm(-2): 221 // svd norm 222 if retVal, err = t.multiSVDNorm(rowAxis, colAxis); err != nil { 223 return nil, errors.Wrapf(err, opFail, "MultiSVDNorm, case 2 with Ord == Norm(-2)") 224 } 225 dims := retVal.Dims() 226 return retVal.Min(dims - 1) 227 case ord == Norm(1): 228 if colAxis > rowAxis { 229 colAxis-- 230 } 231 if ret, err = cloned.Apply(abs); err != nil { 232 return nil, errors.Wrapf(err, opFail, "Apply abs in Norm. ord == Norm(1") 233 } 234 if retVal, err = assertDense(ret); err != nil { 235 return nil, errors.Wrapf(err, opFail, "Norm-1, axis=2") 236 } 237 if retVal, err = retVal.Sum(rowAxis); err != nil { 238 return 239 } 240 return retVal.Max(colAxis) 241 case ord == Norm(-1): 242 if colAxis > rowAxis { 243 colAxis-- 244 } 245 if ret, err = cloned.Apply(abs); err != nil { 246 return 247 } 248 if retVal, err = assertDense(ret); err != nil { 249 return nil, errors.Wrapf(err, opFail, "Norm-(-1), axis=2") 250 } 251 if retVal, err = retVal.Sum(rowAxis); err != nil { 252 return 253 } 254 return retVal.Min(colAxis) 255 case ord == Norm(0): 256 return nil, errors.Errorf("Norm of order 0 undefined for matrices") 257 case ord.IsInf(1): 258 if rowAxis > colAxis { 259 rowAxis-- 260 } 261 if ret, err = cloned.Apply(abs); err != nil { 262 return 263 } 264 if retVal, err = assertDense(ret); err != nil { 265 return nil, errors.Wrapf(err, opFail, "InfNorm, axis=2") 266 } 267 if retVal, err = retVal.Sum(colAxis); err != nil { 268 return nil, errors.Wrapf(err, "Sum in infNorm") 269 } 270 return retVal.Max(rowAxis) 271 case ord.IsInf(-1): 272 if rowAxis > colAxis { 273 rowAxis-- 274 } 275 if ret, err = cloned.Apply(abs); err != nil { 276 return 277 } 278 if retVal, err = assertDense(ret); err != nil { 279 return nil, errors.Wrapf(err, opFail, "-InfNorm, axis=2") 280 } 281 if retVal, err = retVal.Sum(colAxis); err != nil { 282 return nil, errors.Wrapf(err, opFail, "Sum with InfNorm") 283 } 284 return retVal.Min(rowAxis) 285 case ord.IsUnordered() || ord.IsFrobenius(): 286 if ret, err = cloned.Apply(abs); err != nil { 287 return 288 } 289 if retVal, ok = ret.(*Dense); !ok { 290 return nil, errors.Errorf(opFail, "Frobenius Norm, axis = 2") 291 } 292 if ret, err = Square(retVal); err != nil { 293 return 294 } 295 if retVal, err = assertDense(ret); err != nil { 296 return nil, errors.Wrapf(err, opFail, "Norm-0, axis=2") 297 } 298 if retVal, err = retVal.Sum(axes...); err != nil { 299 return 300 } 301 if ret, err = Sqrt(retVal); err != nil { 302 return 303 } 304 return assertDense(ret) 305 case ord.IsNuclear(): 306 // svd norm 307 if retVal, err = t.multiSVDNorm(rowAxis, colAxis); err != nil { 308 return 309 } 310 return retVal.Sum(len(t.Shape()) - 1) 311 case ord == Norm(0): 312 err = errors.Errorf("Norm order 0 undefined for matrices") 313 return 314 default: 315 return nil, errors.Errorf("Not yet implemented: Norm for Axes %v, ord %v", axes, ord) 316 } 317 default: 318 err = errors.Errorf(dimMismatch, 2, len(axes)) 319 return 320 } 321 panic("Unreachable") 322 }