Implement New Models#

PocketPose offers a flexible and extensible framework for integrating various pose estimation models, particularly tailored for mobile devices. At its core, the library utilizes a well-defined interface for models, making it straightforward to add new models and extend the library’s capabilities.

classDiagram IModel <|-- TFLiteModel IModel <|-- ONNModel TFLiteModel <|-- YourCustomModel IModel: +process_image() IModel: +predict() IModel: +postprocess_prediction() class TFLiteModel{ +predict() +postprocess_prediction() } class ONNModel{ +predict() +postprocess_prediction() } class YourCustomModel{ +predict() +postprocess_prediction() }

To integrate a new model into PocketPose, follow these steps:

  1. Inherit from IModel: Your model class should inherit from the IModel interface or one of its direct subclasses. This ensures consistency and interoperability within the PocketPose framework.

    from pocketpose.models.interfaces import IModel
    
    class YourCustomModel(IModel):
    ...
    
  2. Implement Required Methods: Implement the abstract methods defined in the IModel interface. These typically include process_image, predict, and postprocess_prediction.

    • process_image: Prepare the input image for prediction.

    • predict: Run the model inference.

    • postprocess_prediction: Process the raw model output to extract meaningful information, such as keypoint coordinates.

  3. Add Custom Logic (if needed): Depending on your model’s specific requirements, you can add additional methods or override existing ones for customized behavior.

  4. Integrate with Model Factory: Optionally, integrate your model with the ModelFactory to enable easy instantiation.

Example#

The following example demonstrates how to integrate a TFLite model into PocketPose.

from pocketpose.models.interfaces import TFLiteModel
from pocketpose.models.registry import model_registry


@model_registry.register('model_name')
class CustomModel(TFLiteModel):
    def __init__(self,
                 model_path: str = "path/to/cache/model/file.tflite",
                 model_url: str = "https://url/to/download/model/from.tflite",
                 input_size: tuple = (192, 192, 3)):
        super().__init__(model_path, model_url, keypoints_type='coco',
                         input_size=input_size, output_type='keypoints')

    def postprocess_prediction(self, prediction, original_size):
        keypoints = prediction.squeeze()  # (17, 3) as (y, x, score)
        keypoints[:, :2] *= original_size
        keypoints = [tuple([int(x), int(y), s]) for y, x, s in keypoints]
        return keypoints  # (17, 3) as (x, y, score)

The decorated model_registry.register method registers the model with the ModelFactory, allowing it to be instantiated by name. The postprocess_prediction method converts the output from model-specific format to a list of keypoints, which is the expected output format for PocketPose.