Skip to content

Commit

Permalink
Allow output datatype to be defined by users for clustering functions…
Browse files Browse the repository at this point in the history
… and set default to byte so is more gtiff friendly.
  • Loading branch information
petebunting committed Jul 20, 2024
1 parent ccd4505 commit 5ac6ad7
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions python/rsgislib/classification/clustersklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def img_pixel_sample_cluster(
),
calc_stats: bool = True,
use_mean_shift_est_band_width: bool = False,
datatype: int = rsgislib.TYPE_8UINT,
):
"""
A function which allows a clustering to be performed using the algorithms available
Expand All @@ -86,6 +87,9 @@ def img_pixel_sample_cluster(
:param use_mean_shift_est_band_width: use the mean-shift algorithm as the clusterer
(pass None as the clusterer) where the
bandwidth is calculated from the data itself.
:param datatype: the output image data type - needs to be a
rsgislib datatype (e.g., rsgislib.TYPE_8UINT)
"""
from rios import applier

Expand Down Expand Up @@ -117,6 +121,7 @@ def img_pixel_sample_cluster(
otherargs = applier.OtherInputs()
otherargs.clusterer = clusterer
otherargs.no_data_val = no_data_val
otherargs.np_datatype = rsgislib.get_numpy_datatype(datatype)
aControls = applier.ApplierControls()
aControls.progress = progress_bar
aControls.creationoptions = rsgislib.imageutils.get_rios_img_creation_opts(
Expand Down Expand Up @@ -155,7 +160,9 @@ def _apply_sk_clusterer(info, inputs, outputs, otherargs):
out_cluster_vals[ID] = out_pred

out_cluster_vals = out_cluster_vals.astype(numpy.int32)
outputs.output_img = out_cluster_vals.reshape([1, img_shp[1], img_shp[2]])
outputs.output_img = out_cluster_vals.reshape(
[1, img_shp[1], img_shp[2]]
).astype(otherargs.np_datatype)

print("Applying to Whole Image")
applier.apply(_apply_sk_clusterer, infiles, outfiles, otherargs, controls=aControls)
Expand Down Expand Up @@ -189,6 +196,7 @@ def img_pixel_tiled_cluster(
use_mean_shift_est_band_width: bool = False,
tile_x_size: int = 200,
tile_y_size: int = 200,
datatype: int = rsgislib.TYPE_8UINT,
):
"""
A function which allows a clustering to be performed using the algorithms available
Expand All @@ -210,6 +218,9 @@ def img_pixel_tiled_cluster(
bandwidth is calculated from the data itself.
:param tile_x_size: tile size in the x-axis in pixels.
:param tile_y_size: tile size in the y-axis in pixels.
:param datatype: the output image data type - needs to be a
rsgislib datatype (e.g., rsgislib.TYPE_8UINT)
"""
from rios import applier

Expand All @@ -227,6 +238,7 @@ def img_pixel_tiled_cluster(
outfiles.output_img = output_img
otherargs = applier.OtherInputs()
otherargs.no_data_val = no_data_val
otherargs.np_datatype = rsgislib.get_numpy_datatype(datatype)
otherargs.clusterer = clusterer
aControls = applier.ApplierControls()
aControls.progress = progress_bar
Expand Down Expand Up @@ -275,7 +287,9 @@ def _apply_sk_tiled_clusterer(info, inputs, outputs, otherargs):
out_cluster_vals[ID] = out_pred

out_cluster_vals = out_cluster_vals.astype(numpy.int32)
outputs.output_img = out_cluster_vals.reshape([1, img_shp[1], img_shp[2]])
outputs.output_img = out_cluster_vals.reshape(
[1, img_shp[1], img_shp[2]]
).astype(otherargs.np_datatype)

applier.apply(
_apply_sk_tiled_clusterer, infiles, outfiles, otherargs, controls=aControls
Expand Down Expand Up @@ -308,6 +322,7 @@ def img_pixel_cluster(
),
calc_stats: bool = True,
use_mean_shift_est_band_width: bool = False,
datatype: int = rsgislib.TYPE_8UINT,
):
"""
A function which allows a clustering to be performed using the algorithms available
Expand All @@ -328,10 +343,13 @@ def img_pixel_cluster(
:param use_mean_shift_est_band_width: use the mean-shift algorithm as the clusterer
(pass None as the clusterer) where the
bandwidth is calculated from the data itself.
:param datatype: the output image data type - needs to be a
rsgislib datatype (e.g., rsgislib.TYPE_8UINT)
"""
# Create output image
rsgislib.imageutils.create_copy_img(
input_img, output_img, 1, 0, gdalformat, rsgislib.TYPE_16UINT
input_img, output_img, 1, 0, gdalformat, datatype
)

if use_mean_shift_est_band_width:
Expand Down

0 comments on commit 5ac6ad7

Please sign in to comment.