Data Experiment #05 Feed-forward Neural Network

[EDIT on 4 Apr] I added the validation_split option, .hist and .history_table().

In the last entry we compare a feed-forward neural network with other regression algorithms and we use Keras to create a neural network regression model.

The simplicity of the code is important, especially for a neural network

There are many libraries for neural network available. The reason why I chose Keras is its simplicity. For example the code which I wrote to train a neural network model contains only 10 lines (including two "import" stetements).

One of the difficulties of neural networks is many tuning options. Even for a feed-forward neural network, we need to pick suitable tuning options such as number of hidden layers, number of hidden units, activation functions, training algorithms (stochastic gradient descent, etc.).

Even though Keras provides a simple API for a neural network, I want to make the code even simpler. In particular I wanted to use Keras through scikit-learn. So I have written a Python module for it.

If we use it, the model which I used for recovering the sign curve can be trained with the following 3 lines.

from skeras import RegressionNN
rfnn = RegressionNN(hidden=[5,3],activation=['tanh']*3,nb_epoch=50000,verbose=0)
rfnn.fit(X,y)

How to use the module

You can find the module skeras.py at my Bitbucket repository. Your code and the module must be in the same directory, and you can import it with the statement

from skeras import ClassificationNN

for a classification model and

from skeras import RegressionNN

for a regression model. Then the basic usage of the module is the same as an estimator instance of scikit-learn. Please try a test script skeras-test.py which can be found at the same Bitbucket repository.

Options

  • hidden : list of numbers of (hidden) units in hidden layers. For example [5,3] creates two layers witch have 5 and 3 hidden units, respectively. The default value is [p], where p is the number of the predictor.
  • activation : list of activation functions of hidden (and the output) layers. For example ['relu','relu']. For ClassificationNN, the activation functions of hidden layer must be given. (An activation function of the output layer is 'sigmoid' or 'softmax'.) But for RegressionNN the activation functions of hidden and output layers must be given.

  • init : list of initialisation methods for both hidden and output layers. For example ['uniform','normal','uniform']. The default values are all 'uniform'

  • optimizer : list of a name of the optimizer, an optimizer instance]. For example ['sgd',SGD()]. A name of the optimizer is an arbitrary string. The default value is ['rmsprop']*2. The reason for the name is to distinguish optimizers, when we use sklearn.grid_search.GridSearchCV.

  • nb_epoch, batch_size, verbose, show_accuracy, validation_split : These options are the same as ones of the fit() method of a Sequential instance.

After training

fnn.fit(X,y)
  • fnn.get_weights() : the same as get_weights()-method for a Keras model instance.
  • fnn.evaluate() : the same as evaluate()-method for a Keras model instance.
  • fnn.predict() : makes a prediction.
  • fnn.predict_proba() (only for ClassificationNN) : calculate predicted probabilities.
  • fnn.classes_ (only for ClassificationNN) : gives the list of classes.
  • fnn.hist : an History callback.
  • fnn.history_table : a data frame of a History callback.

Tips

  • scikit-learn and Keras are obviously required for the module.
  • For an object relating to Keras, you should consult the well-written Keras documentation.
  • You do not have to convert your target variable in a required format. For example, a 2d numpy array is required for a multi-class classification when you use Keras directly. But the module automatically converts the target variable.
  • The module provides a scikit-learn wrapper only for a very simple feed-forward neural network. If we make complicated neural network (e.g. we want to use Dropout to avoid overfitting), then you need to modify the module.
Share this page on        
Categories: #data-mining