smash.Model.ann_optimize#

Model.ann_optimize(net=None, optimizer='adam', learning_rate=0.003, control_vector=None, bounds=None, jobs_fun='nse', wjobs_fun=None, event_seg=None, gauge='downstream', wgauge='mean', ost=None, epochs=400, early_stopping=False, random_state=None, verbose=True, inplace=False, return_net=False)[source]#

Optimize the Model using Artificial Neural Network.

Hint

See the User Guide and Math / Num Documentation for more.

Parameters:
netNet or None, default None

The Net object to be trained to learn the descriptors-to-parameters mapping.

Note

If not given, a default network will be used. Otherwise, perform operation in-place on this net.

optimizerstr, default ‘adam’

Name of optimizer. Only used if net is not set. Should be one of

  • ‘sgd’

  • ‘adam’

  • ‘adagrad’

  • ‘rmsprop’

learning_ratefloat, default 0.003

The learning rate used to update the weights during training. Only used if net is not set.

control_vectorstr, sequence or None, default None

Parameters and/or states to be optimized. The control vector argument can be any parameter or state name or any sequence of parameter and/or state names.

Note

If not given, the control vector will be composed of the parameters of the structure defined in the Model setup.

boundsdict or None, default None

Bounds on control vector. The bounds argument is a dictionary where keys are the name of the parameters and/or states in the control vector (can be a subset of control vector sequence) and the values are pairs of (min, max) values (i.e. list or tuple) with min lower than max. None value inside the dictionary will be filled in with default bound values.

Note

If not given, the bounds will be filled in with default bound values.

jobs_fun, wjobs_fun, event_seg, gauge, wgauge, ostmultiple types

Optimization setting to run the forward hydrological model and compute the cost values. See smash.Model.optimize for more.

epochsint, default 400

The number of epochs to train the network.

early_stoppingbool, default False

Stop updating weights and biases when the loss function stops decreasing.

random_stateint or None, default None

Random seed used to initialize weights. Only used if net is not set.

Note

If not given and net is not set, the weights will be initialized with a random seed.

verbosebool, default True

Display information while training.

inplacebool, default False

If True, perform operation in-place.

return_netbool, default False

If True and the default graph is used (net is not set), also return the trained neural network.

Returns:
ModelModel or None

Model with optimize outputs if not inplace.

NetNet or None

Net with trained weights and biases if return_net and the default graph is used.

See also

Net

Artificial Neural Network initialization.

Examples

>>> setup, mesh = smash.load_dataset("cance")
>>> model = smash.Model(setup, mesh)
>>> net = model.ann_optimize(epochs=200, inplace=True, return_net=True, random_state=11)
>>> model
Structure: 'gr-a'
Spatio-Temporal dimension: (x: 28, y: 28, time: 1440)
Last update: ANN Optimization

Display a summary of the neural network

>>> net
+----------------------------------------------------------+
| Layer Type            Input/Output Shape  Num Parameters |
+----------------------------------------------------------+
| Dense                 (2,)/(18,)          54             |
| Activation (ReLU)     (18,)/(18,)         0              |
| Dense                 (18,)/(9,)          171            |
| Activation (ReLU)     (9,)/(9,)           0              |
| Dense                 (9,)/(4,)           40             |
| Activation (Sigmoid)  (4,)/(4,)           0              |
| Scale (MinMaxScale)   (4,)/(4,)           0              |
+----------------------------------------------------------+
Total parameters: 265
Trainable parameters: 265
Optimizer: (adam, lr=0.003)

Access to some training information

>>> net.history['loss_train']  # training loss
[1.2064831256866455, ..., 0.03552241995930672]
>>> net.layers[0].weight  # trained weights of the first layer
array([[-0.35024701, -0.5263885 ,  0.06432176,  0.31493864, -0.08741257,
        -0.01596381, -0.53372188, -0.01383371,  0.54957057,  0.51538232,
        0.23674032, -0.42860816,  0.53083172,  0.42429858, -0.24634816,
        0.07233667, -0.58225892, -0.34835798],
    [-0.20115953, -0.37473829,  0.43865405,  0.48463052, -0.17020534,
        -0.19849597, -0.42540381, -0.4557565 ,  0.3663841 ,  0.27515033,
        -0.50145176, -0.02213097,  0.02078811,  0.48562112,  0.40088665,
        0.12205882, -0.00624188,  0.62118917]])