From 0944e976bc543f9171127c83f9e0fe29c80a0c98 Mon Sep 17 00:00:00 2001 From: Jiao Lin Date: Wed, 24 Aug 2016 15:45:45 -0400 Subject: [PATCH] Refs #53. added another tilt-calculator (DirectMinimization) that uses Rick's idea; switch find_rot_center to also use direct minimization. --- python/imars3d/tilt/__init__.py | 40 +++++---- python/imars3d/tilt/direct.py | 120 +++++++++++++++++++++++++ python/imars3d/tilt/find_rot_center.py | 16 ++++ python/imars3d/tilt/use_centers.py | 4 +- tests/imars3d/tilt/test_direct.py | 47 ++++++++++ tests/imars3d/tilt/test_direct2.py | 24 +++++ tests/imars3d/tilt/test_tilt.py | 6 +- 7 files changed, 236 insertions(+), 21 deletions(-) create mode 100644 python/imars3d/tilt/direct.py create mode 100755 tests/imars3d/tilt/test_direct.py create mode 100755 tests/imars3d/tilt/test_direct2.py diff --git a/python/imars3d/tilt/__init__.py b/python/imars3d/tilt/__init__.py index a018770d..b1852d59 100644 --- a/python/imars3d/tilt/__init__.py +++ b/python/imars3d/tilt/__init__.py @@ -2,26 +2,31 @@ import os, numpy as np, warnings import logging -from . import use_centers, phasecorrelation +from . import use_centers, phasecorrelation, direct def compute(ct_series, workdir, max_npairs=10): from . import use_centers - calculator = use_centers.Calculator(sigma=15, maxshift=200) - try: - tilt = _compute( - ct_series, os.path.join(workdir, 'testrun'), max_npairs=10, - calculator=calculator) - except: - warnings.warn("Failed to use centers to determine tilt. Now try phase correlation method") - calculator = phasecorrelation.PhaseCorrelation() - return _compute( - ct_series, workdir, max_npairs=max_npairs, - calculator=calculator) - if abs(tilt) > 0.8: - calculator = phasecorrelation.PhaseCorrelation() - tilt = _compute( - ct_series, workdir, max_npairs=max_npairs, - calculator=calculator) + calculators = [ + use_centers.UseCenters(sigma=15, maxshift=200), + phasecorrelation.PhaseCorrelation(), + direct.DirectMinimization(), + ] + tilt = None + for calculator in calculators: + kind = calculator.__class__.__name__ + # print kind + try: + tilt = _compute( + ct_series, os.path.join(workdir, kind), + max_npairs=10, + calculator=calculator) + # print tilt + break + except: + warnings.warn("Failed to use %s to determine tilt" % kind) + continue + if tilt is None: + raise RuntimeError("Failed to compute tilt") return tilt def _compute(ct_series, workdir, max_npairs=10, calculator=None): @@ -45,6 +50,7 @@ def _compute(ct_series, workdir, max_npairs=10, calculator=None): logger.info("working on pair %s, %s" % (a0, a180)) logging_dir=os.path.join(workdir, "log.tilt.%s_vs_%s"%(a0, a180)) if not os.path.exists(logging_dir): + os.makedirs(logging_dir) calculator.logging_dir=logging_dir tilt, weight = calculator(img(a0), img(a180)) open(os.path.join(logging_dir, 'tilt.out'), 'wt')\ diff --git a/python/imars3d/tilt/direct.py b/python/imars3d/tilt/direct.py new file mode 100644 index 00000000..502cd8be --- /dev/null +++ b/python/imars3d/tilt/direct.py @@ -0,0 +1,120 @@ +# imars3d.tilt.direct + +""" +directly compute tilt in real space +* no fft to compute polar distribution like in phasecorrelation +* no edge detection like in use_centers + +Just simply +* find the center of rotation +* find rotation angle by doing a minimization +""" + +import os, numpy as np +from imars3d import io +from matplotlib import pyplot as plt + +class DirectMinimization: + + def __init__(self, logging_dir=None, **opts): + self.logging_dir = logging_dir + self.opts = opts + return + + def __call__(self, img0, img180): + return computeTilt(img0.data, img180.data), 1.0 + + +def computeTilt(img0, img180, workdir=None, **kwds): + flipped_img180 = np.fliplr(img180) + shift = findShift(img0, flipped_img180) + differ = lambda a,b: shift_diff(shift, a,b) + tilts = np.arange(-2., 2.1, 0.2) + tilt = _argmin_tilt(tilts, img0, flipped_img180, differ) + tilts = np.arange(tilt-0.2, tilt+0.21, 0.02) + tilt = _argmin_tilt(tilts, img0, flipped_img180, differ) + return tilt + + +def _argmin_tilt(tilts, img0, flipped_img180, differ): + nrows, ncols = img0.shape + borderY, borderX = nrows//20, ncols//20 + from skimage.transform import rotate + diffs = [] + for tilt in tilts: + a = rotate(img0/np.max(img0), -tilt)[borderY:-borderY, borderX:-borderX] + b = rotate(flipped_img180/np.max(flipped_img180), tilt)[borderY:-borderY, borderX:-borderX] + diff = differ(a,b) + print("* tilt=%s, diff=%s" % (tilt, diff)) + diffs.append(diff) + continue + return tilts[np.argmin(diffs)] + + +def shift_diff(x, img1, img2): + x = int(x) + if x>0: + left = img1[:, :-x] + right = img2[:, x:] + elif x<0: + left = img1[:, -x:] + right = img2[:, :x] + else: + left = img1 + right = img2 + return ((left-right)**2).sum()/left.size + +MAX_SHIFT = 100 +def findShift(img0, flipped_img180): + """find the shift in number of pixels + + note: the relation between rot center and shift is + rot_center = -shift/2 if 0 is center of image + """ + print("* Calculating shift...") + ncols = img0.shape[1] + def diff(x): + return shift_diff(x, img0, flipped_img180) + start = max(-ncols//2, -MAX_SHIFT) + end = min(MAX_SHIFT, ncols//2) + xs = range(start, end) + diffs = [diff(x) for x in xs] + index = np.argmin(diffs) + guess = xs[index] + return guess + assert index >=5 and index < len(xs)-6 + # around guess + x = xs[index-3: index+4] + y = diffs[index-3: index+4] + plt.plot(x,y) + plt.savefig("tilt-direct-around-guess.png") + # fit to parabolic + a = np.polyfit(x,y,2) + c = -a[1]/2/a[0] + return c + + +def findShift_byfft(img0, flipped_img180): + "compute shift from img0 to the flipped img180" + import numpy.fft as npfft, numpy as np + A = npfft.fft2(1-img0) + B = npfft.fft2(1-flipped_img180) + C = A * np.conjugate(B) + C /= np.abs(C) + D = npfft.ifft2(C) + plt.imshow(np.real(D)) + plt.savefig("D.png") + pos = np.argmax(D) + nrows, ncols = D.shape + col = pos % ncols + row = pos // ncols + # should be around zero + if row > nrows//10: + row -= nrows + if col > ncols//10: + col -= ncols + if abs(row) > nrows//10 or abs(col) > ncols//10: + msg = "computed displacement unexpectedly too large: %s, %s"% ( + col, row) + raise RuntimeError(msg) + return col, row diff --git a/python/imars3d/tilt/find_rot_center.py b/python/imars3d/tilt/find_rot_center.py index ee6b61e2..368da2b3 100644 --- a/python/imars3d/tilt/find_rot_center.py +++ b/python/imars3d/tilt/find_rot_center.py @@ -2,6 +2,22 @@ import os, numpy as np def find(ct_series, workdir=None): + img = lambda angle: ct_series.getImage(angle) + from . import _find180DegImgPairs + from .direct import findShift + pairs = _find180DegImgPairs(ct_series.identifiers) + centers = [] + for i, (a0, a180) in enumerate(pairs): + workdir1=os.path.join(workdir, "%s_vs_%s"%(a0, a180)) + shift = findShift(img(a0).data, np.fliplr(img(a180).data)) + center = -shift/2. + img(a0).data.shape[1]/2. + # print shift, center, img(a0).data.shape[1]/2. + centers.append(center) + continue + return np.median(centers) + + +def find_using_edges(ct_series, workdir=None): img = lambda angle: ct_series.getImage(angle) from . import _find180DegImgPairs from .use_centers import computeTilt diff --git a/python/imars3d/tilt/use_centers.py b/python/imars3d/tilt/use_centers.py index 7be0cd01..a14ef010 100644 --- a/python/imars3d/tilt/use_centers.py +++ b/python/imars3d/tilt/use_centers.py @@ -4,7 +4,7 @@ from imars3d import io from matplotlib import pyplot as plt -class Calculator: +class UseCenters: def __init__(self, logging_dir=None, **opts): self.logging_dir = logging_dir @@ -15,7 +15,7 @@ def __call__(self, img0, img180): slope, intercept = computeTilt(img0, img180, workdir=self.logging_dir, **self.opts) # print (slope, np.arctan(slope)) return .7 * np.arctan(slope)*180./np.pi, 1.0 - +Calculator = UseCenters def computeTilt(img0, img180, workdir=None, **kwds): centers = np.array( diff --git a/tests/imars3d/tilt/test_direct.py b/tests/imars3d/tilt/test_direct.py new file mode 100755 index 00000000..df49c919 --- /dev/null +++ b/tests/imars3d/tilt/test_direct.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os, numpy as np +from imars3d import io +from imars3d import tilt +from imars3d.tilt import direct +dir = os.path.dirname(__file__) +datadir = os.path.join(dir, "..", "..", "iMars3D_data_set", "tilt", "16040") + +def test_Calculator(): + # ct + angles = np.arange(0, 180.5, 1.) + ct_series = io.ImageFileSeries( + os.path.join(datadir, "cropped*_%07.3f.tiff"), + identifiers = angles, + decimal_mark_replacement=".", + name = "CT", + ) + calculator = direct.DirectMinimization() + t = tilt._compute(ct_series, "_tmp/test_direct/work", calculator=calculator) + print t + return + + +def test_computeTilt(): + img0 = io.ImageFile(os.path.join(datadir, "cropped_000.000.tiff")).data + img180 = io.ImageFile(os.path.join(datadir, "cropped_180.000.tiff")).data + print direct.computeTilt(img0, img180) + return + +def test_shift(): + img0 = io.ImageFile(os.path.join(datadir, "cropped_000.000.tiff")).data + img180 = io.ImageFile(os.path.join(datadir, "cropped_180.000.tiff")).data + flipped_180 = np.fliplr(img180) + print direct.findShift(img0, flipped_180) + return + +def main(): + # test_shift() + # test_computeTilt() + test_Calculator() + return + +if __name__ == '__main__': main() + +# End of file diff --git a/tests/imars3d/tilt/test_direct2.py b/tests/imars3d/tilt/test_direct2.py new file mode 100755 index 00000000..fbdd1bfb --- /dev/null +++ b/tests/imars3d/tilt/test_direct2.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os, numpy as np +from imars3d import io +from imars3d import tilt +from imars3d.tilt import direct +dir = os.path.dirname(__file__) +datadir = os.path.join(dir, "..", "..", "iMars3D_data_set", "turbine", 'cleaned') + +def test_computeTilt(): + img0 = io.ImageFile(os.path.join(datadir, "smoothed_000.000.tiff")).data + img180 = io.ImageFile(os.path.join(datadir, "smoothed_180.200.tiff")).data + t = direct.computeTilt(img0, img180) + assert t>-2 and t<-1 + return + +def main(): + test_computeTilt() + return + +if __name__ == '__main__': main() + +# End of file diff --git a/tests/imars3d/tilt/test_tilt.py b/tests/imars3d/tilt/test_tilt.py index 80704147..7d75d6fe 100755 --- a/tests/imars3d/tilt/test_tilt.py +++ b/tests/imars3d/tilt/test_tilt.py @@ -18,6 +18,7 @@ def test_tilt(): from imars3d.tilt.phasecorrelation import PhaseCorrelation calculator = PhaseCorrelation() t = tilt._compute(ct_series, "_tmp/test_tilt/work", calculator=calculator) + print(t) assert t>-2 and t<-1 return @@ -34,9 +35,10 @@ def test_tilt2(): name = "CT", ) from imars3d.tilt.use_centers import Calculator - calculator = Calculator(sigma=3, maxshift=200) + calculator = Calculator(sigma=15, maxshift=200) t = tilt._compute(ct_series, "_tmp/test_tilt2/work", calculator=calculator) - assert t>-1 and t<-.5 + print(t) + assert t>-1.5 and t<-.5 return