Skip to content

Commit

Permalink
Fix clippy warnings and add some code factorization
Browse files Browse the repository at this point in the history
  • Loading branch information
fxpineau committed Oct 16, 2024
1 parent 2f42376 commit 873a372
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 73 deletions.
40 changes: 20 additions & 20 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -766,18 +766,18 @@ fn cdshealpix(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {

/// Cone search
#[pyfn(m)]
fn cone_search<'a>(
py: Python<'a>,
fn cone_search(
py: Python<'_>,
depth: u8,
delta_depth: u8,
lon: f64,
lat: f64,
radius: f64,
flat: bool,
) -> (
Bound<'a, PyArray1<u64>>,
Bound<'a, PyArray1<u8>>,
Bound<'a, PyArray1<bool>>,
Bound<'_, PyArray1<u64>>,
Bound<'_, PyArray1<u8>>,
Bound<'_, PyArray1<bool>>,
) {
let bmoc = healpix::nested::cone_coverage_approx_custom(depth, delta_depth, lon, lat, radius);

Expand All @@ -796,8 +796,8 @@ fn cdshealpix(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {

/// Elliptical cone search
#[pyfn(m)]
fn elliptical_cone_search<'a>(
py: Python<'a>,
fn elliptical_cone_search(
py: Python<'_>,
depth: u8,
delta_depth: u8,
lon: f64,
Expand All @@ -807,9 +807,9 @@ fn cdshealpix(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
pa: f64,
flat: bool,
) -> (
Bound<'a, PyArray1<u64>>,
Bound<'a, PyArray1<u8>>,
Bound<'a, PyArray1<bool>>,
Bound<'_, PyArray1<u64>>,
Bound<'_, PyArray1<u8>>,
Bound<'_, PyArray1<bool>>,
) {
let bmoc =
healpix::nested::elliptical_cone_coverage_custom(depth, delta_depth, lon, lat, a, b, pa);
Expand Down Expand Up @@ -878,8 +878,8 @@ fn cdshealpix(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
/// * ``b`` - size in degrees
/// * ``pa`` -rotation angle in degrees
#[pyfn(m)]
fn box_search<'a>(
py: Python<'a>,
fn box_search(
py: Python<'_>,
depth: u8,
lon: f64,
lat: f64,
Expand All @@ -888,9 +888,9 @@ fn cdshealpix(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
pa: f64,
flat: bool,
) -> (
Bound<'a, PyArray1<u64>>,
Bound<'a, PyArray1<u8>>,
Bound<'a, PyArray1<bool>>,
Bound<'_, PyArray1<u64>>,
Bound<'_, PyArray1<u8>>,
Bound<'_, PyArray1<bool>>,
) {
let bmoc = healpix::nested::box_coverage(depth, lon, lat, a, b, pa);

Expand Down Expand Up @@ -918,18 +918,18 @@ fn cdshealpix(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
/// * ``lon_max`` - east north corner longitude
/// * ``lat_max`` - east north corner latitude
#[pyfn(m)]
fn zone_search<'a>(
py: Python<'a>,
fn zone_search(
py: Python<'_>,
depth: u8,
lon_min: f64,
lat_min: f64,
lon_max: f64,
lat_max: f64,
flat: bool,
) -> (
Bound<'a, PyArray1<u64>>,
Bound<'a, PyArray1<u8>>,
Bound<'a, PyArray1<bool>>,
Bound<'_, PyArray1<u64>>,
Bound<'_, PyArray1<u8>>,
Bound<'_, PyArray1<bool>>,
) {
let bmoc = healpix::nested::zone_coverage(depth, lon_min, lat_min, lon_max, lat_max);

Expand Down
98 changes: 45 additions & 53 deletions src/skymap_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@ extern crate healpix;
extern crate mapproj;

use std::fs::File;
use std::i64;
use std::io::BufWriter;

use numpy::{IntoPyArray, Ix3, PyArray1, PyArray3, PyArrayMethods, PyReadonlyArray1};
use numpy::{
IntoPyArray, Ix3, NotContiguousError, PyArray1, PyArray3, PyArrayMethods, PyReadonlyArray1,
};
use pyo3::{
exceptions::{PyIOError, PyValueError},
prelude::*,
types::{PyAny, PyModule},
Bound, PyErr, PyResult,
};

use healpix::nested::map::skymap::SkyMapValue;
use healpix::{
depth_from_n_hash_unsafe,
nested::map::{
Expand All @@ -27,42 +30,42 @@ pub fn read_skymap<'py>(
module: &Bound<'py, PyModule>,
path: String,
) -> PyResult<Bound<'py, PyAny>> {
SkyMapEnum::from_fits_file(path.to_string())
SkyMapEnum::from_fits_file(&path)
.map_err(|err| PyIOError::new_err(err.to_string()))
.map(|sky_map_enum| match sky_map_enum {
SkyMapEnum::ImplicitU64U8(s) => s
.values()
.map(|v| *v)
.copied()
.collect::<Vec<u8>>()
.into_pyarray_bound(module.py())
.into_any(),
SkyMapEnum::ImplicitU64I16(s) => s
.values()
.map(|v| *v)
.copied()
.collect::<Vec<i16>>()
.into_pyarray_bound(module.py())
.into_any(),
SkyMapEnum::ImplicitU64I32(s) => s
.values()
.map(|v| *v)
.copied()
.collect::<Vec<i32>>()
.into_pyarray_bound(module.py())
.into_any(),
SkyMapEnum::ImplicitU64I64(s) => s
.values()
.map(|v| *v)
.copied()
.collect::<Vec<i64>>()
.into_pyarray_bound(module.py())
.into_any(),
SkyMapEnum::ImplicitU64F32(s) => s
.values()
.map(|v| *v)
.copied()
.collect::<Vec<f32>>()
.into_pyarray_bound(module.py())
.into_any(),
SkyMapEnum::ImplicitU64F64(s) => s
.values()
.map(|v| *v)
.copied()
.collect::<Vec<f64>>()
.into_pyarray_bound(module.py())
.into_any(),
Expand All @@ -79,37 +82,41 @@ pub enum SupportedArray<'py> {
I16(PyReadonlyArray1<'py, i16>),
U8(PyReadonlyArray1<'py, u8>),
}
impl<'py> SupportedArray<'py> {
fn n_hash(&self) -> u64 {
let n = match self {
SupportedArray::F64(values) => values.as_array().shape()[0],
SupportedArray::I64(values) => values.as_array().shape()[0],
SupportedArray::F32(values) => values.as_array().shape()[0],
SupportedArray::I32(values) => values.as_array().shape()[0],
SupportedArray::I16(values) => values.as_array().shape()[0],
SupportedArray::U8(values) => values.as_array().shape()[0],
};
n as u64
}
}

#[pyfunction]
pub fn write_skymap<'py>(values: SupportedArray<'py>, path: String) -> Result<(), PyErr> {
let writer = File::create(path).map_err(|err| PyIOError::new_err(err.to_string()))?;
pub fn write_skymap(values: SupportedArray<'_>, path: String) -> Result<(), PyErr> {
let writer: BufWriter<File> =
BufWriter::new(File::create(path).map_err(|err| PyIOError::new_err(err.to_string()))?);
match values {
SupportedArray::F64(values) => values.as_slice().map_err(|e| e.into()).and_then(|slice| {
write_implicit_skymap_fits(writer, slice)
.map_err(|err| PyIOError::new_err(err.to_string()).into())
}),
SupportedArray::I64(values) => values.as_slice().map_err(|e| e.into()).and_then(|slice| {
write_implicit_skymap_fits(writer, slice)
.map_err(|err| PyIOError::new_err(err.to_string()).into())
}),
SupportedArray::F32(values) => values.as_slice().map_err(|e| e.into()).and_then(|slice| {
write_implicit_skymap_fits(writer, slice)
.map_err(|err| PyIOError::new_err(err.to_string()).into())
}),
SupportedArray::I32(values) => values.as_slice().map_err(|e| e.into()).and_then(|slice| {
write_implicit_skymap_fits(writer, slice)
.map_err(|err| PyIOError::new_err(err.to_string()).into())
}),
SupportedArray::I16(values) => values.as_slice().map_err(|e| e.into()).and_then(|slice| {
write_implicit_skymap_fits(writer, slice)
.map_err(|err| PyIOError::new_err(err.to_string()).into())
}),
SupportedArray::U8(values) => values.as_slice().map_err(|e| e.into()).and_then(|slice| {
write_implicit_skymap_fits(writer, slice)
.map_err(|err| PyIOError::new_err(err.to_string()).into())
}),
SupportedArray::F64(values) => write_skymap_gen(writer, values.as_slice()),
SupportedArray::I64(values) => write_skymap_gen(writer, values.as_slice()),
SupportedArray::F32(values) => write_skymap_gen(writer, values.as_slice()),
SupportedArray::I32(values) => write_skymap_gen(writer, values.as_slice()),
SupportedArray::I16(values) => write_skymap_gen(writer, values.as_slice()),
SupportedArray::U8(values) => write_skymap_gen(writer, values.as_slice()),
}
}
fn write_skymap_gen<T: SkyMapValue>(
writer: BufWriter<File>,
as_slice_res: Result<&[T], NotContiguousError>,
) -> Result<(), PyErr> {
as_slice_res.map_err(move |e| e.into()).and_then(|slice| {
write_implicit_skymap_fits(writer, slice).map_err(|err| PyIOError::new_err(err.to_string()))
})
}

#[pyfunction]
#[pyo3(pass_module)]
Expand All @@ -119,15 +126,8 @@ pub fn pixels_skymap<'py>(
image_size: u16,
convert_to_gal: bool,
) -> PyResult<Bound<'py, PyArray3<u8>>> {
let n_hash: u64 = match &values {
SupportedArray::F64(values) => values.as_array().shape()[0] as u64,
SupportedArray::I64(values) => values.as_array().shape()[0] as u64,
SupportedArray::F32(values) => values.as_array().shape()[0] as u64,
SupportedArray::I32(values) => values.as_array().shape()[0] as u64,
SupportedArray::I16(values) => values.as_array().shape()[0] as u64,
SupportedArray::U8(values) => values.as_array().shape()[0] as u64,
};
let depth: u8 = depth_from_n_hash_unsafe(n_hash);
let n_hash = values.n_hash();
let depth = depth_from_n_hash_unsafe(n_hash);
// we have to use https://github.com/cds-astro/cds-healpix-rust/blob/847ae35945708efb6b949c3d15b3726ab7adeb2f/src/nested/map/img.rs#L391
match values {
SupportedArray::F64(values) => values.as_slice().map_err(|e| e.into()).and_then(|v| {
Expand Down Expand Up @@ -228,13 +228,5 @@ where

#[pyfunction]
pub fn depth_skymap(values: SupportedArray) -> u8 {
let n_hash: u64 = match &values {
SupportedArray::F64(values) => values.as_array().shape()[0] as u64,
SupportedArray::I64(values) => values.as_array().shape()[0] as u64,
SupportedArray::F32(values) => values.as_array().shape()[0] as u64,
SupportedArray::I32(values) => values.as_array().shape()[0] as u64,
SupportedArray::I16(values) => values.as_array().shape()[0] as u64,
SupportedArray::U8(values) => values.as_array().shape()[0] as u64,
};
depth_from_n_hash_unsafe(n_hash)
depth_from_n_hash_unsafe(values.n_hash())
}

0 comments on commit 873a372

Please sign in to comment.