@@ -135,7 +135,7 @@ func ByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err
135
135
if sbi , ok := a .Engine ().(ByIndiceser ); ok {
136
136
return sbi .SelectByIndices (a , indices , axis , opts ... )
137
137
}
138
- return nil , errors .Errorf ("Unable to select by indices. Egnine %T does not support that." , a .Engine ())
138
+ return nil , errors .Errorf ("Unable to select by indices. Engine %T does not support that." , a .Engine ())
139
139
}
140
140
141
141
// ByIndicesB is the backpropagation of ByIndices.
@@ -146,5 +146,41 @@ func ByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor,
146
146
if sbi , ok := a .Engine ().(ByIndiceser ); ok {
147
147
return sbi .SelectByIndicesB (a , b , indices , axis , opts ... )
148
148
}
149
- return nil , errors .Errorf ("Unable to select by indices. Egnine %T does not support that." , a .Engine ())
149
+ return nil , errors .Errorf ("Unable to select by indices. Engine %T does not support that." , a .Engine ())
150
+ }
151
+
152
+ // LogSoftMax applies log softmax to the given tensor.
153
+ func LogSoftMax (x Tensor , axis int , opts ... FuncOpt ) (retVal Tensor , err error ) {
154
+ if sm , ok := x .Engine ().(SoftMaxer ); ok {
155
+ return sm .LogSoftMax (x , axis , opts ... )
156
+ }
157
+
158
+ return nil , errors .Errorf ("Unable to apply LogSoftMax. Engine %T does not support that." , x .Engine ())
159
+ }
160
+
161
+ // SoftMax applies softmax to the given tensor.
162
+ func SoftMax (x Tensor , axis int , opts ... FuncOpt ) (retVal Tensor , err error ) {
163
+ if sm , ok := x .Engine ().(SoftMaxer ); ok {
164
+ return sm .SoftMax (x , axis , opts ... )
165
+ }
166
+
167
+ return nil , errors .Errorf ("Unable to apply SoftMax. Engine %T does not support that." , x .Engine ())
168
+ }
169
+
170
+ // SoftMaxB applies softmax backwards operation
171
+ func SoftMaxB (output , grad Tensor , axis int , opts ... FuncOpt ) (retVal Tensor , err error ) {
172
+ if sm , ok := output .Engine ().(SoftMaxer ); ok {
173
+ return sm .SoftMaxB (output , grad , axis , opts ... )
174
+ }
175
+
176
+ return nil , errors .Errorf ("Unable to apply SoftMaxB. Engine %T does not support that." , output .Engine ())
177
+ }
178
+
179
+ // LogSoftMaxB applies softmax backwards operation
180
+ func LogSoftMaxB (output , grad Tensor , axis int , opts ... FuncOpt ) (retVal Tensor , err error ) {
181
+ if sm , ok := output .Engine ().(SoftMaxer ); ok {
182
+ return sm .LogSoftMaxB (output , grad , axis , opts ... )
183
+ }
184
+
185
+ return nil , errors .Errorf ("Unable to apply SoftMaxB. Engine %T does not support that." , output .Engine ())
150
186
}
0 commit comments