dl4ds.app

absl.FLAGS-based command line app. To be executed run something like this:

python -m dl4ds.app --flagfile=params.cfg

View Source
  0#!/usr/bin/env python
  1
  2"""
  3absl.FLAGS-based command line app. To be executed run something like this:
  4
  5python -m dl4ds.app --flagfile=params.cfg
  6"""
  7
  8import numpy as np
  9import xarray as xr
 10import importlib.util
 11from absl import app, flags  
 12
 13# Usign Agg MLP backend to prevent errors related to X11 unable to connect to display "localhost:10.0"
 14import matplotlib
 15matplotlib.use('Agg')
 16
 17# Attempting to import horovod
 18try:
 19    import horovod.tensorflow.keras as hvd
 20    has_horovod = True
 21    hvd.init()
 22    if hvd.rank() == 0:
 23        running_on_first_worker = True
 24    else:
 25        running_on_first_worker = False
 26except ImportError:
 27    has_horovod = False
 28    running_on_first_worker = True
 29
 30import dl4ds as dds
 31from dl4ds import BACKBONE_BLOCKS, UPSAMPLING_METHODS, INTERPOLATION_METHODS, LOSS_FUNCTIONS, DROPOUT_VARIANTS
 32
 33
 34FLAGS = flags.FLAGS
 35
 36### EXPERIMENT
 37flags.DEFINE_bool('train', True, 'Training a model')
 38flags.DEFINE_bool('test', True, 'Testing the trained model on holdout data')
 39flags.DEFINE_bool('metrics', True, 'Running vaerification metrics on the downscaled arrays')
 40flags.DEFINE_bool('debug', False, 'If True a debug training run (2 epochs by default with 6 steps) is executed') 
 41
 42### DOWNSCALING PARAMS
 43flags.DEFINE_enum('trainer', 'SupervisedTrainer', ['SupervisedTrainer', 'CGANTrainer'], 'Tainer')
 44flags.DEFINE_enum('paired_samples', 'implicit', ['implicit', 'explicit'], 'Type of learning: implicit (PerfectProg) or explicit (MOS)')
 45flags.DEFINE_string('data_module', None, 'Python module where the data pre-processing is done')
 46
 47### MODEL
 48flags.DEFINE_enum('backbone', 'resnet', BACKBONE_BLOCKS, 'Backbone section')
 49flags.DEFINE_enum('upsampling', 'spc', UPSAMPLING_METHODS, 'Upsampling method')
 50flags.DEFINE_integer('time_window', None, 'Time window for training spatio-temporal models')
 51flags.DEFINE_integer('n_filters', 8, 'Number of convolutional filters for the first convolutional block')
 52flags.DEFINE_integer('n_blocks', 6, 'Number of convolutional blocks')
 53flags.DEFINE_integer('n_disc_filters', 32, 'Number of convolutional filters per convolutional block in the discriminator')
 54flags.DEFINE_integer('n_disc_blocks', 4, 'Number of residual blocks for discriminator network')
 55flags.DEFINE_enum('normalization', None, ['bn', 'ln'], 'Normalization')
 56flags.DEFINE_float('dropout_rate', 0.2, 'Dropout rate')
 57flags.DEFINE_enum('dropout_variant', 'vanilla', DROPOUT_VARIANTS, 'Dropout variants')
 58flags.DEFINE_bool('attention', False, 'Attention block in convolutional layers')
 59flags.DEFINE_enum('activation', 'relu', ['elu', 'relu', 'gelu', 'crelu', 'leaky_relu', 'selu'], 'Activation used in intermediate convolutional blocks')
 60flags.DEFINE_enum('output_activation', None, ['elu', 'relu', 'gelu', 'crelu', 'leaky_relu', 'selu'], 'Activation used in the last convolutional block')
 61flags.DEFINE_bool('localcon_layer', False, 'Locally connected convolutional layer')
 62flags.DEFINE_enum('decoder_upsampling', 'rc', UPSAMPLING_METHODS, 'Upsampling in decoder blocks (unet backbone)')
 63flags.DEFINE_enum('rc_interpolation', 'bilinear', INTERPOLATION_METHODS, 'Interpolation used in resize convolution upsampling')
 64
 65### TRAINING PROCEDURE
 66flags.DEFINE_enum('device', 'GPU', ['GPU', 'CPU'], 'Device to be used: GPU or CPU')
 67flags.DEFINE_bool('save', True, 'Saving to disk the trained model (last epoch), metrics, run info, etc')
 68flags.DEFINE_string('save_path', './dl4ds_results/', 'Path for saving results to disk')
 69flags.DEFINE_integer('scale', 2, 'Scaling factor, positive integer')
 70flags.DEFINE_integer('epochs', 100, 'Number of training epochs')
 71flags.DEFINE_enum('loss', 'mae', LOSS_FUNCTIONS, 'Loss function')
 72flags.DEFINE_enum('interpolation', 'inter_area', INTERPOLATION_METHODS, 'Interpolation method')
 73flags.DEFINE_integer('patch_size', None, 'Patch size in number of px/gridpoints')
 74flags.DEFINE_integer('batch_size', 32, 'Batch size (of samples) used during training')
 75flags.DEFINE_multi_float('learning_rate', 1e-3, 'Learning rate')
 76flags.DEFINE_bool('gpu_memory_growth', True, 'To use GPU memory growth (gradual memory allocation)')
 77flags.DEFINE_bool('use_multiprocessing', True, 'To use multiprocessing for data generation')
 78flags.DEFINE_float('lr_decay_after', 1e5, 'Steps to tweak the learning rate using the PiecewiseConstantDecay scheduler')
 79flags.DEFINE_bool('early_stopping', False, 'Early stopping')
 80flags.DEFINE_integer('patience', 6, 'Patience in number of epochs w/o improvement for early stopping')
 81flags.DEFINE_float('min_delta', 0.0, 'Minimum delta improvement for early stopping')
 82flags.DEFINE_bool('show_plot', False, 'Show the learning curve plot on finish')
 83flags.DEFINE_bool('save_bestmodel', True, 'SupervisedTrainer - Whether to save the best model (epoch with the best val_loss)')
 84flags.DEFINE_bool('verbose', True, 'Verbosity')
 85flags.DEFINE_integer('checkpoints_frequency', 2, 'CGANTrainer - Frequency for saving checkpoints and the generator')
 86
 87### INFERENCE/TEST
 88flags.DEFINE_bool('inference_array_in_hr', False, 'Whether the inference array is in high resolution')
 89flags.DEFINE_string('inference_save_fname', None, 'Filename for saving the inference array')
 90
 91
 92
 93def dl4ds(argv):
 94    """DL4DS absl.FLAGS-based command line app.
 95    """
 96    if running_on_first_worker:
 97        print('<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< DL4DS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n')
 98
 99    # Run mode
100    if FLAGS.debug:
101        epochs = 2
102        steps_per_epoch = test_steps = validation_steps = 6
103    else:
104        epochs = FLAGS.epochs
105        steps_per_epoch = test_steps = validation_steps = None 
106
107    if running_on_first_worker:
108        print('<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< Loading data >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n')
109    # Training data from Python script/module
110    if FLAGS.data_module is not None:
111        spec = importlib.util.spec_from_file_location("module.name", FLAGS.data_module)
112        DATA = importlib.util.module_from_spec(spec)
113        spec.loader.exec_module(DATA)
114    else:
115        raise ValueError('`data_module` flag must be provided (path to the data preprocessing module)')
116
117    # Architecture parameters
118    if FLAGS.time_window is None:
119        if FLAGS.upsampling == 'pin':
120            architecture_params = dict(
121                n_filters=FLAGS.n_filters,
122                n_blocks=FLAGS.n_blocks,
123                normalization=FLAGS.normalization,
124                dropout_rate=FLAGS.dropout_rate,
125                dropout_variant=FLAGS.dropout_variant,
126                attention=FLAGS.attention,
127                activation=FLAGS.activation,
128                localcon_layer=FLAGS.localcon_layer,
129                output_activation=FLAGS.output_activation)
130            if FLAGS.backbone == 'unet':
131                architecture_params['decoder_upsampling'] = FLAGS.decoder_upsampling
132                architecture_params['rc_interpolation'] = FLAGS.rc_interpolation
133        else:
134            architecture_params = dict(
135                n_filters=FLAGS.n_filters,
136                n_blocks=FLAGS.n_blocks,
137                normalization=FLAGS.normalization,
138                dropout_rate=FLAGS.dropout_rate,
139                dropout_variant=FLAGS.dropout_variant,
140                attention=FLAGS.attention,
141                activation=FLAGS.activation,
142                localcon_layer=FLAGS.localcon_layer,
143                output_activation=FLAGS.output_activation,
144                rc_interpolation=FLAGS.rc_interpolation)
145    else:
146        if FLAGS.upsampling == 'pin':
147            architecture_params = dict(
148                n_filters=FLAGS.n_filters,
149                n_blocks=FLAGS.n_blocks,
150                activation=FLAGS.activation,
151                normalization=FLAGS.normalization,
152                dropout_rate=FLAGS.dropout_rate,
153                dropout_variant=FLAGS.dropout_variant,
154                attention=FLAGS.attention,
155                output_activation=FLAGS.output_activation,
156                localcon_layer=FLAGS.localcon_layer)
157        else:
158            architecture_params = dict(
159                n_filters=FLAGS.n_filters,
160                activation=FLAGS.activation,
161                normalization=FLAGS.normalization,
162                dropout_rate=FLAGS.dropout_rate,
163                dropout_variant=FLAGS.dropout_variant,
164                attention=FLAGS.attention,
165                output_activation=FLAGS.output_activation,
166                localcon_layer=FLAGS.localcon_layer,
167                rc_interpolation=FLAGS.rc_interpolation)
168
169    if FLAGS.train:
170        if running_on_first_worker:
171            print('\n<<<<<<<<<<<<<<<<<<<<<<<<<<<<< DL4DS Training phase >>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n')
172        if FLAGS.trainer == 'SupervisedTrainer':
173            trainer = dds.SupervisedTrainer(
174                backbone=FLAGS.backbone, 
175                upsampling=FLAGS.upsampling,
176                data_train=DATA.data_train, 
177                data_val=DATA.data_val, 
178                data_test=DATA.data_test, 
179                data_train_lr=DATA.data_train_lr if FLAGS.paired_samples == 'explicit' else None, 
180                data_val_lr=DATA.data_val_lr if FLAGS.paired_samples == 'explicit' else None, 
181                data_test_lr=DATA.data_test_lr if FLAGS.paired_samples == 'explicit' else None, 
182                predictors_train=DATA.predictors_train, 
183                predictors_val=DATA.predictors_val, 
184                predictors_test=DATA.predictors_test, 
185                static_vars=DATA.static_vars, 
186                scale=FLAGS.scale, 
187                interpolation=FLAGS.interpolation,
188                patch_size=FLAGS.patch_size, 
189                time_window=FLAGS.time_window, 
190                batch_size=FLAGS.batch_size,
191                loss=FLAGS.loss, 
192                epochs=epochs, 
193                steps_per_epoch=steps_per_epoch, 
194                validation_steps=validation_steps, 
195                test_steps=test_steps,
196                device=FLAGS.device, 
197                gpu_memory_growth=FLAGS.gpu_memory_growth, 
198                use_multiprocessing=FLAGS.use_multiprocessing, 
199                learning_rate=FLAGS.learning_rate, 
200                lr_decay_after=FLAGS.lr_decay_after, 
201                early_stopping=FLAGS.early_stopping, 
202                patience=FLAGS.patience, 
203                min_delta=FLAGS.min_delta, 
204                show_plot=FLAGS.show_plot, 
205                save=FLAGS.save, 
206                save_path=FLAGS.save_path, 
207                save_bestmodel=FLAGS.save_bestmodel, 
208                trained_model=None, #FLAGS.trained_model, 
209                trained_epochs=0, #FLAGS.trained_epochs, 
210                verbose=FLAGS.verbose, 
211                **architecture_params)
212        elif FLAGS.trainer == 'CGANTrainer':
213            discriminator_params = dict(
214                n_filters=FLAGS.n_disc_filters,
215                n_res_blocks=FLAGS.n_disc_blocks,
216                normalization=FLAGS.normalization,
217                activation=FLAGS.activation,
218                attention=FLAGS.attention)
219
220            trainer = dds.CGANTrainer(
221                backbone=FLAGS.backbone, 
222                upsampling=FLAGS.upsampling,
223                data_train=DATA.data_train, 
224                data_test=DATA.data_test, 
225                data_train_lr=DATA.data_train_lr if FLAGS.paired_samples == 'explicit' else None,
226                data_test_lr=DATA.data_test_lr if FLAGS.paired_samples == 'explicit' else None,
227                predictors_train=DATA.predictors_train,
228                predictors_test=DATA.predictors_test,
229                scale=FLAGS.scale, 
230                patch_size=FLAGS.patch_size, 
231                time_window=FLAGS.time_window,
232                loss=FLAGS.loss,
233                epochs=epochs, 
234                batch_size=FLAGS.batch_size,
235                learning_rates=FLAGS.learning_rate, 
236                device=FLAGS.device,
237                gpu_memory_growth=FLAGS.gpu_memory_growth,
238                steps_per_epoch=steps_per_epoch,
239                interpolation=FLAGS.interpolation, 
240                static_vars=DATA.static_vars,
241                checkpoints_frequency=FLAGS.checkpoints_frequency, 
242                save=FLAGS.save,
243                save_path=FLAGS.save_path,
244                save_logs=False,
245                save_loss_history=FLAGS.save,
246                verbose=FLAGS.verbose,
247                generator_params=architecture_params,
248                discriminator_params=discriminator_params)
249
250        trainer.run()
251
252    if FLAGS.test:
253        if running_on_first_worker:
254            print('\n<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< DL4DS Test phase >>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n')
255        if DATA.inference_scaler is None:
256            inference_scaler = None
257        else:
258            inference_scaler = DATA.inference_scaler
259
260        if not has_horovod or running_on_first_worker:
261            predictor = dds.Predictor(
262                trainer=trainer,
263                array=DATA.inference_data, 
264                array_in_hr=FLAGS.inference_array_in_hr, 
265                scale=FLAGS.scale, 
266                interpolation=FLAGS.interpolation, 
267                predictors=DATA.inference_predictors, 
268                static_vars=DATA.static_vars, 
269                time_window=FLAGS.time_window, 
270                batch_size=FLAGS.batch_size,
271                scaler=inference_scaler,
272                save_path=FLAGS.save_path, 
273                save_fname=FLAGS.inference_save_fname,
274                device=FLAGS.device)
275
276            y_hat = predictor.run()
277
278            # Saving the downscaled product in netcdf format
279            y_hat_datarray = xr.DataArray(data=np.squeeze(y_hat), 
280                                          dims=('time', 'lat', 'lon'), 
281                                          coords={'time':DATA.gt_holdout_dataset.time, 
282                                                  'lon':DATA.gt_holdout_dataset.lon, 
283                                                  'lat':DATA.gt_holdout_dataset.lat})
284            
285            if FLAGS.save_path is not None:
286                y_hat_datarray.to_netcdf(f'{FLAGS.save_path}y_hat.nc')
287
288    if FLAGS.metrics:
289        if running_on_first_worker:
290            print('\n<<<<<<<<<<<<<<<<<<<<<<<<< DL4DS Metrics computation phase >>>>>>>>>>>>>>>>>>>>>>\n')
291        if not has_horovod or running_on_first_worker:
292            metrics = dds.compute_metrics(
293                y_test=DATA.gt_holdout_dataset, 
294                y_test_hat=y_hat, 
295                dpi=300, plot_size_px=1200, 
296                mask=DATA.gt_mask, 
297                save_path=FLAGS.save_path,
298                n_jobs=-1)
299
300if __name__ == '__main__':
301    app.run(dl4ds)
#   def dl4ds(argv):
View Source
 94def dl4ds(argv):
 95    """DL4DS absl.FLAGS-based command line app.
 96    """
 97    if running_on_first_worker:
 98        print('<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< DL4DS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n')
 99
100    # Run mode
101    if FLAGS.debug:
102        epochs = 2
103        steps_per_epoch = test_steps = validation_steps = 6
104    else:
105        epochs = FLAGS.epochs
106        steps_per_epoch = test_steps = validation_steps = None 
107
108    if running_on_first_worker:
109        print('<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< Loading data >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n')
110    # Training data from Python script/module
111    if FLAGS.data_module is not None:
112        spec = importlib.util.spec_from_file_location("module.name", FLAGS.data_module)
113        DATA = importlib.util.module_from_spec(spec)
114        spec.loader.exec_module(DATA)
115    else:
116        raise ValueError('`data_module` flag must be provided (path to the data preprocessing module)')
117
118    # Architecture parameters
119    if FLAGS.time_window is None:
120        if FLAGS.upsampling == 'pin':
121            architecture_params = dict(
122                n_filters=FLAGS.n_filters,
123                n_blocks=FLAGS.n_blocks,
124                normalization=FLAGS.normalization,
125                dropout_rate=FLAGS.dropout_rate,
126                dropout_variant=FLAGS.dropout_variant,
127                attention=FLAGS.attention,
128                activation=FLAGS.activation,
129                localcon_layer=FLAGS.localcon_layer,
130                output_activation=FLAGS.output_activation)
131            if FLAGS.backbone == 'unet':
132                architecture_params['decoder_upsampling'] = FLAGS.decoder_upsampling
133                architecture_params['rc_interpolation'] = FLAGS.rc_interpolation
134        else:
135            architecture_params = dict(
136                n_filters=FLAGS.n_filters,
137                n_blocks=FLAGS.n_blocks,
138                normalization=FLAGS.normalization,
139                dropout_rate=FLAGS.dropout_rate,
140                dropout_variant=FLAGS.dropout_variant,
141                attention=FLAGS.attention,
142                activation=FLAGS.activation,
143                localcon_layer=FLAGS.localcon_layer,
144                output_activation=FLAGS.output_activation,
145                rc_interpolation=FLAGS.rc_interpolation)
146    else:
147        if FLAGS.upsampling == 'pin':
148            architecture_params = dict(
149                n_filters=FLAGS.n_filters,
150                n_blocks=FLAGS.n_blocks,
151                activation=FLAGS.activation,
152                normalization=FLAGS.normalization,
153                dropout_rate=FLAGS.dropout_rate,
154                dropout_variant=FLAGS.dropout_variant,
155                attention=FLAGS.attention,
156                output_activation=FLAGS.output_activation,
157                localcon_layer=FLAGS.localcon_layer)
158        else:
159            architecture_params = dict(
160                n_filters=FLAGS.n_filters,
161                activation=FLAGS.activation,
162                normalization=FLAGS.normalization,
163                dropout_rate=FLAGS.dropout_rate,
164                dropout_variant=FLAGS.dropout_variant,
165                attention=FLAGS.attention,
166                output_activation=FLAGS.output_activation,
167                localcon_layer=FLAGS.localcon_layer,
168                rc_interpolation=FLAGS.rc_interpolation)
169
170    if FLAGS.train:
171        if running_on_first_worker:
172            print('\n<<<<<<<<<<<<<<<<<<<<<<<<<<<<< DL4DS Training phase >>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n')
173        if FLAGS.trainer == 'SupervisedTrainer':
174            trainer = dds.SupervisedTrainer(
175                backbone=FLAGS.backbone, 
176                upsampling=FLAGS.upsampling,
177                data_train=DATA.data_train, 
178                data_val=DATA.data_val, 
179                data_test=DATA.data_test, 
180                data_train_lr=DATA.data_train_lr if FLAGS.paired_samples == 'explicit' else None, 
181                data_val_lr=DATA.data_val_lr if FLAGS.paired_samples == 'explicit' else None, 
182                data_test_lr=DATA.data_test_lr if FLAGS.paired_samples == 'explicit' else None, 
183                predictors_train=DATA.predictors_train, 
184                predictors_val=DATA.predictors_val, 
185                predictors_test=DATA.predictors_test, 
186                static_vars=DATA.static_vars, 
187                scale=FLAGS.scale, 
188                interpolation=FLAGS.interpolation,
189                patch_size=FLAGS.patch_size, 
190                time_window=FLAGS.time_window, 
191                batch_size=FLAGS.batch_size,
192                loss=FLAGS.loss, 
193                epochs=epochs, 
194                steps_per_epoch=steps_per_epoch, 
195                validation_steps=validation_steps, 
196                test_steps=test_steps,
197                device=FLAGS.device, 
198                gpu_memory_growth=FLAGS.gpu_memory_growth, 
199                use_multiprocessing=FLAGS.use_multiprocessing, 
200                learning_rate=FLAGS.learning_rate, 
201                lr_decay_after=FLAGS.lr_decay_after, 
202                early_stopping=FLAGS.early_stopping, 
203                patience=FLAGS.patience, 
204                min_delta=FLAGS.min_delta, 
205                show_plot=FLAGS.show_plot, 
206                save=FLAGS.save, 
207                save_path=FLAGS.save_path, 
208                save_bestmodel=FLAGS.save_bestmodel, 
209                trained_model=None, #FLAGS.trained_model, 
210                trained_epochs=0, #FLAGS.trained_epochs, 
211                verbose=FLAGS.verbose, 
212                **architecture_params)
213        elif FLAGS.trainer == 'CGANTrainer':
214            discriminator_params = dict(
215                n_filters=FLAGS.n_disc_filters,
216                n_res_blocks=FLAGS.n_disc_blocks,
217                normalization=FLAGS.normalization,
218                activation=FLAGS.activation,
219                attention=FLAGS.attention)
220
221            trainer = dds.CGANTrainer(
222                backbone=FLAGS.backbone, 
223                upsampling=FLAGS.upsampling,
224                data_train=DATA.data_train, 
225                data_test=DATA.data_test, 
226                data_train_lr=DATA.data_train_lr if FLAGS.paired_samples == 'explicit' else None,
227                data_test_lr=DATA.data_test_lr if FLAGS.paired_samples == 'explicit' else None,
228                predictors_train=DATA.predictors_train,
229                predictors_test=DATA.predictors_test,
230                scale=FLAGS.scale, 
231                patch_size=FLAGS.patch_size, 
232                time_window=FLAGS.time_window,
233                loss=FLAGS.loss,
234                epochs=epochs, 
235                batch_size=FLAGS.batch_size,
236                learning_rates=FLAGS.learning_rate, 
237                device=FLAGS.device,
238                gpu_memory_growth=FLAGS.gpu_memory_growth,
239                steps_per_epoch=steps_per_epoch,
240                interpolation=FLAGS.interpolation, 
241                static_vars=DATA.static_vars,
242                checkpoints_frequency=FLAGS.checkpoints_frequency, 
243                save=FLAGS.save,
244                save_path=FLAGS.save_path,
245                save_logs=False,
246                save_loss_history=FLAGS.save,
247                verbose=FLAGS.verbose,
248                generator_params=architecture_params,
249                discriminator_params=discriminator_params)
250
251        trainer.run()
252
253    if FLAGS.test:
254        if running_on_first_worker:
255            print('\n<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< DL4DS Test phase >>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n')
256        if DATA.inference_scaler is None:
257            inference_scaler = None
258        else:
259            inference_scaler = DATA.inference_scaler
260
261        if not has_horovod or running_on_first_worker:
262            predictor = dds.Predictor(
263                trainer=trainer,
264                array=DATA.inference_data, 
265                array_in_hr=FLAGS.inference_array_in_hr, 
266                scale=FLAGS.scale, 
267                interpolation=FLAGS.interpolation, 
268                predictors=DATA.inference_predictors, 
269                static_vars=DATA.static_vars, 
270                time_window=FLAGS.time_window, 
271                batch_size=FLAGS.batch_size,
272                scaler=inference_scaler,
273                save_path=FLAGS.save_path, 
274                save_fname=FLAGS.inference_save_fname,
275                device=FLAGS.device)
276
277            y_hat = predictor.run()
278
279            # Saving the downscaled product in netcdf format
280            y_hat_datarray = xr.DataArray(data=np.squeeze(y_hat), 
281                                          dims=('time', 'lat', 'lon'), 
282                                          coords={'time':DATA.gt_holdout_dataset.time, 
283                                                  'lon':DATA.gt_holdout_dataset.lon, 
284                                                  'lat':DATA.gt_holdout_dataset.lat})
285            
286            if FLAGS.save_path is not None:
287                y_hat_datarray.to_netcdf(f'{FLAGS.save_path}y_hat.nc')
288
289    if FLAGS.metrics:
290        if running_on_first_worker:
291            print('\n<<<<<<<<<<<<<<<<<<<<<<<<< DL4DS Metrics computation phase >>>>>>>>>>>>>>>>>>>>>>\n')
292        if not has_horovod or running_on_first_worker:
293            metrics = dds.compute_metrics(
294                y_test=DATA.gt_holdout_dataset, 
295                y_test_hat=y_hat, 
296                dpi=300, plot_size_px=1200, 
297                mask=DATA.gt_mask, 
298                save_path=FLAGS.save_path,
299                n_jobs=-1)

DL4DS absl.FLAGS-based command line app.