gorgonia.org/gorgonia@v0.9.17/op_sparsemax.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 "hash" 6 "math" 7 "sort" 8 9 "github.com/chewxy/hm" 10 "github.com/pkg/errors" 11 "gorgonia.org/tensor" 12 ) 13 14 type sparsemaxOp struct { 15 axis int 16 } 17 18 func newSparsemaxOp(axes ...int) *sparsemaxOp { 19 axis := -1 20 if len(axes) > 0 { 21 axis = axes[0] 22 } 23 24 sparsemaxop := &sparsemaxOp{ 25 axis: axis, 26 } 27 28 return sparsemaxop 29 } 30 31 // Sparsemax - implements the sparsemax operation described here: http://proceedings.mlr.press/v48/martins16.pdf 32 func Sparsemax(x *Node, axes ...int) (*Node, error) { 33 op := newSparsemaxOp(axes...) 34 35 return ApplyOp(op, x) 36 } 37 38 func (op *sparsemaxOp) Arity() int { 39 return 1 40 } 41 42 func (op *sparsemaxOp) ReturnsPtr() bool { return false } 43 44 func (op *sparsemaxOp) CallsExtern() bool { return false } 45 46 func (op *sparsemaxOp) WriteHash(h hash.Hash) { 47 fmt.Fprintf(h, "Sparsemax{}()") 48 } 49 50 func (op *sparsemaxOp) Hashcode() uint32 { return simpleHash(op) } 51 52 func (op *sparsemaxOp) String() string { 53 return fmt.Sprintf("Sparsemax{}()") 54 } 55 56 func (op *sparsemaxOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) { 57 s := inputs[0].(tensor.Shape).Clone() 58 return s, nil 59 } 60 61 func (op *sparsemaxOp) Type() hm.Type { 62 a := hm.TypeVariable('a') 63 return hm.NewFnType(a, a) 64 } 65 66 func (op *sparsemaxOp) OverwritesInput() int { return -1 } 67 68 func (op *sparsemaxOp) checkInput(inputs ...Value) (tensor.Tensor, error) { 69 if err := checkArity(op, len(inputs)); err != nil { 70 return nil, err 71 } 72 73 var in tensor.Tensor 74 var ok bool 75 76 if in, ok = inputs[0].(tensor.Tensor); !ok { 77 return nil, errors.Errorf("Expected input to be a tensor, got %T", inputs[0]) 78 } 79 80 return in, nil 81 } 82 83 func (op *sparsemaxOp) Do(inputs ...Value) (Value, error) { 84 inputTensor, err := op.checkInput(inputs...) 85 if err != nil { 86 return nil, fmt.Errorf("Can't check Sparsemax input: %w", err) 87 } 88 89 inputShape := inputTensor.Shape() 90 91 if op.axis != -1 { 92 axes := make([]int, inputTensor.Dims()) 93 axes[op.axis] = 1 94 95 inputTensor, err = tensor.Transpose(inputTensor, axes...) 96 if err != nil { 97 return nil, fmt.Errorf("error tranposing the input tensor: %w", err) 98 } 99 } 100 101 var output interface{} 102 103 switch inputTensor.Dtype() { 104 case tensor.Float64: 105 output = op.float64sparseMax(inputTensor) 106 case tensor.Float32: 107 output = op.float32sparseMax(inputTensor) 108 default: 109 return nil, fmt.Errorf("invalid input type for Sparsemax, expected float64 or float32, got: %v", inputTensor.Dtype()) 110 } 111 112 return tensor.New(tensor.Of(inputTensor.Dtype()), tensor.WithShape(inputShape.Clone()...), tensor.WithEngine(inputTensor.Engine()), tensor.WithBacking(output)), nil 113 } 114 115 // FIXME: go2 generics 116 func (op *sparsemaxOp) float32sparseMax(inputTensor tensor.Tensor) interface{} { 117 inputData := inputTensor.Data().([]float32) 118 dims := inputTensor.Dims() 119 it := 0 120 121 to := inputTensor.Shape()[dims-1] 122 from := tensor.Shape(inputTensor.Shape()[0 : dims-1]).TotalSize() 123 if from == 0 { 124 from = 1 125 } 126 127 maxValues := make([]float32, from) 128 129 for i := 0; i < from; i++ { 130 maxValue := float32(-math.MaxFloat32) 131 132 for j := 0; j < to; j++ { 133 if inputData[it] > maxValue { 134 maxValue = inputData[it] 135 } 136 137 it++ 138 } 139 140 maxValues[i] = maxValue 141 } 142 143 // this is math trick for numerical stability 144 stableInput := make([]float32, len(inputData)) 145 it = 0 146 147 for i := 0; i < from; i++ { 148 for j := 0; j < to; j++ { 149 stableInput[it] = inputData[it] - maxValues[i] 150 it++ 151 } 152 } 153 154 sortedData := make([]float32, len(inputData)) 155 copy(sortedData, stableInput) 156 157 sort.Slice(sortedData, func(i, j int) bool { 158 return sortedData[i] > sortedData[j] 159 }) 160 161 thresholds := make([]float32, from) 162 it = 0 163 164 for i := 0; i < from; i++ { 165 cumSum := float32(0.0) 166 prevCum := float32(0.0) 167 maxIndex := 0 168 169 for j := 0; j < to; j++ { 170 k := 1 + float32(j+1)*sortedData[it] 171 172 prevCum += sortedData[it] 173 174 if k > prevCum { 175 maxIndex = j + 1 176 177 cumSum += sortedData[i] 178 } 179 180 it++ 181 } 182 183 thresholds[i] = (cumSum - 1) / float32(maxIndex) 184 } 185 186 output := make([]float32, len(stableInput)) 187 it = 0 188 189 for i := 0; i < from; i++ { 190 for j := 0; j < to; j++ { 191 vF := stableInput[it] 192 193 if vF-thresholds[i] > 0 { 194 output[it] = vF - thresholds[i] 195 } 196 197 it++ 198 } 199 } 200 201 return output 202 } 203 204 func (op *sparsemaxOp) float64sparseMax(inputTensor tensor.Tensor) interface{} { 205 inputData := inputTensor.Data().([]float64) 206 dims := inputTensor.Dims() 207 it := 0 208 209 to := inputTensor.Shape()[dims-1] 210 from := tensor.Shape(inputTensor.Shape()[0 : dims-1]).TotalSize() 211 if from == 0 { 212 from = 1 213 } 214 215 maxValues := make([]float64, from) 216 217 for i := 0; i < from; i++ { 218 maxValue := -math.MaxFloat64 219 220 for j := 0; j < to; j++ { 221 if inputData[it] > maxValue { 222 maxValue = inputData[it] 223 } 224 225 it++ 226 } 227 228 maxValues[i] = maxValue 229 } 230 231 // this is math trick for numerical stability 232 stableInput := make([]float64, len(inputData)) 233 it = 0 234 235 for i := 0; i < from; i++ { 236 for j := 0; j < to; j++ { 237 stableInput[it] = inputData[it] - maxValues[i] 238 it++ 239 } 240 } 241 242 sortedData := make([]float64, len(inputData)) 243 copy(sortedData, stableInput) 244 245 sort.Slice(sortedData, func(i, j int) bool { 246 return sortedData[i] > sortedData[j] 247 }) 248 249 thresholds := make([]float64, from) 250 it = 0 251 252 for i := 0; i < from; i++ { 253 cumSum := 0.0 254 prevCum := 0.0 255 maxIndex := 0 256 257 for j := 0; j < to; j++ { 258 k := 1 + float64(j+1)*sortedData[it] 259 260 prevCum += sortedData[it] 261 262 if k > prevCum { 263 maxIndex = j + 1 264 265 cumSum += sortedData[i] 266 } 267 268 it++ 269 } 270 271 thresholds[i] = (cumSum - 1) / float64(maxIndex) 272 } 273 274 output := make([]float64, len(stableInput)) 275 it = 0 276 277 for i := 0; i < from; i++ { 278 for j := 0; j < to; j++ { 279 vF := stableInput[it] 280 281 if vF-thresholds[i] > 0 { 282 output[it] = vF - thresholds[i] 283 } 284 285 it++ 286 } 287 } 288 289 return output 290 } 291 292 // DoDiff calculates the diff and sets its value to the output node. Implementation for ADOp interface. 293 func (op *sparsemaxOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) error { 294 if len(inputs) != 2 { 295 return fmt.Errorf("SparsemaxOp.DoDiff needs 2 arguments") 296 } 297 298 odv := output.boundTo.(*dualValue) 299 odvd := odv.Value.(tensor.Tensor) 300 diffOp := &sparsemaxDiffOp{} 301 302 result, err := diffOp.Do(inputs[0].boundTo, inputs[1].boundTo) 303 if err != nil { 304 return err 305 } 306 307 err = result.(*tensor.Dense).Reshape(odvd.Shape()...) 308 if err != nil { 309 return err 310 } 311 312 sum, err := odvd.(*tensor.Dense).Add(result.(*tensor.Dense), tensor.UseUnsafe()) 313 if err != nil { 314 return err 315 } 316 317 odv.d = sum 318 319 return nil 320 } 321 322 // SymDiff applies the diff op. Implementation for SDOp interface. 323 func (op *sparsemaxOp) SymDiff(inputs Nodes, output, grad *Node) (Nodes, error) { 324 err := checkArity(op, len(inputs)) 325 if err != nil { 326 return nil, err 327 } 328 329 t := inputs[0] 330 331 diffOp := &sparsemaxDiffOp{} 332 nodes := make(Nodes, 1) 333 334 nodes[0], err = ApplyOp(diffOp, t, grad) 335 336 return nodes, err 337 } 338 339 // DiffWRT is an implementation for the SDOp interface 340 func (op *sparsemaxOp) DiffWRT(inputs int) []bool { 341 if inputs != 1 { 342 panic(fmt.Sprintf("sparsemax operator only supports one input, got %d instead", inputs)) 343 } 344 345 return []bool{true} 346 } 347 348 type sparsemaxDiffOp struct { 349 } 350 351 func newSparsemaxOpDiff() *sparsemaxDiffOp { 352 return &sparsemaxDiffOp{} 353 } 354 355 func (op *sparsemaxDiffOp) Arity() int { 356 return 2 357 } 358 359 func (op *sparsemaxDiffOp) ReturnsPtr() bool { return false } 360 361 func (op *sparsemaxDiffOp) CallsExtern() bool { return false } 362 363 func (op *sparsemaxDiffOp) WriteHash(h hash.Hash) { 364 fmt.Fprintf(h, "SparsemaxDiff{}()") 365 } 366 367 func (op *sparsemaxDiffOp) Hashcode() uint32 { return simpleHash(op) } 368 369 func (op *sparsemaxDiffOp) String() string { 370 return fmt.Sprintf("SparsemaxDiff{}()") 371 } 372 373 func (op *sparsemaxDiffOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) { 374 s := inputs[0].(tensor.Shape).Clone() 375 376 return s, nil 377 } 378 379 func (op *sparsemaxDiffOp) Type() hm.Type { 380 a := hm.TypeVariable('a') 381 return hm.NewFnType(a, a, a) 382 } 383 384 func (op *sparsemaxDiffOp) OverwritesInput() int { return -1 } 385 386 func (op *sparsemaxDiffOp) checkInput(inputs ...Value) (*tensor.Dense, *tensor.Dense, error) { 387 if err := checkArity(op, len(inputs)); err != nil { 388 return nil, nil, err 389 } 390 391 var ( 392 in *tensor.Dense 393 394 gradient *tensor.Dense 395 ok bool 396 ) 397 398 switch t := inputs[0].(type) { 399 case *dualValue: 400 if in, ok = t.Value.(*tensor.Dense); !ok { 401 return nil, nil, errors.Errorf("input should be a tensor.Tensor, got %T", inputs[0]) 402 } 403 case *tensor.Dense: 404 in = t 405 default: 406 return nil, nil, errors.Errorf("input type is not supported, got %T", inputs[0]) 407 } 408 409 switch t := inputs[1].(type) { 410 case *dualValue: 411 if gradient, ok = t.Value.(*tensor.Dense); !ok { 412 return nil, nil, errors.Errorf("gradient should be a tensor, got %T", inputs[1]) 413 } 414 case *tensor.Dense: 415 gradient = t 416 default: 417 return nil, nil, errors.Errorf("gradient type is not supported, got %T", inputs[1]) 418 } 419 420 return in, gradient, nil 421 } 422 423 func (op *sparsemaxDiffOp) mul(a tensor.Tensor, b tensor.Tensor) (tensor.Tensor, error) { 424 if a.Dims() != b.Dims() { 425 return tensor.Outer(a, b) 426 } 427 428 return tensor.Mul(a, b) 429 } 430 431 func (op *sparsemaxDiffOp) Do(inputs ...Value) (Value, error) { 432 inputTensor, gradTensor, err := op.checkInput(inputs...) 433 if err != nil { 434 return nil, fmt.Errorf("Can't check SparsemaxDiff input: %w", err) 435 } 436 437 if inputTensor.Size() != gradTensor.Size() { 438 return nil, fmt.Errorf("sparsemaxDiffOp.Do inputs sizes should be equal") 439 } 440 441 var zero interface{} 442 443 if inputTensor.Dtype() == tensor.Float32 { 444 zero = float32(0.0) 445 } else { 446 zero = float64(0.0) 447 } 448 449 nonZeros, err := inputTensor.ElNeScalar(zero, false, tensor.AsSameType()) 450 if err != nil { 451 return nil, fmt.Errorf("sparsemaxDiffOp.Do failed to get non-zeros: %w", err) 452 } 453 454 mul, err := op.mul(nonZeros, gradTensor) 455 if err != nil { 456 return nil, fmt.Errorf("sparsemaxDiffOp.Do failed to mul grad tensor: %w", err) 457 } 458 459 a, err := tensor.Sum(mul, 1) 460 if err != nil { 461 return nil, err 462 } 463 464 b, err := tensor.Sum(nonZeros, 1) 465 if err != nil { 466 return nil, err 467 } 468 469 sum, err := tensor.Div(a, b) 470 if err != nil { 471 return nil, err 472 } 473 474 if sum.Dims() == 1 && gradTensor.Dims() == 2 { 475 err := sum.Reshape(sum.Shape()[0], 1) 476 if err != nil { 477 return nil, err 478 } 479 480 sum, err = tensor.Repeat(sum, 1, gradTensor.Shape()[1]) 481 if err != nil { 482 panic(err) 483 } 484 } 485 486 sub, err := tensor.Sub(gradTensor, sum) 487 if err != nil { 488 return nil, err 489 } 490 491 result, err := op.mul(nonZeros, sub) 492 if err != nil { 493 return nil, err 494 } 495 496 return result, nil 497 } 498 499 // ensure it complies with the Op interface 500 var ( 501 _ Op = &sparsemaxDiffOp{} 502 503 _ Op = &sparsemaxOp{} 504 _ SDOp = &sparsemaxOp{} 505 _ ADOp = &sparsemaxOp{} 506 )