From 00b2f746cbd0a1d89d8a0441e0380b9cf4f86264 Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@inrae.fr> Date: Mon, 19 Sep 2022 19:39:46 +0200 Subject: [PATCH 1/3] ENH: better model plotting --- otbtf/model.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/otbtf/model.py b/otbtf/model.py index a6306e27..f7e96779 100644 --- a/otbtf/model.py +++ b/otbtf/model.py @@ -145,23 +145,24 @@ class ModelBase(abc.ABC): def summary(self, strategy=None): """ Wraps the summary printing of the model. When multiworker strategy, only prints if the worker is chief + + :param strategy: strategy """ if not strategy or _is_chief(strategy): self.model.summary(line_length=150) - def plot(self, output_path, strategy=None): + def plot(self, output_path, strategy=None, show_shapes=False): """ Enables to save a figure representing the architecture of the network. Needs pydot and graphviz to work (`pip install pydot` and https://graphviz.gitlab.io/download/) + + :param output_path: output path for the schema + :param strategy: strategy + :param strategy: annotate with shapes values + """ assert self.model, "Plot() only works if create_network() has been called beforehand" # When multiworker strategy, only plot if the worker is chief if not strategy or _is_chief(strategy): - # Build a simplified model, without normalization nor extra outputs. - # This model is only used for plotting the architecture thanks to `keras.utils.plot_model` - inputs = self.get_inputs() # inputs without normalization - outputs = self.get_outputs(inputs) # raw model outputs - model_simplified = tensorflow.keras.Model(inputs=inputs, outputs=outputs, - name=self.__class__.__name__ + '_simplified') - tensorflow.keras.utils.plot_model(model_simplified, output_path) + tensorflow.keras.utils.plot_model(self.model, output_path, show_shapes=show_shapes) -- GitLab From d235e0c6037127c33280945fa5bd90daa052427d Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@inrae.fr> Date: Mon, 19 Sep 2022 21:00:05 +0200 Subject: [PATCH 2/3] DOC: docstring --- otbtf/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/otbtf/model.py b/otbtf/model.py index f7e96779..91800dcd 100644 --- a/otbtf/model.py +++ b/otbtf/model.py @@ -158,7 +158,7 @@ class ModelBase(abc.ABC): :param output_path: output path for the schema :param strategy: strategy - :param strategy: annotate with shapes values + :param show_shapes: annotate with shapes values (True or False) """ assert self.model, "Plot() only works if create_network() has been called beforehand" -- GitLab From 4c37af7b3337745de8c62505aa7c9d4e2824656d Mon Sep 17 00:00:00 2001 From: Remi Cresson <remi.cresson@inrae.fr> Date: Mon, 19 Sep 2022 21:00:22 +0200 Subject: [PATCH 3/3] STYLE: pylint --- otbtf/tfrecords.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/otbtf/tfrecords.py b/otbtf/tfrecords.py index 254efd2a..15a23354 100644 --- a/otbtf/tfrecords.py +++ b/otbtf/tfrecords.py @@ -172,7 +172,8 @@ class TFRecords: :param num_parallel_calls: number of parallel calls for the parsing + preprocessing step :param kwargs: some keywords arguments for preprocessing_fn """ - for dic, file in zip([self.output_types, self.output_shapes], [self.output_types_file, self.output_shapes_file]): + for dic, file in zip([self.output_types, self.output_shapes], + [self.output_types_file, self.output_shapes_file]): assert dic, f"The file {file} is missing!" options = tf.data.Options() -- GitLab