Skip to content

Commit 4a08402

Browse files
committedOct 18, 2021
Merge remote-tracking branch 'origin/master' into optimizations
2 parents 2fcf2ee + bfdc96b commit 4a08402

File tree

6 files changed

+942
-9
lines changed

6 files changed

+942
-9
lines changed
 

‎api_matop.go

+38-2
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ func ByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err
135135
if sbi, ok := a.Engine().(ByIndiceser); ok {
136136
return sbi.SelectByIndices(a, indices, axis, opts...)
137137
}
138-
return nil, errors.Errorf("Unable to select by indices. Egnine %T does not support that.", a.Engine())
138+
return nil, errors.Errorf("Unable to select by indices. Engine %T does not support that.", a.Engine())
139139
}
140140

141141
// ByIndicesB is the backpropagation of ByIndices.
@@ -146,5 +146,41 @@ func ByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor,
146146
if sbi, ok := a.Engine().(ByIndiceser); ok {
147147
return sbi.SelectByIndicesB(a, b, indices, axis, opts...)
148148
}
149-
return nil, errors.Errorf("Unable to select by indices. Egnine %T does not support that.", a.Engine())
149+
return nil, errors.Errorf("Unable to select by indices. Engine %T does not support that.", a.Engine())
150+
}
151+
152+
// LogSoftMax applies log softmax to the given tensor.
153+
func LogSoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
154+
if sm, ok := x.Engine().(SoftMaxer); ok {
155+
return sm.LogSoftMax(x, axis, opts...)
156+
}
157+
158+
return nil, errors.Errorf("Unable to apply LogSoftMax. Engine %T does not support that.", x.Engine())
159+
}
160+
161+
// SoftMax applies softmax to the given tensor.
162+
func SoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
163+
if sm, ok := x.Engine().(SoftMaxer); ok {
164+
return sm.SoftMax(x, axis, opts...)
165+
}
166+
167+
return nil, errors.Errorf("Unable to apply SoftMax. Engine %T does not support that.", x.Engine())
168+
}
169+
170+
// SoftMaxB applies softmax backwards operation
171+
func SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
172+
if sm, ok := output.Engine().(SoftMaxer); ok {
173+
return sm.SoftMaxB(output, grad, axis, opts...)
174+
}
175+
176+
return nil, errors.Errorf("Unable to apply SoftMaxB. Engine %T does not support that.", output.Engine())
177+
}
178+
179+
// LogSoftMaxB applies softmax backwards operation
180+
func LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
181+
if sm, ok := output.Engine().(SoftMaxer); ok {
182+
return sm.LogSoftMaxB(output, grad, axis, opts...)
183+
}
184+
185+
return nil, errors.Errorf("Unable to apply SoftMaxB. Engine %T does not support that.", output.Engine())
150186
}

0 commit comments

Comments
 (0)