Module facetorch.utils
Expand source code
import os
import omegaconf
import torch
import torchvision
from facetorch.datastruct import ImageData
def rgb2bgr(tensor: torch.Tensor) -> torch.Tensor:
"""Converts a batch of RGB tensors to BGR tensors or vice versa.
Args:
tensor (torch.Tensor): Batch of RGB (or BGR) channeled tensors
with shape (dim0, channels, dim2, dim3)
Returns:
torch.Tensor: Batch of BGR (or RGB) tensors with shape (dim0, channels, dim2, dim3).
"""
assert tensor.shape[1] == 3, "Tensor must have 3 channels."
return tensor[:, [2, 1, 0]]
def draw_boxes_and_save(data: ImageData, path_output: str) -> None:
"""Draws boxes on an image and saves it to a file.
Args:
data (ImageData): ImageData object containing the image tensor, detections, and faces.
path_output (str): Path to the output file.
Returns:
None
"""
os.makedirs(os.path.dirname(path_output), exist_ok=True)
loc_tensor = data.aggregate_loc_tensor()
labels = [str(face.indx) for face in data.faces]
data.img = torchvision.utils.draw_bounding_boxes(
image=data.img,
boxes=loc_tensor,
labels=labels,
colors="green",
width=3,
)
pil_image = torchvision.transforms.functional.to_pil_image(data.img)
pil_image.save(path_output)
def fix_transform_list_attr(
transform: torchvision.transforms.Compose,
) -> torchvision.transforms.Compose:
"""Fix the transform attributes by converting the listconfig to a list.
This enables to optimize the transform using TorchScript.
Args:
transform (torchvision.transforms.Compose): Transform to be fixed.
Returns:
torchvision.transforms.Compose: Fixed transform.
"""
for transform_x in transform.transforms:
for key, value in transform_x.__dict__.items():
if isinstance(value, omegaconf.listconfig.ListConfig):
transform_x.__dict__[key] = list(value)
return transform
Functions
def rgb2bgr(tensor: torch.Tensor) ‑> torch.Tensor
-
Converts a batch of RGB tensors to BGR tensors or vice versa.
Args
tensor
:torch.Tensor
- Batch of RGB (or BGR) channeled tensors
with shape (dim0, channels, dim2, dim3)
Returns
torch.Tensor
- Batch of BGR (or RGB) tensors with shape (dim0, channels, dim2, dim3).
Expand source code
def rgb2bgr(tensor: torch.Tensor) -> torch.Tensor: """Converts a batch of RGB tensors to BGR tensors or vice versa. Args: tensor (torch.Tensor): Batch of RGB (or BGR) channeled tensors with shape (dim0, channels, dim2, dim3) Returns: torch.Tensor: Batch of BGR (or RGB) tensors with shape (dim0, channels, dim2, dim3). """ assert tensor.shape[1] == 3, "Tensor must have 3 channels." return tensor[:, [2, 1, 0]]
def draw_boxes_and_save(data: ImageData, path_output: str) ‑> None
-
Draws boxes on an image and saves it to a file.
Args
data
:ImageData
- ImageData object containing the image tensor, detections, and faces.
path_output
:str
- Path to the output file.
Returns
None
Expand source code
def draw_boxes_and_save(data: ImageData, path_output: str) -> None: """Draws boxes on an image and saves it to a file. Args: data (ImageData): ImageData object containing the image tensor, detections, and faces. path_output (str): Path to the output file. Returns: None """ os.makedirs(os.path.dirname(path_output), exist_ok=True) loc_tensor = data.aggregate_loc_tensor() labels = [str(face.indx) for face in data.faces] data.img = torchvision.utils.draw_bounding_boxes( image=data.img, boxes=loc_tensor, labels=labels, colors="green", width=3, ) pil_image = torchvision.transforms.functional.to_pil_image(data.img) pil_image.save(path_output)
def fix_transform_list_attr(transform: torchvision.transforms.transforms.Compose) ‑> torchvision.transforms.transforms.Compose
-
Fix the transform attributes by converting the listconfig to a list. This enables to optimize the transform using TorchScript.
Args
transform
:torchvision.transforms.Compose
- Transform to be fixed.
Returns
torchvision.transforms.Compose
- Fixed transform.
Expand source code
def fix_transform_list_attr( transform: torchvision.transforms.Compose, ) -> torchvision.transforms.Compose: """Fix the transform attributes by converting the listconfig to a list. This enables to optimize the transform using TorchScript. Args: transform (torchvision.transforms.Compose): Transform to be fixed. Returns: torchvision.transforms.Compose: Fixed transform. """ for transform_x in transform.transforms: for key, value in transform_x.__dict__.items(): if isinstance(value, omegaconf.listconfig.ListConfig): transform_x.__dict__[key] = list(value) return transform