Source code for arthropod_describer.plugins.pekar_unet.regions.unet_regions

import logging
from datetime import datetime

import os
import subprocess
import tempfile
import xml.etree.ElementTree as ET
from pathlib import Path
from typing import Set, Optional, List

import numpy as np
from skimage import io

from arthropod_describer.common.photo import Photo, LabelImg
from arthropod_describer.common.plugin import RegionComputation


[docs]class UNetRegions(RegionComputation): """ NAME: Body parts segmenter DESCRIPTION: Labels the parts of an insect. """ def __init__(self): RegionComputation.__init__(self, None) self.bin_path = Path(__file__).parent / 'pekar_unet' / 'bin' self.unet_path = self.bin_path / 'UNetSegmentationPlugin.exe' self.leg_segment_path = self.bin_path / 'LegSegmentsPlugin.exe' self.std_reflection_path = self.bin_path / 'StandardDeviationReflectionsPlugin.exe' def __call__(self, photo: Photo, labels: Optional[Set[int]] = None, storage=None) -> List[LabelImg]: img_loc = photo.image_path with tempfile.TemporaryDirectory() as out_d: temp_out_d = Path(out_d) os.mkdir(temp_out_d / 'maskfolder') os.mkdir(temp_out_d / 'reflectionsfolder') out_mask = temp_out_d / 'maskfolder' / f'unapprovedmask-temp____{img_loc.name}' out_refl = temp_out_d / 'reflectionsfolder' / f'unapprovedreflections-temp____{img_loc.name}' out_reg = temp_out_d / 'regionsfolder' reg_img_path = out_reg / f'unapprovedregions____{img_loc.name}' reg_xml_path = out_reg / f'unapprovedregions____{img_loc.name}.xml' os.mkdir(out_reg) args = [ str(self.unet_path), "r", str(img_loc), str(out_mask) ] print(f"Running plugin:\n{args}\n") logging.info(f"{datetime.now()} Running plugin:\n{args}\n") returncode = subprocess.run(args, cwd=str(self.bin_path.parent)) # compute reflections binary image args = [ str(self.std_reflection_path), "r", str(img_loc), str(out_mask), str(out_refl) ] print(f"Running plugin:\n{args}\n") logging.info(f"{datetime.now()} Running plugin:\n{args}\n") returncode = subprocess.run(args, cwd=str(self.bin_path.parent)) # finally, divide legs into sections args = [ str(self.leg_segment_path), "r", str(img_loc), str(out_mask), str(out_refl), str(reg_img_path), str(reg_xml_path) ] print(f"Running plugin:\n{args}\n") logging.info(f"{datetime.now()} Running plugin:\n{args}\n") returncode = subprocess.run(args, cwd=str(self.bin_path.parent)) reg_img = io.imread(str(reg_img_path)) root = ET.parse(out_reg / f'{reg_img_path.name}.xml') used_labels = [ident.text for ident in root.iter('identifier')] lab_hier = photo['Labels'].label_hierarchy mapping = {int(label): lab_hier.label(lab_hier.sep.join(list(label))) for label in used_labels} result_img = np.zeros_like(reg_img, dtype=np.uint32) for lab in mapping.keys(): coords = np.nonzero(reg_img == lab) result_img[coords] = mapping[lab] io.imsave(str(img_loc.parent.parent / 'Labels' / f'{img_loc.name}.tif'), result_img, check_contrast=False) #io.imsave(str(img_loc.parent.parent / 'Labels' / f'viz_{img_loc.name}'), color.label2rgb(result_img)) new_lab = photo['Labels'].clone() new_lab.label_image = result_img ref_img = io.imread(str(out_refl)) ref_lab_hier = photo['Reflections'].label_hierarchy ref_lab_img = photo['Reflections'] reflection_label = [lab for lab in ref_lab_hier.labels if lab > 0][0] ref_lab_img.label_image = np.where(ref_img > 0, reflection_label, 0).astype(np.uint32) return [new_lab, ref_lab_img]