Skip to content

Python: lack of out-of-the-box support for tensorflow models #326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
pfraczek opened this issue Sep 12, 2020 · 1 comment
Closed

Python: lack of out-of-the-box support for tensorflow models #326

pfraczek opened this issue Sep 12, 2020 · 1 comment
Labels
Python 🐍 Related to Python

Comments

@pfraczek
Copy link

It turns out that the tensorflow model requires a custom predict_function to create an Explainer object for it. This is because of two reasons:

It would be great to have dalex handling this natively, like xgboost models. To that moment I believe creating a short instruction is needed.

# tensorflow==2.0.0
import dalex as dx
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.preprocessing import StandardScaler


data = dx.datasets.load_fifa()
X = data.drop(columns=['nationality', 'value_eur'])
y = data['value_eur']

X[X.columns] = StandardScaler().fit_transform(X[X.columns])

model = keras.models.Sequential([
    layers.Dense(128, activation='relu', input_shape=[X.values.shape[1]]),
    layers.Dense(64, activation='relu'),
    layers.Dense(1)
])

optimizer = tf.keras.optimizers.RMSprop(0.001)

model.compile(loss='mse',
              optimizer=optimizer,
              metrics=['mse'])


def predict_function_without_reshaping(model, data):
    X = data.values
    return model.predict(X)

def predict_function_with_reshaping(model, data):
    X = data.values
    return model.predict(X).reshape((data.shape[0],))


exp_default = dx.Explainer(model, X, y, verbose=True, model_type='regression')


exp_no_reshape = dx.Explainer(model, X, y, verbose=True, model_type='regression', predict_function=predict_function_without_reshaping)
print(exp_no_reshape.y_hat.shape)
print(exp_no_reshape.residuals.shape)


exp_reshape = dx.Explainer(model, X, y, verbose=True, model_type='regression', predict_function=predict_function_with_reshaping)
print(exp_reshape.y_hat.shape)
print(exp_reshape.residuals.shape)
@hbaniecki
Copy link
Member

Indeed, this is connected with changes in predict for tensorflow #321 and should be worked on in the next version

@hbaniecki hbaniecki added the Python 🐍 Related to Python label Sep 12, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Python 🐍 Related to Python
Projects
None yet
Development

No branches or pull requests

2 participants