Skip to content

Commit

Permalink
Move parallel support imports to .pxi so shims get correct type
Browse files Browse the repository at this point in the history
We also need to explicitly cast the comm/info arguments to the correct type
  • Loading branch information
ZedThree committed Sep 29, 2023
1 parent f9e7d11 commit a83f81b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
4 changes: 2 additions & 2 deletions include/netCDF4.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,8 @@ cdef extern from "numpy/arrayobject.h":
void import_array()


include "parallel_support_imports.pxi"

# Compatibility shims
cdef extern from "netcdf-compat.h":
int nc_rename_grp(int grpid, char *name) nogil
Expand Down Expand Up @@ -424,8 +426,6 @@ cdef extern from "netcdf-compat.h":
unsigned* addshufflep) nogil

# Parallel shims
ctypedef int MPI_Comm
ctypedef int MPI_Info
int nc_create_par(char *path, int cmode, MPI_Comm comm, MPI_Info info, int *ncidp) nogil
int nc_open_par(char *path, int mode, MPI_Comm comm, MPI_Info info, int *ncidp) nogil
int nc_var_par_access(int ncid, int varid, int par_access) nogil
Expand Down
6 changes: 4 additions & 2 deletions include/no_parallel_support_imports.pxi.in
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Stubs for when parallel support is not enabled

ctypedef int MPI_Comm
ctypedef int MPI_Info
ctypedef int Comm
ctypedef int Info
cdef Comm MPI_COMM_WORLD
cdef Info MPI_INFO_NULL
cdef MPI_Comm MPI_COMM_WORLD
cdef MPI_Info MPI_INFO_NULL
MPI_COMM_WORLD = 0
MPI_INFO_NULL = 0
5 changes: 2 additions & 3 deletions src/netCDF4/_netCDF4.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1244,7 +1244,6 @@ from numpy import ma
from libc.string cimport memcpy, memset
from libc.stdlib cimport malloc, free
numpy.import_array()
include "parallel_support_imports.pxi"
include "membuf.pyx"
include "netCDF4.pxi"

Expand Down Expand Up @@ -2265,11 +2264,11 @@ strings.
msg='parallel mode only works with the following formats: ' + ' '.join(parallel_formats)
raise ValueError(msg)
if comm is not None:
mpicomm = comm.ob_mpi
mpicomm = (<Comm?>comm).ob_mpi
else:
mpicomm = MPI_COMM_WORLD
if info is not None:
mpiinfo = info.ob_mpi
mpiinfo = (<Info?>info).ob_mpi
else:
mpiinfo = MPI_INFO_NULL
parmode = NC_MPIIO | _cmode_dict[format]
Expand Down

0 comments on commit a83f81b

Please sign in to comment.