DriverTrac/venv/lib/python3.12/site-packages/roboflow/models/video.py

241 lines
6.5 KiB
Python

import json
import time
from typing import Optional, Tuple
from urllib.parse import urljoin
import filetype
import requests
from roboflow.config import API_URL
from roboflow.models.inference import InferenceModel
SUPPORTED_ROBOFLOW_MODELS = ["object-detection", "classification", "instance-segmentation", "keypoint-detection"]
SUPPORTED_ADDITIONAL_MODELS = {
"clip": {
"model_id": "clip",
"model_version": "1",
"inference_type": "clip-embed-image",
},
"gaze": {
"model_id": "gaze",
"model_version": "1",
"inference_type": "gaze-detection",
},
}
ACCEPTED_VIDEO_FORMATS = {
"video/mp4",
"video/x-msvideo", # AVI
"video/webm",
}
def is_valid_mime(filename):
kind = filetype.guess(filename)
if kind is None:
return False
return kind.mime in ACCEPTED_VIDEO_FORMATS
def is_valid_video(filename):
# check file type
if not is_valid_mime(filename):
return False
return True
class VideoInferenceModel(InferenceModel):
"""
Run inference on an object detection model hosted on Roboflow or served through Roboflow Inference.
""" # noqa: E501 // docs
def __init__(
self,
api_key,
):
"""
Create a VideoDetectionModel object through which you can run inference on videos.
Args:
api_key (str): Your API key (obtained via your workspace API settings page).
""" # noqa: E501 // docs
self.__api_key = api_key
def predict( # type: ignore[override]
self,
video_path: str,
inference_type: str,
fps: int = 5,
additional_models: Optional[list] = None,
) -> Tuple[str, str]:
"""
Infers detections based on image from specified model and image path.
Args:
video_path (str): path to the video you'd like to perform prediction on
inference_type (str): type of the model to run
fps (int): frames per second to run inference
Returns:
A list of the signed url and job id
Example:
>>> import roboflow
>>> rf = roboflow.Roboflow(api_key="")
>>> project = rf.workspace().project("PROJECT_ID")
>>> model = project.version("1").model
>>> prediction = model.predict("video.mp4", fps=5, inference_type="object-detection")
""" # noqa: E501 // docs
url = urljoin(API_URL, f"/video_upload_signed_url/?api_key={self.__api_key}")
if fps > 120:
raise Exception("FPS must be less than or equal to 120.")
if additional_models is None:
additional_models = []
for model in additional_models:
if model not in SUPPORTED_ADDITIONAL_MODELS:
raise Exception(f"Model {model} is not supported for video inference.")
if inference_type not in SUPPORTED_ROBOFLOW_MODELS:
raise Exception(f"Model {inference_type} is not supported for video inference.")
if not is_valid_video(video_path):
raise Exception("Video path is not valid")
payload = json.dumps(
{
"file_name": video_path,
}
)
headers = {"Content-Type": "application/json"}
response = requests.request("POST", url, headers=headers, data=payload)
signed_url = response.json()["signed_url"]
print("Uploaded video to signed url: " + signed_url)
url = urljoin(API_URL, f"/videoinfer/?api_key={self.__api_key}")
models = [
{
"model_id": self.dataset_id,
"model_version": self.version,
"inference_type": inference_type,
}
]
for model in additional_models:
models.append(SUPPORTED_ADDITIONAL_MODELS[model])
payload = json.dumps({"input_url": signed_url, "infer_fps": fps, "models": models})
response = requests.request("POST", url, headers=headers, data=payload)
job_id = response.json()["job_id"]
self.job_id = job_id
return job_id, signed_url
def poll_for_results(self, job_id: Optional[str] = None) -> dict:
"""
Polls the Roboflow API to check if video inference is complete.
Returns:
Inference results as a dict
Example:
>>> import roboflow
>>> rf = roboflow.Roboflow(api_key="")
>>> project = rf.workspace().project("PROJECT_ID")
>>> model = project.version("1").model
>>> prediction = model.predict("video.mp4")
>>> results = model.poll_for_results()
""" # noqa: E501 // docs
if job_id is None:
job_id = self.job_id
url = urljoin(API_URL, f"/videoinfer/?api_key={self.__api_key}&job_id={self.job_id}")
try:
response = requests.get(url, headers={"Content-Type": "application/json"})
except Exception as e:
print(e)
raise Exception("Error polling for results.")
if not response.ok:
raise Exception("Error polling for results.")
data = response.json()
if data["success"] == 0:
output_signed_url = data["output_signed_url"]
inference_data = requests.get(output_signed_url, headers={"Content-Type": "application/json"})
# frame_offset and model name are top-level keys
return inference_data.json()
elif data["success"] == 1:
print("Job not complete yet. Check back in a minute.")
return {}
else:
raise Exception("Job failed.")
def poll_until_results(self, job_id) -> dict:
"""
Polls the Roboflow API to check if video inference is complete.
When inference is complete, the results are returned.
Returns:
Inference results as a dict
Example:
>>> import roboflow
>>> rf = roboflow.Roboflow(api_key="")
>>> project = rf.workspace().project("PROJECT_ID")
>>> model = project.version("1").model
>>> prediction = model.predict("video.mp4")
>>> results = model.poll_until_results()
""" # noqa: E501 // docs
if job_id is None:
job_id = self.job_id
attempts = 0
while True:
response = self.poll_for_results()
attempts += 1
if response != {}:
return response
print(f"({attempts * 60}s): Checking for inference results")
time.sleep(60)