diff --git a/otbtf/examples/tensorflow_v2x/fcnn/fcnn_model.py b/otbtf/examples/tensorflow_v2x/fcnn/fcnn_model.py
index 44285d92bb77695224288cb9d804018c90924c82..4f2377ae9d354b5d6e90728b254de864909eb7fe 100644
--- a/otbtf/examples/tensorflow_v2x/fcnn/fcnn_model.py
+++ b/otbtf/examples/tensorflow_v2x/fcnn/fcnn_model.py
@@ -5,6 +5,7 @@ import logging
 
 import tensorflow as tf
 
+import otbtf.layers
 from otbtf.model import ModelBase
 
 logging.basicConfig(
@@ -23,8 +24,9 @@ INPUT_NAME = "input_xs"
 # Name of the output in the `FCNNModel` instance
 TARGET_NAME = "predictions"
 
-# Name (prefix) of the output node in the SavedModel
-OUTPUT_SOFTMAX_NAME = "predictions_softmax_tensor"
+# Name (prefix) of the output nodes in the SavedModel
+OUTPUT_SOFTMAX_NAME = "predictions_softmax"
+OUTPUT_ARGMAX_NAME = "predictions_argmax"
 
 
 class FCNNModel(ModelBase):
@@ -115,15 +117,22 @@ class FCNNModel(ModelBase):
         # from `MyModel.get_output()`. They are two identifiers with a
         # different purpose:
         #  - the output layer name is used only at inference time, to identify
-        #    the output tensor from which generate the output image,
+        #    the output tensor from which generate the output image (e.g.
+        #    OUTPUT_SOFTMAX_NAME),
         #  - the output tensor key identifies the output tensors, mainly to
         #    fit the targets to model outputs during training process, but it
         #    can also be used to access the tensors as tf/keras objects, for
         #    instance to display previews images in TensorBoard.
+        # For convenience, since OTBTF 4.3.0, the post-processed outputs are
+        # automatically added after the output tensor keys. This avoids lazy
+        # users to explicitly name the last layers. When the name is already
+        # took by the layer, the post-processed outputs creating is skipped.
         softmax_op = tf.keras.layers.Softmax(name=OUTPUT_SOFTMAX_NAME)
         predictions = softmax_op(out_tconv4)
+        argmax_op = otbtf.layers.Argmax()
+        labels = argmax_op(predictions)
 
-        return {TARGET_NAME: predictions}
+        return {TARGET_NAME: predictions, OUTPUT_ARGMAX_NAME: labels}
 
 
 def dataset_preprocessing_fn(examples: dict):
@@ -146,9 +155,9 @@ def dataset_preprocessing_fn(examples: dict):
     """
     return {
         INPUT_NAME: examples["input_xs_patches"],
-        TARGET_NAME: tf.one_hot(
-            tf.squeeze(tf.cast(examples["labels_patches"], tf.int32), axis=-1),
-            depth=N_CLASSES
+        TARGET_NAME: otbtf.ops.one_hot(
+            labels=examples["labels_patches"],
+            nb_classes=N_CLASSES
         )
     }
 
@@ -173,12 +182,21 @@ def train(params, ds_train, ds_valid, ds_test):
         model = FCNNModel(dataset_element_spec=ds_train.element_spec)
 
         # Compile the model
+        # Here using a `dict` to explicitly name the outputs over which the
+        # losses/metrics are computed is a good practice.
         model.compile(
-            loss=tf.keras.losses.CategoricalCrossentropy(),
+            loss={
+                TARGET_NAME: tf.keras.losses.CategoricalCrossentropy()
+            },
             optimizer=tf.keras.optimizers.Adam(
                 learning_rate=params.learning_rate
             ),
-            metrics=[tf.keras.metrics.Precision(), tf.keras.metrics.Recall()]
+            metrics={
+                TARGET_NAME: [
+                    tf.keras.metrics.Precision(class_id=1),
+                    tf.keras.metrics.Recall(class_id=1)
+                ]
+            }
         )
 
         # Summarize the model (in CLI)
diff --git a/otbtf/model.py b/otbtf/model.py
index 9958510bdf32dd147df273a264714c93d2543ee1..1c648de241bcb4715e17294cd78f8f23dd5b5ab0 100644
--- a/otbtf/model.py
+++ b/otbtf/model.py
@@ -34,7 +34,7 @@ TensorsDict = Dict[str, Tensor]
 
 class ModelBase(abc.ABC):
     """
-    Base class for all models
+    Base class for all fully convolutional models
     """
 
     def __init__(
@@ -178,24 +178,35 @@ class ModelBase(abc.ABC):
             a dict of post-processed model outputs
 
         """
-
-        # Add extra outputs for inference
+        # Dict of extra outputs for inference
         extra_outputs = {}
         for out_key, out_tensor in outputs.items():
-            for crop in self.inference_cropping:
-                extra_output_key = cropped_tensor_name(out_key, crop)
-                extra_output_name = cropped_tensor_name(
-                    out_tensor._keras_history.layer.name, crop
-                )
-                logging.info(
-                    "Adding extra output for tensor %s with crop %s (%s)",
-                    out_key, crop, extra_output_name
+            layer_name = out_tensor._keras_history.layer.name
+            # extra output named after layer
+            srcs_names = [layer_name]
+            if layer_name == out_key:
+                logging.warning(
+                    "Output \"%s\" already exist from layer of the same "
+                    "name. Skipping extra-outputs creation. If you have "
+                    "any doubt, you can use the following command to see "
+                    "available model outputs: "
+                    "`saved_model_cli show --dir your_model_dir --all`",
+                    layer_name
                 )
-                cropped = out_tensor[:, crop:-crop, crop:-crop, :]
-                identity = tf.keras.layers.Activation(
-                    'linear', name=extra_output_name
-                )
-                extra_outputs[extra_output_key] = identity(cropped)
+            else:
+                # extra output named after output key
+                srcs_names += [out_key]
+            # Now for all accepted src_name, we create extra outputs
+            for crop in self.inference_cropping:
+                for src_name in srcs_names:
+                    tgt_name = cropped_tensor_name(src_name, crop)
+                    logging.info(
+                        "Adding extra output for tensor %s with crop %s: %s",
+                        out_key, crop, tgt_name
+                    )
+                    eye = tf.keras.layers.Activation("linear", name=tgt_name)
+                    extra_out = eye(out_tensor[:, crop:-crop, crop:-crop, :])
+                    extra_outputs[f"postproc_{tgt_name}"] = extra_out
 
         return extra_outputs
 
diff --git a/test/api_unittest.py b/test/api_unittest.py
index 2fe3fe38632d8739dca7a61458ec95d398b0170e..aa2f51b3a146863aab79b74ace1867db0fac7296 100644
--- a/test/api_unittest.py
+++ b/test/api_unittest.py
@@ -8,9 +8,10 @@ from otbtf.examples.tensorflow_v2x.fcnn import create_tfrecords
 from otbtf.examples.tensorflow_v2x.fcnn import train_from_patchesimages
 from otbtf.examples.tensorflow_v2x.fcnn import train_from_tfrecords
 from otbtf.examples.tensorflow_v2x.fcnn.fcnn_model import INPUT_NAME, \
-    OUTPUT_SOFTMAX_NAME
+    OUTPUT_SOFTMAX_NAME, OUTPUT_ARGMAX_NAME, TARGET_NAME
 from otbtf.model import cropped_tensor_name
-from test_utils import resolve_paths, files_exist, run_command_and_compare
+from test_utils import resolve_paths, files_exist, run_command_and_compare, \
+    run_command_and_test_exist
 
 INFERENCE_MAE_TOL = 10.0  # Dummy value: we don't really care of the mae value but rather the image size etc
 
@@ -39,42 +40,85 @@ class APITest(unittest.TestCase):
             '$TMPDIR/model_from_pimg/variables/variables.index'
         ]))
 
-    @pytest.mark.order(2)
+    @pytest.mark.order(2, 5)
     def test_model_inference1(self):
-        self.assertTrue(
-            run_command_and_compare(
-                command=
+        def _make_command(out_name, rfield, efield):
+            return (
                 "otbcli_TensorflowModelServe "
                 "-source1.il $DATADIR/fake_spot6.jp2 "
-                "-source1.rfieldx 64 "
-                "-source1.rfieldy 64 "
+                f"-source1.rfieldx {rfield} "
+                f"-source1.rfieldy {rfield} "
                 f"-source1.placeholder {INPUT_NAME} "
                 "-model.dir $TMPDIR/model_from_pimg "
                 "-model.fullyconv on "
-                f"-output.names {cropped_tensor_name(OUTPUT_SOFTMAX_NAME, 16)} "
-                "-output.efieldx 32 "
-                "-output.efieldy 32 "
-                "-out \"$TMPDIR/classif_model4_softmax.tif?&gdal:co:compress=deflate\" uint8",
-                to_compare_dict={
-                    "$DATADIR/classif_model4_softmax.tif": "$TMPDIR/classif_model4_softmax.tif"},
-                tol=INFERENCE_MAE_TOL))
-        self.assertTrue(
-            run_command_and_compare(
-                command=
-                "otbcli_TensorflowModelServe "
-                "-source1.il $DATADIR/fake_spot6.jp2 "
-                "-source1.rfieldx 128 "
-                "-source1.rfieldy 128 "
-                f"-source1.placeholder {INPUT_NAME} "
-                "-model.dir $TMPDIR/model_from_pimg "
-                "-model.fullyconv on "
-                f"-output.names {cropped_tensor_name(OUTPUT_SOFTMAX_NAME, 32)} "
-                "-output.efieldx 64 "
-                "-output.efieldy 64 "
-                "-out \"$TMPDIR/classif_model4_softmax.tif?&gdal:co:compress=deflate\" uint8",
-                to_compare_dict={
-                    "$DATADIR/classif_model4_softmax.tif": "$TMPDIR/classif_model4_softmax.tif"},
-                tol=INFERENCE_MAE_TOL))
+                f"-output.names {out_name} "
+                f"-output.efieldx {efield} "
+                f"-output.efieldy {efield} "
+                "-out \"$TMPDIR/classif_model4_softmax.tif?&gdal:co:compress=deflate\" uint8"
+            )
+
+        command_crop16_softmax = _make_command(
+            out_name=cropped_tensor_name(OUTPUT_SOFTMAX_NAME, 16),
+            rfield=64,
+            efield=32
+        )
+        command_crop32_softmax = _make_command(
+            out_name=cropped_tensor_name(OUTPUT_SOFTMAX_NAME, 32),
+            rfield=128,
+            efield=64
+        )
+        command_crop16_softmax2 = _make_command(
+            out_name=cropped_tensor_name(TARGET_NAME, 16),
+            rfield=64,
+            efield=32
+        )
+        command_crop32_softmax2 = _make_command(
+            out_name=cropped_tensor_name(TARGET_NAME, 32),
+            rfield=128,
+            efield=64
+        )
+        command_crop16_argmax = _make_command(
+            out_name=cropped_tensor_name(OUTPUT_ARGMAX_NAME, 16),
+            rfield=64,
+            efield=32
+        )
+        command_crop32_argmax = _make_command(
+            out_name=cropped_tensor_name(OUTPUT_ARGMAX_NAME, 32),
+            rfield=128,
+            efield=64
+        )
+
+        def _test_compare(command):
+            self.assertTrue(
+                run_command_and_compare(
+                    command=command,
+                    to_compare_dict={
+                        "$DATADIR/classif_model4_softmax.tif":
+                            "$TMPDIR/classif_model4_softmax.tif"
+                    },
+                    tol=INFERENCE_MAE_TOL
+                )
+            )
+
+        def _test_exist(command):
+            self.assertTrue(
+                run_command_and_test_exist(
+                    command=command,
+                    file_list=["$TMPDIR/classif_model4_softmax.tif"]
+                )
+            )
+
+        # softmax (from layer name)
+        _test_compare(command_crop16_softmax)
+        _test_compare(command_crop32_softmax)
+
+        # softmax (from target key i.e. model output name)
+        _test_compare(command_crop16_softmax2)
+        _test_compare(command_crop32_softmax2)
+
+        # argmax
+        _test_exist(command_crop16_argmax)
+        _test_exist(command_crop32_argmax)
 
     @pytest.mark.order(3)
     def test_create_tfrecords(self):
@@ -116,48 +160,6 @@ class APITest(unittest.TestCase):
             '$TMPDIR/model_from_tfrecs/variables/variables.index'
         ]))
 
-    @pytest.mark.order(5)
-    def test_model_inference2(self):
-        self.assertTrue(
-            run_command_and_compare(
-                command=
-                "otbcli_TensorflowModelServe "
-                "-source1.il $DATADIR/fake_spot6.jp2 "
-                "-source1.rfieldx 64 "
-                "-source1.rfieldy 64 "
-                f"-source1.placeholder {INPUT_NAME} "
-                "-model.dir $TMPDIR/model_from_pimg "
-                "-model.fullyconv on "
-                f"-output.names {cropped_tensor_name(OUTPUT_SOFTMAX_NAME, 16)} "
-                "-output.efieldx 32 "
-                "-output.efieldy 32 "
-                "-out \"$TMPDIR/classif_model4_softmax.tif?&gdal:co:compress=deflate\" uint8",
-                to_compare_dict={
-                    "$DATADIR/classif_model4_softmax.tif":
-                        "$TMPDIR/classif_model4_softmax.tif"
-                },
-                tol=INFERENCE_MAE_TOL))
-
-        self.assertTrue(
-            run_command_and_compare(
-                command=
-                "otbcli_TensorflowModelServe "
-                "-source1.il $DATADIR/fake_spot6.jp2 "
-                "-source1.rfieldx 128 "
-                "-source1.rfieldy 128 "
-                f"-source1.placeholder {INPUT_NAME} "
-                "-model.dir $TMPDIR/model_from_pimg "
-                "-model.fullyconv on "
-                f"-output.names {cropped_tensor_name(OUTPUT_SOFTMAX_NAME, 32)} "
-                "-output.efieldx 64 "
-                "-output.efieldy 64 "
-                "-out \"$TMPDIR/classif_model4_softmax.tif?&gdal:co:compress=deflate\" uint8",
-                to_compare_dict={
-                    "$DATADIR/classif_model4_softmax.tif":
-                        "$TMPDIR/classif_model4_softmax.tif"
-                },
-                tol=INFERENCE_MAE_TOL))
-
 
 if __name__ == '__main__':
     unittest.main()