Skip to content

Commit 062b984

Browse files
authored
Merge pull request #177 from AFM-SPM/ns-rse/work-with-topostats-classes
fix: decode numpy arrays; feature: topostats versions
2 parents f6fc37d + 0b1c527 commit 062b984

5 files changed

Lines changed: 60 additions & 6 deletions

File tree

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ py-version=3.9
6262

6363
# When enabled, pylint would attempt to guess common misconfiguration and emit
6464
# user-friendly hints instead of false-positive error messages.
65-
suggestion-mode=yes
65+
# suggestion-mode=yes
6666

6767
# Allow loading of arbitrary C extensions. Extensions are imported into the
6868
# active Python interpreter and may run arbitrary code.

AFMReader/io.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import BinaryIO
66

77
import h5py
8+
import numpy as np
89
from loguru import logger
910
from ruamel.yaml import YAML, YAMLError
1011

@@ -255,7 +256,14 @@ def unpack_hdf5(open_hdf5_file: h5py.File, group_path: str = "/") -> dict:
255256
# Decode byte strings to utf-8. The data type "O" is a byte string.
256257
elif isinstance(item, h5py.Dataset) and item.dtype == "O":
257258
# Byte string
258-
data[key] = item[()].decode("utf-8")
259+
try:
260+
data[key] = item[()].decode("utf-8")
261+
# Numpy arrays of strings can not be directly decoded, have to iterate over each item
262+
except AttributeError as e:
263+
if isinstance(item[()], np.ndarray):
264+
data[key] = [_item.decode("utf-8") for _item in item[()]] # type: ignore
265+
else:
266+
raise e
259267
else:
260268
# Another type of dataset
261269
data[key] = item[()]

AFMReader/topostats.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import h5py
77

8+
from packaging.version import parse as parse_version
89
from AFMReader.io import unpack_hdf5
910
from AFMReader.logging import logger
1011

@@ -41,10 +42,15 @@ def load_topostats(file_path: Path | str) -> dict[str, Any]:
4142
try:
4243
with h5py.File(file_path, "r") as f:
4344
data = unpack_hdf5(open_hdf5_file=f, group_path="/")
44-
if str(data["topostats_file_version"]) >= "0.2":
45+
# Handle different names for variables holding the file version (<=0.3) or the newer topostats version
46+
version = (
47+
data["topostats_file_version"]
48+
if "topostats_file_version" in data.keys() # pylint: disable=consider-iterating-dictionary
49+
else data["topostats_version"]
50+
)
51+
if parse_version(str(version)) > parse_version("0.2"):
4552
data["img_path"] = Path(data["img_path"])
46-
file_version = data["topostats_file_version"]
47-
logger.info(f"[{filename}] TopoStats file version : {file_version}")
53+
logger.info(f"[{filename}] TopoStats file version : {version}")
4854

4955
except OSError as e:
5056
if "Unable to open file" in str(e):

tests/test_io.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,46 @@ def test_unpack_hdf5_nested_dict_group_path(tmp_path: Path) -> None:
201201
np.testing.assert_equal(result, expected)
202202

203203

204+
def test_unpack_hdf5_list_of_bytes(tmp_path: Path) -> None:
205+
"""Test loading a list of strings which are encoded to Numpy array on saving."""
206+
to_save = {
207+
"config": {
208+
"grainstats": {
209+
"class_names": np.asarray([b"DNA", b"Protein"], dtype="S7"),
210+
"edge_detection_method": "binary_erosion",
211+
"extract_height_profile": True,
212+
"run": True,
213+
}
214+
}
215+
}
216+
group_path = "/config/grainstats/"
217+
expected = {
218+
"class_names": np.asarray([b"DNA", b"Protein"], dtype="S7"),
219+
"edge_detection_method": "binary_erosion",
220+
"extract_height_profile": True,
221+
"run": True,
222+
}
223+
# Manually save the dictionary to HDF5 format
224+
with h5py.File(tmp_path / "hdf5_file_list_of_strings", "w") as f:
225+
# t_path = Path.cwd()
226+
# with h5py.File(t_path / "tmp" / "something_else", "w") as f:
227+
config = f.create_group("config")
228+
grainstats = config.create_group("grainstats")
229+
grainstats.create_dataset("class_names", data=to_save["config"]["grainstats"]["class_names"])
230+
grainstats.create_dataset(
231+
"edge_detection_method", data=to_save["config"]["grainstats"]["edge_detection_method"]
232+
)
233+
grainstats.create_dataset(
234+
"extract_height_profile", data=to_save["config"]["grainstats"]["extract_height_profile"]
235+
)
236+
grainstats.create_dataset("run", data=to_save["config"]["grainstats"]["run"])
237+
238+
# Load it back in and check if the list is the same
239+
with h5py.File(tmp_path / "hdf5_file_list_of_strings", "r") as f:
240+
result = unpack_hdf5(open_hdf5_file=f, group_path=group_path)
241+
np.testing.assert_equal(result, expected)
242+
243+
204244
def test_read_yaml() -> None:
205245
"""Test reading of YAML file."""
206246
sample_config = read_yaml(RESOURCES / "test.yaml")

tests/test_topostats.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test_load_topostats(
9999
assert topostats_data["pixel_to_nm_scaling"] == pytest.approx(pixel_to_nm_scaling)
100100
assert topostats_data["image"].shape == image_shape
101101
assert topostats_data["image"].sum() == pytest.approx(image_sum)
102-
if version >= "0.2":
102+
if version > "0.2":
103103
assert isinstance(topostats_data["img_path"], Path)
104104

105105

0 commit comments

Comments
 (0)