smash.Model.ann_optimize#
- Model.ann_optimize(net=None, optimizer='adam', learning_rate=0.003, control_vector=None, bounds=None, jobs_fun='nse', wjobs_fun=None, event_seg=None, gauge='downstream', wgauge='mean', ost=None, epochs=400, early_stopping=False, random_state=None, verbose=True, inplace=False, return_net=False)[source]#
Optimize the Model using Artificial Neural Network.
Hint
See the User Guide and Math / Num Documentation for more.
- Parameters:
- netNet or None, default None
The
Netobject to be trained to learn the descriptors-to-parameters mapping.Note
If not given, a default network will be used. Otherwise, perform operation in-place on this net.
- optimizerstr, default ‘adam’
Name of optimizer. Only used if net is not set. Should be one of
‘sgd’
‘adam’
‘adagrad’
‘rmsprop’
- learning_ratefloat, default 0.003
The learning rate used to update the weights during training. Only used if net is not set.
- control_vectorstr, sequence or None, default None
Parameters and/or states to be optimized. The control vector argument can be any parameter or state name or any sequence of parameter and/or state names.
Note
If not given, the control vector will be composed of the parameters of the structure defined in the Model setup.
- boundsdict or None, default None
Bounds on control vector. The bounds argument is a dictionary where keys are the name of the parameters and/or states in the control vector (can be a subset of control vector sequence) and the values are pairs of
(min, max)values (i.e. list or tuple) withminlower thanmax. None value inside the dictionary will be filled in with default bound values.Note
If not given, the bounds will be filled in with default bound values.
- jobs_fun, wjobs_fun, event_seg, gauge, wgauge, ostmultiple types
Optimization setting to run the forward hydrological model and compute the cost values. See
smash.Model.optimizefor more.- epochsint, default 400
The number of epochs to train the network.
- early_stoppingbool, default False
Stop updating weights and biases when the loss function stops decreasing.
- random_stateint or None, default None
Random seed used to initialize weights. Only used if net is not set.
Note
If not given and net is not set, the weights will be initialized with a random seed.
- verbosebool, default True
Display information while training.
- inplacebool, default False
If True, perform operation in-place.
- return_netbool, default False
If True and the default graph is used (net is not set), also return the trained neural network.
- Returns:
- ModelModel or None
Model with optimize outputs if not inplace.
- NetNet or None
Netwith trained weights and biases if return_net and the default graph is used.
See also
NetArtificial Neural Network initialization.
Examples
>>> setup, mesh = smash.load_dataset("cance") >>> model = smash.Model(setup, mesh) >>> net = model.ann_optimize(epochs=200, inplace=True, return_net=True, random_state=11) >>> model Structure: 'gr-a' Spatio-Temporal dimension: (x: 28, y: 28, time: 1440) Last update: ANN Optimization
Display a summary of the neural network
>>> net +----------------------------------------------------------+ | Layer Type Input/Output Shape Num Parameters | +----------------------------------------------------------+ | Dense (2,)/(18,) 54 | | Activation (ReLU) (18,)/(18,) 0 | | Dense (18,)/(9,) 171 | | Activation (ReLU) (9,)/(9,) 0 | | Dense (9,)/(4,) 40 | | Activation (Sigmoid) (4,)/(4,) 0 | | Scale (MinMaxScale) (4,)/(4,) 0 | +----------------------------------------------------------+ Total parameters: 265 Trainable parameters: 265 Optimizer: (adam, lr=0.003)
Access to some training information
>>> net.history['loss_train'] # training loss [1.2064831256866455, ..., 0.03552241995930672] >>> net.layers[0].weight # trained weights of the first layer array([[-0.35024701, -0.5263885 , 0.06432176, 0.31493864, -0.08741257, -0.01596381, -0.53372188, -0.01383371, 0.54957057, 0.51538232, 0.23674032, -0.42860816, 0.53083172, 0.42429858, -0.24634816, 0.07233667, -0.58225892, -0.34835798], [-0.20115953, -0.37473829, 0.43865405, 0.48463052, -0.17020534, -0.19849597, -0.42540381, -0.4557565 , 0.3663841 , 0.27515033, -0.50145176, -0.02213097, 0.02078811, 0.48562112, 0.40088665, 0.12205882, -0.00624188, 0.62118917]])