Enhancing OpenPose Detection Using Self-Supervised Learning

5
 min. read

While working as an ML engineer for KatalistAI, I encountered a fascinating challenge when working with OpenPose detection. In some cases, the body was only visible from the waist up. In that case, the model for detection (DW-pose) could not predict the points located outside of the frame, but we still wanted a reasonable body pose even for the off-screen points. This would allow us to scale or move the skeleton and have a reasonable skeleton shape from the get-go.

By adding a post-processing step to the keypoint detection, we can extrapolate the points outside the frame, based on the currently visible keypoints, reducing further work when changing the pose, as demonstrated below.

The code is available at https://github.com/katalist-ai/openpose-extrapolation

Critical Role of Data Preparation

The key to overcoming this challenge was not in the complexity of the machine learning model, but rather in the extensive data preparation. I discovered a comprehensive OpenPose dataset on Kaggle, which provided various human skeletons in different poses. We aimed to train our model to predict hidden points effectively, a task that required thoughtful manipulation of the data.

We can greatly improve the dataset by using different augmentation techniques and make the model generalize better:

  • Normalising points in relation to the neck point, to introduce translation invariant.
  • Vertically flipping the skeletons.
  • Resizing the skeleton images to simulate various body sizes.
  • Rotating the skeletons in 3D space and then projecting them back into 2D.
  • Adjusting the width of the skeletons to represent different body types.
  • Altering the aspect ratios to reflect common formats like 16:9, 1:1, and 9:16.
  • To generate training examples, we used the whole-body skeletons and marked different sets of points as missing (by changin them to -10) to produce the inputs for the training.
  • All the preprocessing steps and sets of missing points are included in the preprocess.py file in the repo.

To generate training examples, we used the whole-body skeletons and marked different subsets of points as missing (by changin them to -10) to produce the inputs for the training. To generate missing points, we sampled a subset of points from all possible subsets and marked those points as missing. We sampled randomly for each skeleton and repeated this process 30 times.
All the preprocessing steps and are included in the preprocess.py file in the repo.

Model Architecture

We created a simple torch Dataloader class, which returns matching samples of a full skeleton and the same skeleton with missing keypoints

import torch
from torch.utils.data import Dataset
from pathlib import Path

class OpenPoseDataset(Dataset):
    def __init__(self, data_dir, data_type):
        self.poses = torch.load(Path(data_dir, f'poses_{data_type}.pt'))
        self.poses_missing = torch.load(
          Path(data_dir, f'poses_missing_{data_type}.pt')
        )
        self.data_dir = data_dir

    def __len__(self):
        return self.poses.shape[0]

    def __getitem__(self, idx):
        pose = self.poses[idx]
        pose_missing = self.poses_missing[idx]
        return pose_missing, pose

Get started with Katalist For FREE

Start creating today!
Join the Discord server to get access earlier.

Join our Discord server
165/220 spots filled

We’ll later use it to load training, validation and test datasets.

We utilized the PyTorch Lightning library for an efficient model-building process. The architecture was straightforward yet effective, comprising two hidden layers designed for fast learning and good generalization. The AdamW optimizer was chosen for its approach to weight decay, providing better regularization than the standard Adam optimizer.

from torch import nn, optim
import lightning as L

class SkeletonExtrapolator(L.LightningModule):
    def __init__(self, learning_rate=1e-3, weight_decay=1e-5, dropout=0.2):
        super().__init__()
        self.save_hyperparameters()
        self.layers = nn.Sequential(nn.Linear(36, 36), nn.ReLU(), 
                                     nn.Linear(36, 64), nn.ReLU(), nn.Dropout(dropout),
                                     nn.Linear(64, 36))
    
    def forward(self, x):
        # inference
        mask = x != -10.0
        x_hat = self.layers(x)
        x_hat[mask] = x[mask]
        return x_hat
    
    def general_step(self, batch, batch_idx, loss_name):
        x, y = batch
        x_hat = self.layers(x)
        loss = nn.functional.mse_loss(x_hat, y)
        self.log(loss_name, loss)
        return loss

    def training_step(self, batch, batch_idx):
        return self.general_step(batch, batch_idx, "train_loss")
    
    def test_step(self, batch, batch_idx):
        return self.general_step(batch, batch_idx, "test_loss")
    
    def validation_step(self, batch, batch_idx):
        return self.general_step(batch, batch_idx, "val_loss")

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), 
                               lr=self.hparams.learning_rate, 
                               weight_decay=self.hparams.weight_decay)
        return optimizer

While the training step uses every model output to calculate the loss and learn properly, during inference (the forward method) we only need the outputs for the missing points; points with a value of -10.

Training

Using the pytorch lightning library, the training is much more streamlined compared to using PyTorch alone. We’ve used Dropout with p=0.5 and weight decay of 0.001 to regularize the model. We care more about the model’s generalization capabilities than the pixel-perfect predictions as we don’t want the model to overfit on our heavily augmented dataset.

from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from dataset import OpenPoseDataset
from model import SkeletonExtrapolator
from lightning.pytorch import loggers as pl_loggers
from paths import DATA_DIR, ROOT_DIR
from pathlib import Path

model = SkeletonExtrapolator(dropout=0.5, weight_decay=0.001)
dataset_train = OpenPoseDataset(DATA_DIR, 'train')
dataset_valid = OpenPoseDataset(DATA_DIR, 'valid')
dataset_test = OpenPoseDataset(DATA_DIR, 'test')

checkpoint_callback = ModelCheckpoint(
  monitor='val_loss',
  mode='min',
  save_top_k=3,
  filename='{epoch}-{val_loss:.6f}'
 )

train_loader = DataLoader(dataset_train, batch_size=32, shuffle=True, num_workers=4, persistent_workers=True)
valid_loader = DataLoader(dataset_valid, batch_size=512, shuffle=False, num_workers=8, persistent_workers=True)
test_loader = DataLoader(dataset_test, batch_size=512, shuffle=False, num_workers=8, persistent_workers=True)

tb_logger = pl_loggers.TensorBoardLogger(save_dir=ROOT_DIR)
trainer = L.Trainer(max_epochs=8, callbacks=[checkpoint_callback], logger=tb_logger)
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=valid_loader)
trainer.test(ckpt_path='best', dataloaders=test_loader)

Challenges and Learnings

In our initial approach to constructing the training set, we manually selected subsets of missing points based on our intuition, such as pairs like the left and right ankle or both ankles and knees, preparing around 15 different subsets. However, this method did not cover all possible combinations of missing points, leading to the model’s inability to accurately predict for subsets it hadn’t seen before. To address this limitation and improve the model’s generalization and robustness, we adopted a new strategy. We began sampling one subset from the entire range of possible subsets for each skeleton, repeating this procedure 30 times across the already augmented dataset. This process resulted in a training dataset of approximately 1 million skeleton pairs, enhancing the model robustness.

Use Case

This model was created for the purpose of extrapolating off screen OpenPose keypoints. If you want to see the model in action, you can check it out at Katalist.ai, where the OpenPose skeletons are used to guide the image generation.

Get early access to Katalist For FREE

Start creating today!
Join the Discord server to get access earlier.

Join our Discord server
165/220 spots filled

Book a Free Demo Now

Discover how generative AI can improve storytelling.

Your message has been submitted.
We will get back to you within 24-48 hours.
Oops! Something went wrong.