Skip to content

Commit

Permalink
Fix for SPMFile instantiation
Browse files Browse the repository at this point in the history
  • Loading branch information
kogens committed Nov 11, 2023
1 parent 1d2afd4 commit 1919512
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions spmpy/spmloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,14 @@
class SPMFile:
""" Representation of an entire SPM file with images and metadata """

def __new__(cls, *args, **kwargs):
""" If class is instanced with a single parameter, it is either a path or bytestring """
if len(args) == 1 and isinstance(args[0], (str, Path)):
return cls.from_path(args[0])
def __init__(self, spmfile: str | PathLike | bytes):
if isinstance(spmfile, (str, Path)):
self.path = Path(spmfile)
bytestring = self.load_from_file(spmfile)
elif isinstance(spmfile, bytes):
bytestring = spmfile
else:
return super().__new__(cls)

def __init__(self, bytestring: bytes, path: str | PathLike = None):
if path:
self.path: Path = Path(path)
raise ValueError('SPM file must be path to spm file or raw bytestring')

self.header: dict = self.parse_header(bytestring)
self.images: dict = self.extract_ciao_images(self.header, bytestring)
Expand Down Expand Up @@ -60,13 +58,13 @@ def groups(self) -> dict[int | None, dict[str]]:

return groups

@classmethod
def from_path(cls, path):
@staticmethod
def load_from_file(path):
""" Load SPM data from a file on disk """
with open(path, 'rb') as f:
bytestring = f.read()

return cls(bytestring, path=path)
return bytestring

@staticmethod
def parse_header(bytestring) -> dict:
Expand Down Expand Up @@ -251,12 +249,12 @@ def interpret_file_header(header_bytetring: bytes, encoding: str = 'latin-1') \
n_image = 0

# Walk through each line of metadata and extract sections and parameters
for line in header_lines:
if line.startswith(b'\\*File list end'):
for line_bytes in header_lines:
if line_bytes == b'\\*File list end':
# End of header, break out of loop
break

line = line.decode(encoding).lstrip('\\')
line = line_bytes.decode(encoding).lstrip('\\')
if line.startswith('*'):
# Lines starting with * indicate a new section
current_section = line.strip('*')
Expand Down

0 comments on commit 1919512

Please sign in to comment.