While working as an ML engineer for KatalistAI, I encountered a fascinating challenge when working with OpenPose detection. In some cases, the body was only visible from the waist up. In that case, the model for detection (DW-pose) could not predict the points located outside of the frame, but we still wanted a reasonable body pose even for the off-screen points. This would allow us to scale or move the skeleton and have a reasonable skeleton shape from the get-go.
By adding a post-processing step to the keypoint detection, we can extrapolate the points outside the frame, based on the currently visible keypoints, reducing further work when changing the pose, as demonstrated below.
The code is available at https://github.com/katalist-ai/openpose-extrapolation
The key to overcoming this challenge was not in the complexity of the machine learning model, but rather in the extensive data preparation. I discovered a comprehensive OpenPose dataset on Kaggle, which provided various human skeletons in different poses. We aimed to train our model to predict hidden points effectively, a task that required thoughtful manipulation of the data.
We can greatly improve the dataset by using different augmentation techniques and make the model generalize better:
To generate training examples, we used the whole-body skeletons and marked different subsets of points as missing (by changin them to -10) to produce the inputs for the training. To generate missing points, we sampled a subset of points from all possible subsets and marked those points as missing. We sampled randomly for each skeleton and repeated this process 30 times.
All the preprocessing steps and are included in the preprocess.py file in the repo.
Model Architecture
We created a simple torch Dataloader class, which returns matching samples of a full skeleton and the same skeleton with missing keypoints
import torch
from torch.utils.data import Dataset
from pathlib import Path
class OpenPoseDataset(Dataset):
def __init__(self, data_dir, data_type):
self.poses = torch.load(Path(data_dir, f'poses_{data_type}.pt'))
self.poses_missing = torch.load(
Path(data_dir, f'poses_missing_{data_type}.pt')
)
self.data_dir = data_dir
def __len__(self):
return self.poses.shape[0]
def __getitem__(self, idx):
pose = self.poses[idx]
pose_missing = self.poses_missing[idx]
return pose_missing, pose
Start creating today!
Join the Discord server to get access earlier.
We’ll later use it to load training, validation and test datasets.
We utilized the PyTorch Lightning library for an efficient model-building process. The architecture was straightforward yet effective, comprising two hidden layers designed for fast learning and good generalization. The AdamW optimizer was chosen for its approach to weight decay, providing better regularization than the standard Adam optimizer.
from torch import nn, optim
import lightning as L
class SkeletonExtrapolator(L.LightningModule):
def __init__(self, learning_rate=1e-3, weight_decay=1e-5, dropout=0.2):
super().__init__()
self.save_hyperparameters()
self.layers = nn.Sequential(nn.Linear(36, 36), nn.ReLU(),
nn.Linear(36, 64), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(64, 36))
def forward(self, x):
# inference
mask = x != -10.0
x_hat = self.layers(x)
x_hat[mask] = x[mask]
return x_hat
def general_step(self, batch, batch_idx, loss_name):
x, y = batch
x_hat = self.layers(x)
loss = nn.functional.mse_loss(x_hat, y)
self.log(loss_name, loss)
return loss
def training_step(self, batch, batch_idx):
return self.general_step(batch, batch_idx, "train_loss")
def test_step(self, batch, batch_idx):
return self.general_step(batch, batch_idx, "test_loss")
def validation_step(self, batch, batch_idx):
return self.general_step(batch, batch_idx, "val_loss")
def configure_optimizers(self):
optimizer = optim.AdamW(self.parameters(),
lr=self.hparams.learning_rate,
weight_decay=self.hparams.weight_decay)
return optimizer
While the training step uses every model output to calculate the loss and learn properly, during inference (the forward method) we only need the outputs for the missing points; points with a value of -10.
Training
Using the pytorch lightning library, the training is much more streamlined compared to using PyTorch alone. We’ve used Dropout with p=0.5 and weight decay of 0.001 to regularize the model. We care more about the model’s generalization capabilities than the pixel-perfect predictions as we don’t want the model to overfit on our heavily augmented dataset.
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from dataset import OpenPoseDataset
from model import SkeletonExtrapolator
from lightning.pytorch import loggers as pl_loggers
from paths import DATA_DIR, ROOT_DIR
from pathlib import Path
model = SkeletonExtrapolator(dropout=0.5, weight_decay=0.001)
dataset_train = OpenPoseDataset(DATA_DIR, 'train')
dataset_valid = OpenPoseDataset(DATA_DIR, 'valid')
dataset_test = OpenPoseDataset(DATA_DIR, 'test')
checkpoint_callback = ModelCheckpoint(
monitor='val_loss',
mode='min',
save_top_k=3,
filename='{epoch}-{val_loss:.6f}'
)
train_loader = DataLoader(dataset_train, batch_size=32, shuffle=True, num_workers=4, persistent_workers=True)
valid_loader = DataLoader(dataset_valid, batch_size=512, shuffle=False, num_workers=8, persistent_workers=True)
test_loader = DataLoader(dataset_test, batch_size=512, shuffle=False, num_workers=8, persistent_workers=True)
tb_logger = pl_loggers.TensorBoardLogger(save_dir=ROOT_DIR)
trainer = L.Trainer(max_epochs=8, callbacks=[checkpoint_callback], logger=tb_logger)
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=valid_loader)
trainer.test(ckpt_path='best', dataloaders=test_loader)
In our initial approach to constructing the training set, we manually selected subsets of missing points based on our intuition, such as pairs like the left and right ankle or both ankles and knees, preparing around 15 different subsets. However, this method did not cover all possible combinations of missing points, leading to the model’s inability to accurately predict for subsets it hadn’t seen before. To address this limitation and improve the model’s generalization and robustness, we adopted a new strategy. We began sampling one subset from the entire range of possible subsets for each skeleton, repeating this procedure 30 times across the already augmented dataset. This process resulted in a training dataset of approximately 1 million skeleton pairs, enhancing the model robustness.
This model was created for the purpose of extrapolating off screen OpenPose keypoints. If you want to see the model in action, you can check it out at Katalist.ai, where the OpenPose skeletons are used to guide the image generation.
Start creating today!
Join the Discord server to get access earlier.
Discover how generative AI can improve storytelling.