smash.factory.Net.set_trainable#

Net.set_trainable(trainable)[source]#

Method which enables to train or freeze the weights and biases of the network’s layers.

Parameters:
trainableListLike

List of booleans with a length of the total number of the network’s layers.

Examples

>>> from smash.factory import Net
>>> net = Net()
>>> net.add_dense(32, input_shape=8, activation="relu")
>>> net.add_dense(16, activation="relu")
>>> net
+-------------------------------------------------------+
| Layer Type         Input/Output Shape  Num Parameters |
+-------------------------------------------------------+
| Dense              (8,)/(32,)          288            |
| Activation (ReLU)  (32,)/(32,)         0              |
| Dense              (32,)/(16,)         528            |
| Activation (ReLU)  (16,)/(16,)         0              |
+-------------------------------------------------------+
Total parameters: 816
Trainable parameters: 816

Freeze the parameters in the second dense layer:

>>> net.set_trainable([1, 0, 0, 0])
>>> net
+-------------------------------------------------------+
| Layer Type         Input/Output Shape  Num Parameters |
+-------------------------------------------------------+
| Dense              (8,)/(32,)          288            |
| Activation (ReLU)  (32,)/(32,)         0              |
| Dense              (32,)/(16,)         528            |
| Activation (ReLU)  (16,)/(16,)         0              |
+-------------------------------------------------------+
Total parameters: 816
Trainable parameters: 288