dl4ds.app
absl.FLAGS-based command line app. To be executed run something like this:
python -m dl4ds.app --flagfile=params.cfg
View Source
0#!/usr/bin/env python 1 2""" 3absl.FLAGS-based command line app. To be executed run something like this: 4 5python -m dl4ds.app --flagfile=params.cfg 6""" 7 8import numpy as np 9import xarray as xr 10import importlib.util 11from absl import app, flags 12 13# Usign Agg MLP backend to prevent errors related to X11 unable to connect to display "localhost:10.0" 14import matplotlib 15matplotlib.use('Agg') 16 17# Attempting to import horovod 18try: 19 import horovod.tensorflow.keras as hvd 20 has_horovod = True 21 hvd.init() 22 if hvd.rank() == 0: 23 running_on_first_worker = True 24 else: 25 running_on_first_worker = False 26except ImportError: 27 has_horovod = False 28 running_on_first_worker = True 29 30import dl4ds as dds 31from dl4ds import BACKBONE_BLOCKS, UPSAMPLING_METHODS, INTERPOLATION_METHODS, LOSS_FUNCTIONS, DROPOUT_VARIANTS 32 33 34FLAGS = flags.FLAGS 35 36### EXPERIMENT 37flags.DEFINE_bool('train', True, 'Training a model') 38flags.DEFINE_bool('test', True, 'Testing the trained model on holdout data') 39flags.DEFINE_bool('metrics', True, 'Running vaerification metrics on the downscaled arrays') 40flags.DEFINE_bool('debug', False, 'If True a debug training run (2 epochs by default with 6 steps) is executed') 41 42### DOWNSCALING PARAMS 43flags.DEFINE_enum('trainer', 'SupervisedTrainer', ['SupervisedTrainer', 'CGANTrainer'], 'Tainer') 44flags.DEFINE_enum('paired_samples', 'implicit', ['implicit', 'explicit'], 'Type of learning: implicit (PerfectProg) or explicit (MOS)') 45flags.DEFINE_string('data_module', None, 'Python module where the data pre-processing is done') 46 47### MODEL 48flags.DEFINE_enum('backbone', 'resnet', BACKBONE_BLOCKS, 'Backbone section') 49flags.DEFINE_enum('upsampling', 'spc', UPSAMPLING_METHODS, 'Upsampling method') 50flags.DEFINE_integer('time_window', None, 'Time window for training spatio-temporal models') 51flags.DEFINE_integer('n_filters', 8, 'Number of convolutional filters for the first convolutional block') 52flags.DEFINE_integer('n_blocks', 6, 'Number of convolutional blocks') 53flags.DEFINE_integer('n_disc_filters', 32, 'Number of convolutional filters per convolutional block in the discriminator') 54flags.DEFINE_integer('n_disc_blocks', 4, 'Number of residual blocks for discriminator network') 55flags.DEFINE_enum('normalization', None, ['bn', 'ln'], 'Normalization') 56flags.DEFINE_float('dropout_rate', 0.2, 'Dropout rate') 57flags.DEFINE_enum('dropout_variant', 'vanilla', DROPOUT_VARIANTS, 'Dropout variants') 58flags.DEFINE_bool('attention', False, 'Attention block in convolutional layers') 59flags.DEFINE_enum('activation', 'relu', ['elu', 'relu', 'gelu', 'crelu', 'leaky_relu', 'selu'], 'Activation used in intermediate convolutional blocks') 60flags.DEFINE_enum('output_activation', None, ['elu', 'relu', 'gelu', 'crelu', 'leaky_relu', 'selu'], 'Activation used in the last convolutional block') 61flags.DEFINE_bool('localcon_layer', False, 'Locally connected convolutional layer') 62flags.DEFINE_enum('decoder_upsampling', 'rc', UPSAMPLING_METHODS, 'Upsampling in decoder blocks (unet backbone)') 63flags.DEFINE_enum('rc_interpolation', 'bilinear', INTERPOLATION_METHODS, 'Interpolation used in resize convolution upsampling') 64 65### TRAINING PROCEDURE 66flags.DEFINE_enum('device', 'GPU', ['GPU', 'CPU'], 'Device to be used: GPU or CPU') 67flags.DEFINE_bool('save', True, 'Saving to disk the trained model (last epoch), metrics, run info, etc') 68flags.DEFINE_string('save_path', './dl4ds_results/', 'Path for saving results to disk') 69flags.DEFINE_integer('scale', 2, 'Scaling factor, positive integer') 70flags.DEFINE_integer('epochs', 100, 'Number of training epochs') 71flags.DEFINE_enum('loss', 'mae', LOSS_FUNCTIONS, 'Loss function') 72flags.DEFINE_enum('interpolation', 'inter_area', INTERPOLATION_METHODS, 'Interpolation method') 73flags.DEFINE_integer('patch_size', None, 'Patch size in number of px/gridpoints') 74flags.DEFINE_integer('batch_size', 32, 'Batch size (of samples) used during training') 75flags.DEFINE_multi_float('learning_rate', 1e-3, 'Learning rate') 76flags.DEFINE_bool('gpu_memory_growth', True, 'To use GPU memory growth (gradual memory allocation)') 77flags.DEFINE_bool('use_multiprocessing', True, 'To use multiprocessing for data generation') 78flags.DEFINE_float('lr_decay_after', 1e5, 'Steps to tweak the learning rate using the PiecewiseConstantDecay scheduler') 79flags.DEFINE_bool('early_stopping', False, 'Early stopping') 80flags.DEFINE_integer('patience', 6, 'Patience in number of epochs w/o improvement for early stopping') 81flags.DEFINE_float('min_delta', 0.0, 'Minimum delta improvement for early stopping') 82flags.DEFINE_bool('show_plot', False, 'Show the learning curve plot on finish') 83flags.DEFINE_bool('save_bestmodel', True, 'SupervisedTrainer - Whether to save the best model (epoch with the best val_loss)') 84flags.DEFINE_bool('verbose', True, 'Verbosity') 85flags.DEFINE_integer('checkpoints_frequency', 2, 'CGANTrainer - Frequency for saving checkpoints and the generator') 86 87### INFERENCE/TEST 88flags.DEFINE_bool('inference_array_in_hr', False, 'Whether the inference array is in high resolution') 89flags.DEFINE_string('inference_save_fname', None, 'Filename for saving the inference array') 90 91 92 93def dl4ds(argv): 94 """DL4DS absl.FLAGS-based command line app. 95 """ 96 if running_on_first_worker: 97 print('<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< DL4DS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n') 98 99 # Run mode 100 if FLAGS.debug: 101 epochs = 2 102 steps_per_epoch = test_steps = validation_steps = 6 103 else: 104 epochs = FLAGS.epochs 105 steps_per_epoch = test_steps = validation_steps = None 106 107 if running_on_first_worker: 108 print('<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< Loading data >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n') 109 # Training data from Python script/module 110 if FLAGS.data_module is not None: 111 spec = importlib.util.spec_from_file_location("module.name", FLAGS.data_module) 112 DATA = importlib.util.module_from_spec(spec) 113 spec.loader.exec_module(DATA) 114 else: 115 raise ValueError('`data_module` flag must be provided (path to the data preprocessing module)') 116 117 # Architecture parameters 118 if FLAGS.time_window is None: 119 if FLAGS.upsampling == 'pin': 120 architecture_params = dict( 121 n_filters=FLAGS.n_filters, 122 n_blocks=FLAGS.n_blocks, 123 normalization=FLAGS.normalization, 124 dropout_rate=FLAGS.dropout_rate, 125 dropout_variant=FLAGS.dropout_variant, 126 attention=FLAGS.attention, 127 activation=FLAGS.activation, 128 localcon_layer=FLAGS.localcon_layer, 129 output_activation=FLAGS.output_activation) 130 if FLAGS.backbone == 'unet': 131 architecture_params['decoder_upsampling'] = FLAGS.decoder_upsampling 132 architecture_params['rc_interpolation'] = FLAGS.rc_interpolation 133 else: 134 architecture_params = dict( 135 n_filters=FLAGS.n_filters, 136 n_blocks=FLAGS.n_blocks, 137 normalization=FLAGS.normalization, 138 dropout_rate=FLAGS.dropout_rate, 139 dropout_variant=FLAGS.dropout_variant, 140 attention=FLAGS.attention, 141 activation=FLAGS.activation, 142 localcon_layer=FLAGS.localcon_layer, 143 output_activation=FLAGS.output_activation, 144 rc_interpolation=FLAGS.rc_interpolation) 145 else: 146 if FLAGS.upsampling == 'pin': 147 architecture_params = dict( 148 n_filters=FLAGS.n_filters, 149 n_blocks=FLAGS.n_blocks, 150 activation=FLAGS.activation, 151 normalization=FLAGS.normalization, 152 dropout_rate=FLAGS.dropout_rate, 153 dropout_variant=FLAGS.dropout_variant, 154 attention=FLAGS.attention, 155 output_activation=FLAGS.output_activation, 156 localcon_layer=FLAGS.localcon_layer) 157 else: 158 architecture_params = dict( 159 n_filters=FLAGS.n_filters, 160 activation=FLAGS.activation, 161 normalization=FLAGS.normalization, 162 dropout_rate=FLAGS.dropout_rate, 163 dropout_variant=FLAGS.dropout_variant, 164 attention=FLAGS.attention, 165 output_activation=FLAGS.output_activation, 166 localcon_layer=FLAGS.localcon_layer, 167 rc_interpolation=FLAGS.rc_interpolation) 168 169 if FLAGS.train: 170 if running_on_first_worker: 171 print('\n<<<<<<<<<<<<<<<<<<<<<<<<<<<<< DL4DS Training phase >>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n') 172 if FLAGS.trainer == 'SupervisedTrainer': 173 trainer = dds.SupervisedTrainer( 174 backbone=FLAGS.backbone, 175 upsampling=FLAGS.upsampling, 176 data_train=DATA.data_train, 177 data_val=DATA.data_val, 178 data_test=DATA.data_test, 179 data_train_lr=DATA.data_train_lr if FLAGS.paired_samples == 'explicit' else None, 180 data_val_lr=DATA.data_val_lr if FLAGS.paired_samples == 'explicit' else None, 181 data_test_lr=DATA.data_test_lr if FLAGS.paired_samples == 'explicit' else None, 182 predictors_train=DATA.predictors_train, 183 predictors_val=DATA.predictors_val, 184 predictors_test=DATA.predictors_test, 185 static_vars=DATA.static_vars, 186 scale=FLAGS.scale, 187 interpolation=FLAGS.interpolation, 188 patch_size=FLAGS.patch_size, 189 time_window=FLAGS.time_window, 190 batch_size=FLAGS.batch_size, 191 loss=FLAGS.loss, 192 epochs=epochs, 193 steps_per_epoch=steps_per_epoch, 194 validation_steps=validation_steps, 195 test_steps=test_steps, 196 device=FLAGS.device, 197 gpu_memory_growth=FLAGS.gpu_memory_growth, 198 use_multiprocessing=FLAGS.use_multiprocessing, 199 learning_rate=FLAGS.learning_rate, 200 lr_decay_after=FLAGS.lr_decay_after, 201 early_stopping=FLAGS.early_stopping, 202 patience=FLAGS.patience, 203 min_delta=FLAGS.min_delta, 204 show_plot=FLAGS.show_plot, 205 save=FLAGS.save, 206 save_path=FLAGS.save_path, 207 save_bestmodel=FLAGS.save_bestmodel, 208 trained_model=None, #FLAGS.trained_model, 209 trained_epochs=0, #FLAGS.trained_epochs, 210 verbose=FLAGS.verbose, 211 **architecture_params) 212 elif FLAGS.trainer == 'CGANTrainer': 213 discriminator_params = dict( 214 n_filters=FLAGS.n_disc_filters, 215 n_res_blocks=FLAGS.n_disc_blocks, 216 normalization=FLAGS.normalization, 217 activation=FLAGS.activation, 218 attention=FLAGS.attention) 219 220 trainer = dds.CGANTrainer( 221 backbone=FLAGS.backbone, 222 upsampling=FLAGS.upsampling, 223 data_train=DATA.data_train, 224 data_test=DATA.data_test, 225 data_train_lr=DATA.data_train_lr if FLAGS.paired_samples == 'explicit' else None, 226 data_test_lr=DATA.data_test_lr if FLAGS.paired_samples == 'explicit' else None, 227 predictors_train=DATA.predictors_train, 228 predictors_test=DATA.predictors_test, 229 scale=FLAGS.scale, 230 patch_size=FLAGS.patch_size, 231 time_window=FLAGS.time_window, 232 loss=FLAGS.loss, 233 epochs=epochs, 234 batch_size=FLAGS.batch_size, 235 learning_rates=FLAGS.learning_rate, 236 device=FLAGS.device, 237 gpu_memory_growth=FLAGS.gpu_memory_growth, 238 steps_per_epoch=steps_per_epoch, 239 interpolation=FLAGS.interpolation, 240 static_vars=DATA.static_vars, 241 checkpoints_frequency=FLAGS.checkpoints_frequency, 242 save=FLAGS.save, 243 save_path=FLAGS.save_path, 244 save_logs=False, 245 save_loss_history=FLAGS.save, 246 verbose=FLAGS.verbose, 247 generator_params=architecture_params, 248 discriminator_params=discriminator_params) 249 250 trainer.run() 251 252 if FLAGS.test: 253 if running_on_first_worker: 254 print('\n<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< DL4DS Test phase >>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n') 255 if DATA.inference_scaler is None: 256 inference_scaler = None 257 else: 258 inference_scaler = DATA.inference_scaler 259 260 if not has_horovod or running_on_first_worker: 261 predictor = dds.Predictor( 262 trainer=trainer, 263 array=DATA.inference_data, 264 array_in_hr=FLAGS.inference_array_in_hr, 265 scale=FLAGS.scale, 266 interpolation=FLAGS.interpolation, 267 predictors=DATA.inference_predictors, 268 static_vars=DATA.static_vars, 269 time_window=FLAGS.time_window, 270 batch_size=FLAGS.batch_size, 271 scaler=inference_scaler, 272 save_path=FLAGS.save_path, 273 save_fname=FLAGS.inference_save_fname, 274 device=FLAGS.device) 275 276 y_hat = predictor.run() 277 278 # Saving the downscaled product in netcdf format 279 y_hat_datarray = xr.DataArray(data=np.squeeze(y_hat), 280 dims=('time', 'lat', 'lon'), 281 coords={'time':DATA.gt_holdout_dataset.time, 282 'lon':DATA.gt_holdout_dataset.lon, 283 'lat':DATA.gt_holdout_dataset.lat}) 284 285 if FLAGS.save_path is not None: 286 y_hat_datarray.to_netcdf(f'{FLAGS.save_path}y_hat.nc') 287 288 if FLAGS.metrics: 289 if running_on_first_worker: 290 print('\n<<<<<<<<<<<<<<<<<<<<<<<<< DL4DS Metrics computation phase >>>>>>>>>>>>>>>>>>>>>>\n') 291 if not has_horovod or running_on_first_worker: 292 metrics = dds.compute_metrics( 293 y_test=DATA.gt_holdout_dataset, 294 y_test_hat=y_hat, 295 dpi=300, plot_size_px=1200, 296 mask=DATA.gt_mask, 297 save_path=FLAGS.save_path, 298 n_jobs=-1) 299 300if __name__ == '__main__': 301 app.run(dl4ds)
View Source
94def dl4ds(argv): 95 """DL4DS absl.FLAGS-based command line app. 96 """ 97 if running_on_first_worker: 98 print('<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< DL4DS >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n') 99 100 # Run mode 101 if FLAGS.debug: 102 epochs = 2 103 steps_per_epoch = test_steps = validation_steps = 6 104 else: 105 epochs = FLAGS.epochs 106 steps_per_epoch = test_steps = validation_steps = None 107 108 if running_on_first_worker: 109 print('<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< Loading data >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n') 110 # Training data from Python script/module 111 if FLAGS.data_module is not None: 112 spec = importlib.util.spec_from_file_location("module.name", FLAGS.data_module) 113 DATA = importlib.util.module_from_spec(spec) 114 spec.loader.exec_module(DATA) 115 else: 116 raise ValueError('`data_module` flag must be provided (path to the data preprocessing module)') 117 118 # Architecture parameters 119 if FLAGS.time_window is None: 120 if FLAGS.upsampling == 'pin': 121 architecture_params = dict( 122 n_filters=FLAGS.n_filters, 123 n_blocks=FLAGS.n_blocks, 124 normalization=FLAGS.normalization, 125 dropout_rate=FLAGS.dropout_rate, 126 dropout_variant=FLAGS.dropout_variant, 127 attention=FLAGS.attention, 128 activation=FLAGS.activation, 129 localcon_layer=FLAGS.localcon_layer, 130 output_activation=FLAGS.output_activation) 131 if FLAGS.backbone == 'unet': 132 architecture_params['decoder_upsampling'] = FLAGS.decoder_upsampling 133 architecture_params['rc_interpolation'] = FLAGS.rc_interpolation 134 else: 135 architecture_params = dict( 136 n_filters=FLAGS.n_filters, 137 n_blocks=FLAGS.n_blocks, 138 normalization=FLAGS.normalization, 139 dropout_rate=FLAGS.dropout_rate, 140 dropout_variant=FLAGS.dropout_variant, 141 attention=FLAGS.attention, 142 activation=FLAGS.activation, 143 localcon_layer=FLAGS.localcon_layer, 144 output_activation=FLAGS.output_activation, 145 rc_interpolation=FLAGS.rc_interpolation) 146 else: 147 if FLAGS.upsampling == 'pin': 148 architecture_params = dict( 149 n_filters=FLAGS.n_filters, 150 n_blocks=FLAGS.n_blocks, 151 activation=FLAGS.activation, 152 normalization=FLAGS.normalization, 153 dropout_rate=FLAGS.dropout_rate, 154 dropout_variant=FLAGS.dropout_variant, 155 attention=FLAGS.attention, 156 output_activation=FLAGS.output_activation, 157 localcon_layer=FLAGS.localcon_layer) 158 else: 159 architecture_params = dict( 160 n_filters=FLAGS.n_filters, 161 activation=FLAGS.activation, 162 normalization=FLAGS.normalization, 163 dropout_rate=FLAGS.dropout_rate, 164 dropout_variant=FLAGS.dropout_variant, 165 attention=FLAGS.attention, 166 output_activation=FLAGS.output_activation, 167 localcon_layer=FLAGS.localcon_layer, 168 rc_interpolation=FLAGS.rc_interpolation) 169 170 if FLAGS.train: 171 if running_on_first_worker: 172 print('\n<<<<<<<<<<<<<<<<<<<<<<<<<<<<< DL4DS Training phase >>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n') 173 if FLAGS.trainer == 'SupervisedTrainer': 174 trainer = dds.SupervisedTrainer( 175 backbone=FLAGS.backbone, 176 upsampling=FLAGS.upsampling, 177 data_train=DATA.data_train, 178 data_val=DATA.data_val, 179 data_test=DATA.data_test, 180 data_train_lr=DATA.data_train_lr if FLAGS.paired_samples == 'explicit' else None, 181 data_val_lr=DATA.data_val_lr if FLAGS.paired_samples == 'explicit' else None, 182 data_test_lr=DATA.data_test_lr if FLAGS.paired_samples == 'explicit' else None, 183 predictors_train=DATA.predictors_train, 184 predictors_val=DATA.predictors_val, 185 predictors_test=DATA.predictors_test, 186 static_vars=DATA.static_vars, 187 scale=FLAGS.scale, 188 interpolation=FLAGS.interpolation, 189 patch_size=FLAGS.patch_size, 190 time_window=FLAGS.time_window, 191 batch_size=FLAGS.batch_size, 192 loss=FLAGS.loss, 193 epochs=epochs, 194 steps_per_epoch=steps_per_epoch, 195 validation_steps=validation_steps, 196 test_steps=test_steps, 197 device=FLAGS.device, 198 gpu_memory_growth=FLAGS.gpu_memory_growth, 199 use_multiprocessing=FLAGS.use_multiprocessing, 200 learning_rate=FLAGS.learning_rate, 201 lr_decay_after=FLAGS.lr_decay_after, 202 early_stopping=FLAGS.early_stopping, 203 patience=FLAGS.patience, 204 min_delta=FLAGS.min_delta, 205 show_plot=FLAGS.show_plot, 206 save=FLAGS.save, 207 save_path=FLAGS.save_path, 208 save_bestmodel=FLAGS.save_bestmodel, 209 trained_model=None, #FLAGS.trained_model, 210 trained_epochs=0, #FLAGS.trained_epochs, 211 verbose=FLAGS.verbose, 212 **architecture_params) 213 elif FLAGS.trainer == 'CGANTrainer': 214 discriminator_params = dict( 215 n_filters=FLAGS.n_disc_filters, 216 n_res_blocks=FLAGS.n_disc_blocks, 217 normalization=FLAGS.normalization, 218 activation=FLAGS.activation, 219 attention=FLAGS.attention) 220 221 trainer = dds.CGANTrainer( 222 backbone=FLAGS.backbone, 223 upsampling=FLAGS.upsampling, 224 data_train=DATA.data_train, 225 data_test=DATA.data_test, 226 data_train_lr=DATA.data_train_lr if FLAGS.paired_samples == 'explicit' else None, 227 data_test_lr=DATA.data_test_lr if FLAGS.paired_samples == 'explicit' else None, 228 predictors_train=DATA.predictors_train, 229 predictors_test=DATA.predictors_test, 230 scale=FLAGS.scale, 231 patch_size=FLAGS.patch_size, 232 time_window=FLAGS.time_window, 233 loss=FLAGS.loss, 234 epochs=epochs, 235 batch_size=FLAGS.batch_size, 236 learning_rates=FLAGS.learning_rate, 237 device=FLAGS.device, 238 gpu_memory_growth=FLAGS.gpu_memory_growth, 239 steps_per_epoch=steps_per_epoch, 240 interpolation=FLAGS.interpolation, 241 static_vars=DATA.static_vars, 242 checkpoints_frequency=FLAGS.checkpoints_frequency, 243 save=FLAGS.save, 244 save_path=FLAGS.save_path, 245 save_logs=False, 246 save_loss_history=FLAGS.save, 247 verbose=FLAGS.verbose, 248 generator_params=architecture_params, 249 discriminator_params=discriminator_params) 250 251 trainer.run() 252 253 if FLAGS.test: 254 if running_on_first_worker: 255 print('\n<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< DL4DS Test phase >>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n') 256 if DATA.inference_scaler is None: 257 inference_scaler = None 258 else: 259 inference_scaler = DATA.inference_scaler 260 261 if not has_horovod or running_on_first_worker: 262 predictor = dds.Predictor( 263 trainer=trainer, 264 array=DATA.inference_data, 265 array_in_hr=FLAGS.inference_array_in_hr, 266 scale=FLAGS.scale, 267 interpolation=FLAGS.interpolation, 268 predictors=DATA.inference_predictors, 269 static_vars=DATA.static_vars, 270 time_window=FLAGS.time_window, 271 batch_size=FLAGS.batch_size, 272 scaler=inference_scaler, 273 save_path=FLAGS.save_path, 274 save_fname=FLAGS.inference_save_fname, 275 device=FLAGS.device) 276 277 y_hat = predictor.run() 278 279 # Saving the downscaled product in netcdf format 280 y_hat_datarray = xr.DataArray(data=np.squeeze(y_hat), 281 dims=('time', 'lat', 'lon'), 282 coords={'time':DATA.gt_holdout_dataset.time, 283 'lon':DATA.gt_holdout_dataset.lon, 284 'lat':DATA.gt_holdout_dataset.lat}) 285 286 if FLAGS.save_path is not None: 287 y_hat_datarray.to_netcdf(f'{FLAGS.save_path}y_hat.nc') 288 289 if FLAGS.metrics: 290 if running_on_first_worker: 291 print('\n<<<<<<<<<<<<<<<<<<<<<<<<< DL4DS Metrics computation phase >>>>>>>>>>>>>>>>>>>>>>\n') 292 if not has_horovod or running_on_first_worker: 293 metrics = dds.compute_metrics( 294 y_test=DATA.gt_holdout_dataset, 295 y_test_hat=y_hat, 296 dpi=300, plot_size_px=1200, 297 mask=DATA.gt_mask, 298 save_path=FLAGS.save_path, 299 n_jobs=-1)
DL4DS absl.FLAGS-based command line app.