diff --git a/fmpose3d/inference_api/README.md b/fmpose3d/inference_api/README.md index 71a77b1..89d3761 100644 --- a/fmpose3d/inference_api/README.md +++ b/fmpose3d/inference_api/README.md @@ -97,7 +97,7 @@ Convenience constructor for the **animal** pipeline. Sets `model_type="fmpose3d_ ### Public Methods -#### `predict(source, *, camera_rotation, seed, progress)` → `Pose3DResult` +#### `predict(source, *, camera_rotation, seed, progress, return_2d)` → `Pose3DResult` End-to-end prediction: 2D estimation followed by 3D lifting in a single call. @@ -107,6 +107,7 @@ End-to-end prediction: 2D estimation followed by 3D lifting in a single call. | `camera_rotation` | `ndarray \| None` | Length-4 quaternion for camera-to-world rotation. Defaults to the official demo rotation. `None` skips the transform. Ignored for animals. | | `seed` | `int \| None` | Seed for reproducible sampling. | | `progress` | `ProgressCallback \| None` | Callback `(current_step, total_steps) -> None`. | +| `return_2d` | `bool` | When `True`, include the intermediate 2D result under `Pose3DResult.pose_2d`. Default: `False`. | **Returns:** `Pose3DResult` @@ -175,6 +176,7 @@ Source = Union[str, Path, np.ndarray, Sequence[Union[str, Path, np.ndarray]]] |---|---|---| | `poses_3d` | `ndarray` | Root-relative 3D poses, shape `(num_frames, J, 3)`. | | `poses_3d_world` | `ndarray` | Post-processed 3D poses, shape `(num_frames, J, 3)`. For humans: world-coordinate poses. For animals: limb-regularized poses. | +| `pose_2d` | `Pose2DResult \| None` | Optional 2D prediction payload, present when `predict(..., return_2d=True)` is used. | diff --git a/fmpose3d/inference_api/fmpose3d.py b/fmpose3d/inference_api/fmpose3d.py index 603277d..003693c 100644 --- a/fmpose3d/inference_api/fmpose3d.py +++ b/fmpose3d/inference_api/fmpose3d.py @@ -477,6 +477,8 @@ class Pose3DResult: ``camera_to_world``). For animal poses this contains the limb-regularised output. """ + pose_2d: Pose2DResult | None = None + """Optional 2D prediction payload from :meth:`FMPose3DInference.predict`.""" #: Accepted source types for :meth:`FMPose3DInference.predict`. @@ -659,6 +661,7 @@ def predict( camera_rotation: np.ndarray | None = _DEFAULT_CAM_ROTATION, seed: int | None = None, progress: ProgressCallback | None = None, + return_2d: bool = False, ) -> Pose3DResult: """End-to-end prediction: 2D pose estimation → 3D lifting. @@ -683,6 +686,9 @@ def predict( progress : ProgressCallback or None Optional ``(current_step, total_steps)`` callback. Forwarded to the :meth:`pose_3d` step (per-frame reporting). + return_2d : bool + When ``True``, include the intermediate :class:`Pose2DResult` + in the returned :class:`Pose3DResult` as ``pose_2d``. Returns ------- @@ -690,13 +696,16 @@ def predict( Root-relative and world-coordinate 3D poses. """ result_2d = self.prepare_2d(source) - return self.pose_3d( + result_3d = self.pose_3d( result_2d.keypoints, result_2d.image_size, camera_rotation=camera_rotation, seed=seed, progress=progress, ) + if return_2d: + result_3d.pose_2d = result_2d + return result_3d @torch.no_grad() def prepare_2d(