From 5ac6ad7886297e2d84c25979e4447877c7c8866d Mon Sep 17 00:00:00 2001 From: Pete Bunting Date: Sat, 20 Jul 2024 18:25:44 +0100 Subject: [PATCH] Allow output datatype to be defined by users for clustering functions and set default to byte so is more gtiff friendly. --- .../rsgislib/classification/clustersklearn.py | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/python/rsgislib/classification/clustersklearn.py b/python/rsgislib/classification/clustersklearn.py index 67cf122f..73795d24 100644 --- a/python/rsgislib/classification/clustersklearn.py +++ b/python/rsgislib/classification/clustersklearn.py @@ -65,6 +65,7 @@ def img_pixel_sample_cluster( ), calc_stats: bool = True, use_mean_shift_est_band_width: bool = False, + datatype: int = rsgislib.TYPE_8UINT, ): """ A function which allows a clustering to be performed using the algorithms available @@ -86,6 +87,9 @@ def img_pixel_sample_cluster( :param use_mean_shift_est_band_width: use the mean-shift algorithm as the clusterer (pass None as the clusterer) where the bandwidth is calculated from the data itself. + :param datatype: the output image data type - needs to be a + rsgislib datatype (e.g., rsgislib.TYPE_8UINT) + """ from rios import applier @@ -117,6 +121,7 @@ def img_pixel_sample_cluster( otherargs = applier.OtherInputs() otherargs.clusterer = clusterer otherargs.no_data_val = no_data_val + otherargs.np_datatype = rsgislib.get_numpy_datatype(datatype) aControls = applier.ApplierControls() aControls.progress = progress_bar aControls.creationoptions = rsgislib.imageutils.get_rios_img_creation_opts( @@ -155,7 +160,9 @@ def _apply_sk_clusterer(info, inputs, outputs, otherargs): out_cluster_vals[ID] = out_pred out_cluster_vals = out_cluster_vals.astype(numpy.int32) - outputs.output_img = out_cluster_vals.reshape([1, img_shp[1], img_shp[2]]) + outputs.output_img = out_cluster_vals.reshape( + [1, img_shp[1], img_shp[2]] + ).astype(otherargs.np_datatype) print("Applying to Whole Image") applier.apply(_apply_sk_clusterer, infiles, outfiles, otherargs, controls=aControls) @@ -189,6 +196,7 @@ def img_pixel_tiled_cluster( use_mean_shift_est_band_width: bool = False, tile_x_size: int = 200, tile_y_size: int = 200, + datatype: int = rsgislib.TYPE_8UINT, ): """ A function which allows a clustering to be performed using the algorithms available @@ -210,6 +218,9 @@ def img_pixel_tiled_cluster( bandwidth is calculated from the data itself. :param tile_x_size: tile size in the x-axis in pixels. :param tile_y_size: tile size in the y-axis in pixels. + :param datatype: the output image data type - needs to be a + rsgislib datatype (e.g., rsgislib.TYPE_8UINT) + """ from rios import applier @@ -227,6 +238,7 @@ def img_pixel_tiled_cluster( outfiles.output_img = output_img otherargs = applier.OtherInputs() otherargs.no_data_val = no_data_val + otherargs.np_datatype = rsgislib.get_numpy_datatype(datatype) otherargs.clusterer = clusterer aControls = applier.ApplierControls() aControls.progress = progress_bar @@ -275,7 +287,9 @@ def _apply_sk_tiled_clusterer(info, inputs, outputs, otherargs): out_cluster_vals[ID] = out_pred out_cluster_vals = out_cluster_vals.astype(numpy.int32) - outputs.output_img = out_cluster_vals.reshape([1, img_shp[1], img_shp[2]]) + outputs.output_img = out_cluster_vals.reshape( + [1, img_shp[1], img_shp[2]] + ).astype(otherargs.np_datatype) applier.apply( _apply_sk_tiled_clusterer, infiles, outfiles, otherargs, controls=aControls @@ -308,6 +322,7 @@ def img_pixel_cluster( ), calc_stats: bool = True, use_mean_shift_est_band_width: bool = False, + datatype: int = rsgislib.TYPE_8UINT, ): """ A function which allows a clustering to be performed using the algorithms available @@ -328,10 +343,13 @@ def img_pixel_cluster( :param use_mean_shift_est_band_width: use the mean-shift algorithm as the clusterer (pass None as the clusterer) where the bandwidth is calculated from the data itself. + :param datatype: the output image data type - needs to be a + rsgislib datatype (e.g., rsgislib.TYPE_8UINT) + """ # Create output image rsgislib.imageutils.create_copy_img( - input_img, output_img, 1, 0, gdalformat, rsgislib.TYPE_16UINT + input_img, output_img, 1, 0, gdalformat, datatype ) if use_mean_shift_est_band_width: