smash.factory.Net.set_trainable#
- Net.set_trainable(trainable)[source]#
Method which enables to train or freeze the weights and biases of the network’s layers.
- Parameters:
- trainableListLike
List of booleans with a length of the total number of the network’s layers.
Examples
>>> from smash.factory import Net >>> net = Net() >>> net.add_dense(32, input_shape=8, activation="relu") >>> net.add_dense(16, activation="relu") >>> net +-------------------------------------------------------+ | Layer Type Input/Output Shape Num Parameters | +-------------------------------------------------------+ | Dense (8,)/(32,) 288 | | Activation (ReLU) (32,)/(32,) 0 | | Dense (32,)/(16,) 528 | | Activation (ReLU) (16,)/(16,) 0 | +-------------------------------------------------------+ Total parameters: 816 Trainable parameters: 816
Freeze the parameters in the second dense layer:
>>> net.set_trainable([1, 0, 0, 0]) >>> net +-------------------------------------------------------+ | Layer Type Input/Output Shape Num Parameters | +-------------------------------------------------------+ | Dense (8,)/(32,) 288 | | Activation (ReLU) (32,)/(32,) 0 | | Dense (32,)/(16,) 528 | | Activation (ReLU) (16,)/(16,) 0 | +-------------------------------------------------------+ Total parameters: 816 Trainable parameters: 288