Skip to content

Commit 8fb58a9

Browse files
committed
Removed the allocation as suggested by @dcu
1 parent bd80a32 commit 8fb58a9

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

defaultengine_softmax.go

+16-15
Original file line numberDiff line numberDiff line change
@@ -264,22 +264,22 @@ func (e StdEng) softMaxBLastDimF64(inputGrad, output, grad Tensor, axis int, log
264264
dx[i] = gradArr[i] - (math.Exp(outputArr[i]) * sum)
265265
}
266266
} else {
267-
mul := make([]float64, dimSize)
268-
267+
//mul := make([]float64, dimSize)
268+
var sum float64
269269
for j := 0; j < dimSize; j++ {
270270
i := ii*dimSize + j
271271

272-
mul[j] = outputArr[i] * gradArr[i]
272+
// mul[j] = outputArr[i] * gradArr[i]
273+
sum += outputArr[i] * gradArr[i]
273274
}
274275

275-
sum := mul[0]
276-
for j := 1; j < dimSize; j++ {
277-
sum += mul[j]
278-
}
276+
//sum := mul[0]
277+
//for j := 1; j < dimSize; j++ {
278+
// sum += mul[j]
279+
//}
279280

280281
for j := 0; j < dimSize; j++ {
281282
i := ii*dimSize + j
282-
283283
dx[i] = (gradArr[i] - sum) * outputArr[i]
284284
}
285285
}
@@ -481,18 +481,19 @@ func (e StdEng) softMaxBLastDimF32(inputGrad, output, grad Tensor, axis int, log
481481
dx[i] = gradArr[i] - (math32.Exp(outputArr[i]) * sum)
482482
}
483483
} else {
484-
mul := make([]float32, dimSize)
485-
484+
// mul := make([]float32, dimSize)
485+
var sum float32
486486
for j := 0; j < dimSize; j++ {
487487
i := ii*dimSize + j
488488

489-
mul[j] = outputArr[i] * gradArr[i]
489+
//mul[j] = outputArr[i] * gradArr[i]
490+
sum += outputArr[i] * gradArr[i]
490491
}
491492

492-
sum := mul[0]
493-
for j := 1; j < dimSize; j++ {
494-
sum += mul[j]
495-
}
493+
// sum := mul[0]
494+
// for j := 1; j < dimSize; j++ {
495+
// sum += mul[j]
496+
//}
496497

497498
for j := 0; j < dimSize; j++ {
498499
i := ii*dimSize + j

0 commit comments

Comments
 (0)