Skip to content

Commit

Permalink
cover more tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
z0gSh1u committed Jun 5, 2024
1 parent dfb1db0 commit f193365
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 48 deletions.
79 changes: 31 additions & 48 deletions crip/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ def averageProjections(projections: ThreeD) -> TwoD:
def correctFlatDarkField(projections: TwoOrThreeD,
flat: Or[TwoD, ThreeD, float, None] = None,
dark: Or[TwoD, ThreeD, float] = 0,
fillNaN: Or[float, None] = 0,
fillInf: Or[float, None] = 0) -> TwoOrThreeD:
fillNaN: Or[float, None] = 0) -> TwoOrThreeD:
''' Perform flat field (air) and dark field correction to get post-log projections.
I.e., `- log [(X - D) / (F - D)]`.
Usually `flat` and `dark` are 2D.
Expand All @@ -42,7 +41,13 @@ def correctFlatDarkField(projections: TwoOrThreeD,
'''
if flat is None:
cripWarning(False, '`flat` is None. Use the maximum value of each view instead.')
flat = np.max(projections, axis=0) * np.ones_like(projections)
flat = np.max(projections.reshape((projections.shape[0], -1)), axis=1)
flat = np.ones_like(projections) * flat[:, np.newaxis, np.newaxis]

if not isType(flat, np.ndarray):
flat = np.ones_like(projections) * flat
if not isType(dark, np.ndarray):
dark = np.ones_like(projections) * dark

sampleProjection = projections if is2D(projections) else projections[0]

Expand All @@ -56,15 +61,13 @@ def checkShape(haystack, needle, needleName):

numerator = projections - dark
denominator = flat - dark
cripAssert(np.min(numerator > 0), 'Some `projections` values are not greater than zero after canceling `dark`.')
cripAssert(np.min(denominator > 0), 'Some `flat` values are not greater than zero after canceling `dark`.')
cripWarning(np.min(numerator > 0), 'Some `projections` values are not greater than zero after canceling `dark`.')
cripWarning(np.min(denominator > 0), 'Some `flat` values are not greater than zero after canceling `dark`.')

res = -np.log(numerator / denominator)

if fillInf is not None:
res[res == np.inf] = fillInf
if fillNaN is not None:
res[res == np.nan] = fillNaN
res = np.nan_to_num(res, nan=fillNaN)

return res

Expand Down Expand Up @@ -102,39 +105,35 @@ def sinogramsToProjections(sinograms: ThreeD):
@ConvertListNDArray
def padImage(img: TwoOrThreeD,
padding: Tuple[int, int, int, int],
mode: str = 'symmetric',
decay: Or[str, None] = None):
mode: str = 'reflect',
decay: bool = False,
cval: float = 0):
'''
Pad the image on four directions using symmetric `padding` (Up, Right, Down, Left). \\
`mode` determines the border value, can be `symmetric`, `edge`, `constant` (zero), `reflect`. \\
`decay` can be None, `cosine`, `smoothCosine` to perform a decay on padded border.
Pad each 2D image on four directions using symmetric `padding` (Up, Right, Down, Left).
`mode` determines the border value, can be `edge`, `constant` (with `cval`), `reflect`.
`decay` can be True to smooth padded border.
'''
cripAssert(mode in ['symmetric', 'edge', 'constant', 'reflect'], f'Invalid mode: {mode}.')
cripAssert(decay in [None, 'cosine', 'smoothCosine'], f'Invalid decay: {decay}.')
cripAssert(mode in ['edge', 'constant', 'reflect'], f'Invalid mode: {mode}.')

decays = {
'cosine':
lambda ascend, dot: np.cos(np.linspace(-np.pi / 2, 0, dot) if ascend else np.linspace(0, np.pi / 2, dot)),
'smoothCosine':
lambda ascend, dot: 0.5 * np.cos(np.linspace(-np.pi, 0, dot)) + 0.5
if ascend else 0.5 * np.cos(np.linspace(0, np.pi, dot)) + 0.5
}
cosineDecay = lambda ascend, dot: np.cos(
np.linspace(-np.pi / 2, 0, dot) if ascend else np.linspace(0, np.pi / 2, dot)),

h, w = getHnW(img)
nPadU, nPadR, nPadD, nPadL = padding
padH = h + nPadU + nPadD
padW = w + nPadL + nPadR

def decayLR(xPad, w, nPadL, nPadR, decay):
xPad[:, 0:nPadL] *= decay(True, nPadL)[:]
xPad[:, w - nPadR:w] *= decay(False, nPadR)[:]
def decayLR(xPad, w, nPadL, nPadR):
xPad[:, 0:nPadL] *= cosineDecay(True, nPadL)[:]
xPad[:, w - nPadR:w] *= cosineDecay(False, nPadR)[:]
return xPad

def procOne(img):
xPad = np.pad(img, ((nPadU, nPadD), (nPadL, nPadR)), mode=mode)
if decay is not None:
xPad = decayLR(xPad, padW, nPadL, nPadR, decays[decay])
xPad = decayLR(xPad.T, padH, nPadU, nPadD, decays[decay])
kwargs = {'constant_values': cval} if mode == 'constant' else {}
xPad = np.pad(img, ((nPadU, nPadD), (nPadL, nPadR)), mode=mode, **kwargs)
if decay:
xPad = decayLR(xPad, padW, nPadL, nPadR)
xPad = decayLR(xPad.T, padH, nPadU, nPadD)
xPad = xPad.T
return xPad

Expand All @@ -146,23 +145,7 @@ def procOne(img):


@ConvertListNDArray
def padSinogram(sgms: TwoOrThreeD, padding: Or[int, Tuple[int, int]], mode='symmetric', decay='smoothCosine'):
'''
Pad sinograms in width direction (same line detector elements) using `mode` and `decay`\\
with `padding` (single int, or (right, left)).
@see padImage for parameter details.
'''
if isType(padding, int):
padding = (padding, padding)

l, r = padding

return padImage(sgms, (0, r, 0, l), mode, decay)


@ConvertListNDArray
def correctRingArtifactInProj(sgm: TwoOrThreeD, sigma: float, ksize: Or[int, None] = None):
def correctRingArtifactProjLi(sgm: TwoOrThreeD, sigma: float, ksize: Or[int, None] = None):
'''
Apply the ring artifact correction method in projection domain (input postlog sinogram),
using gaussian filter in sinogram detector direction [1].
Expand Down Expand Up @@ -193,8 +176,8 @@ def fanToPara(sgm: TwoD, gammas: NDArray, betas: NDArray, sid: float, oThetas: T
`gammas` is fan angles from min to max [RAD], computed by `arctan(elementOffcenter / SDD)` for each element.
`betas` is system rotation angles from min to max [RAD].
`sid` is Source-Isocenter-Distance [mm].
`oThetas` is output rotation angle range (min, delta, max) tuple [RAD]
`oLines` is output detector element physical locations range (min, delta, max) tuple [mm], e.g., `elementOffcenter` array
`oThetas` is output rotation angle range as (min, delta, max) tuple [RAD] (excluding max)
`oLines` is output detector element physical locations range as (min, delta, max) tuple [mm] (excluding max)
```
/| <- gamma for detector element X
/ | <- SID
Expand Down
144 changes: 144 additions & 0 deletions test/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,149 @@ def test_twoD(self):
def test_threeD(self):
# a 3d array should return the average of the first axis
res = averageProjections(self.threeD)
assert res.ndim == 2
assert np.allclose(res, self.twoD)


def test_correctFlatDarkField():
projections = np.array([np.ones((3, 3)), np.ones((3, 3)) * 2])

# provide flat
flat = np.ones((3, 3)) * 3
res1 = correctFlatDarkField(projections, flat)
assert np.allclose(res1, -np.log(projections / flat))

# no flat
projections[0, 0, 0] = 10
res2 = correctFlatDarkField(projections)
assert res2[0, 0, 0] == pytest.approx(0)
assert res2[0, 0, 1] == pytest.approx(-np.log(1 / 10))
assert res2[1, 0, 0] == pytest.approx(0)

# nan
projections[0, 0, 0] = np.nan
res3 = correctFlatDarkField(projections, fillNaN=-1)
assert res3[0, 0, 0] == -1


def test_projectionsToSinograms():
projections = np.ones((1, 2, 3))
res = projectionsToSinograms(projections)
assert res.shape == (2, 1, 3)


def test_sinogramsToProjections():
sinograms = np.ones((2, 1, 3))
res = sinogramsToProjections(sinograms)
assert res.shape == (1, 2, 3)


def test_padImage():
image = np.ones((3, 3))
res1 = padImage(image, (2, 2, 2, 2), mode='constant', cval=10)
assert res1.shape == (7, 7)
assert res1[0, 0] == 10

res2 = padImage(image, (2, 2, 2, 2), mode='reflect')
assert res2[0, 0] == 1

image3D = np.ones((2, 3, 3))
res3 = padImage(image3D, (1, 1, 1, 1))
assert res3.shape == (2, 5, 5)


def test_correctRingArtifactProjLi():
pass


def test_fanToPara():
pass


import numpy as np
import pytest
from crip.preprocess import *
from crip.utils import CripException


class Test_averageProjections:
twoD = np.array([
[1, 2],
[3, 4],
])
threeD = np.array([twoD, twoD])

def test_twoD(self):
# a 2d array should raise an error
with pytest.raises(CripException):
averageProjections(self.twoD)

def test_threeD(self):
# a 3d array should return the average of the first axis
res = averageProjections(self.threeD)
assert res.ndim == 2
assert np.allclose(res, self.twoD)


def test_correctFlatDarkField():
projections = np.array([np.ones((3, 3)), np.ones((3, 3)) * 2])

# provide flat
flat = np.ones((3, 3)) * 3
res1 = correctFlatDarkField(projections, flat)
assert np.allclose(res1, -np.log(projections / flat))

# no flat
projections[0, 0, 0] = 10
res2 = correctFlatDarkField(projections)
assert res2[0, 0, 0] == pytest.approx(0)
assert res2[0, 0, 1] == pytest.approx(-np.log(1 / 10))
assert res2[1, 0, 0] == pytest.approx(0)

# nan
projections[0, 0, 0] = np.nan
res3 = correctFlatDarkField(projections, fillNaN=-1)
assert res3[0, 0, 0] == -1


def test_projectionsToSinograms():
projections = np.ones((1, 2, 3))
res = projectionsToSinograms(projections)
assert res.shape == (2, 1, 3)


def test_sinogramsToProjections():
sinograms = np.ones((2, 1, 3))
res = sinogramsToProjections(sinograms)
assert res.shape == (1, 2, 3)


def test_padImage():
image = np.ones((3, 3))
res1 = padImage(image, (2, 2, 2, 2), mode='constant', cval=10)
assert res1.shape == (7, 7)
assert res1[0, 0] == 10

res2 = padImage(image, (2, 2, 2, 2), mode='reflect')
assert res2[0, 0] == 1

image3D = np.ones((2, 3, 3))
res3 = padImage(image3D, (1, 1, 1, 1))
assert res3.shape == (2, 5, 5)


def test_correctRingArtifactProjLi():
pass


def test_fanToPara():
sgm = np.ones((4, 3))
gammas = np.array([0.1, 0.2, 0.3])
betas = np.array([0.4, 0.5, 0.6, 0.7])
sid = 100.0
oThetas = (0.0, 0.1, 0.2)
oLines = (-50.0, 10.0, 50.0)

res = fanToPara(sgm, gammas, betas, sid, oThetas, oLines)
assert res.shape == (2, 10)
assert set(list(res.flatten())) == set([0, 1])

0 comments on commit f193365

Please sign in to comment.