Source code for skmultilearn.ext.keras

from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.base import BaseEstimator
from copy import copy

[docs]class Keras(BaseEstimator): def __init__(self, build_function, multi_class=False, keras_params = None): if not callable(build_function): raise ValueError('Model construction function must be callable.') self.multi_class = multi_class self.build_function = build_function if keras_params is None: keras_params = {} self.keras_params = keras_params def fit(self, X, y): if self.multi_class: self.n_classes_ = len(set(y)) else: self.n_classes_ = 1 build_callable = lambda: self.build_function(X.shape[1], self.n_classes_) keras_params=copy(self.keras_params) keras_params['build_fn']=build_callable self.classifier_ = KerasClassifier(**keras_params) self.classifier_.fit(X, y) def predict(self, X): return self.classifier_.predict(X)