417 lines
14 KiB
Python
417 lines
14 KiB
Python
import io
|
|
import json
|
|
import os
|
|
import time
|
|
import urllib
|
|
from typing import List, Optional, Tuple
|
|
from urllib.parse import urljoin
|
|
|
|
import requests
|
|
from PIL import Image
|
|
from requests_toolbelt.multipart.encoder import MultipartEncoder
|
|
from tqdm import tqdm
|
|
|
|
from roboflow.config import API_URL
|
|
from roboflow.util.image_utils import validate_image_path
|
|
from roboflow.util.prediction import PredictionGroup
|
|
|
|
SUPPORTED_ROBOFLOW_MODELS = ["batch-video"]
|
|
|
|
SUPPORTED_ADDITIONAL_MODELS = {
|
|
"clip": {
|
|
"model_id": "clip",
|
|
"model_version": "1",
|
|
"inference_type": "clip-embed-image",
|
|
},
|
|
"gaze": {
|
|
"model_id": "gaze",
|
|
"model_version": "1",
|
|
"inference_type": "gaze-detection",
|
|
},
|
|
}
|
|
|
|
|
|
class InferenceModel:
|
|
def __init__(
|
|
self,
|
|
api_key,
|
|
version_id,
|
|
colors=None,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Create an InferenceModel object through which you can run inference.
|
|
|
|
Args:
|
|
api_key (str): private roboflow api key
|
|
version_id (str): the ID of the dataset version to use for inference
|
|
"""
|
|
|
|
self.__api_key = api_key
|
|
self.id = version_id
|
|
|
|
if version_id != "BASE_MODEL":
|
|
version_info = self.id.rsplit("/")
|
|
self.dataset_id = version_info[1]
|
|
self.version = version_info[2]
|
|
self.colors = {} if colors is None else colors
|
|
|
|
def __get_image_params(self, image_path):
|
|
"""
|
|
Get parameters about an image (i.e. dimensions) for use in an inference request.
|
|
|
|
Args:
|
|
image_path (Union[str, np.ndarray]): path to image or numpy array
|
|
|
|
Returns:
|
|
Tuple containing a dict of querystring params and a dict of requests kwargs
|
|
|
|
Raises:
|
|
Exception: Image path is not valid
|
|
"""
|
|
import numpy as np
|
|
|
|
if isinstance(image_path, np.ndarray):
|
|
# Convert numpy array to PIL Image
|
|
image = Image.fromarray(image_path)
|
|
dimensions = image.size
|
|
image_dims = {"width": str(dimensions[0]), "height": str(dimensions[1])}
|
|
buffered = io.BytesIO()
|
|
image.save(buffered, quality=90, format="JPEG")
|
|
data = MultipartEncoder(fields={"file": ("imageToUpload", buffered.getvalue(), "image/jpeg")})
|
|
return {}, {"data": data, "headers": {"Content-Type": data.content_type}}, image_dims
|
|
|
|
validate_image_path(image_path)
|
|
|
|
hosted_image = urllib.parse.urlparse(image_path).scheme in ("http", "https")
|
|
|
|
if hosted_image:
|
|
image_dims = {"width": "Undefined", "height": "Undefined"}
|
|
return {"image": image_path}, {}, image_dims
|
|
|
|
image = Image.open(image_path)
|
|
dimensions = image.size
|
|
image_dims = {"width": str(dimensions[0]), "height": str(dimensions[1])}
|
|
buffered = io.BytesIO()
|
|
image.save(buffered, quality=90, format="JPEG")
|
|
data = MultipartEncoder(fields={"file": ("imageToUpload", buffered.getvalue(), "image/jpeg")})
|
|
return (
|
|
{},
|
|
{"data": data, "headers": {"Content-Type": data.content_type}},
|
|
image_dims,
|
|
)
|
|
|
|
def predict(self, image_path, prediction_type=None, **kwargs):
|
|
"""
|
|
Infers detections based on image from a specified model and image path.
|
|
|
|
Args:
|
|
image_path (str): path to the image you'd like to perform prediction on
|
|
prediction_type (str): type of prediction to perform
|
|
**kwargs: Any additional kwargs will be turned into querystring params
|
|
|
|
Returns:
|
|
PredictionGroup Object
|
|
|
|
Raises:
|
|
Exception: Image path is not valid
|
|
|
|
Example:
|
|
>>> import roboflow
|
|
|
|
>>> rf = roboflow.Roboflow(api_key="")
|
|
|
|
>>> project = rf.workspace().project("PROJECT_ID")
|
|
|
|
>>> model = project.version("1").model
|
|
|
|
>>> prediction = model.predict("YOUR_IMAGE.jpg")
|
|
"""
|
|
params, request_kwargs, image_dims = self.__get_image_params(image_path)
|
|
|
|
params["api_key"] = self.__api_key
|
|
|
|
params.update(**kwargs)
|
|
url = f"{self.api_url}?{urllib.parse.urlencode(params)}" # type: ignore[attr-defined]
|
|
response = requests.post(url, **request_kwargs)
|
|
response.raise_for_status()
|
|
|
|
return PredictionGroup.create_prediction_group(
|
|
response.json(),
|
|
image_path=image_path,
|
|
prediction_type=prediction_type,
|
|
image_dims=image_dims,
|
|
colors=self.colors,
|
|
)
|
|
|
|
def predict_video(
|
|
self,
|
|
video_path: str,
|
|
fps: int = 5,
|
|
additional_models: Optional[List[str]] = None,
|
|
prediction_type: str = "batch-video",
|
|
) -> Tuple[str, str, Optional[str]]:
|
|
"""
|
|
Infers detections based on image from specified model and image path.
|
|
|
|
Args:
|
|
video_path (str): path to the video you'd like to perform prediction on
|
|
prediction_type (str): type of the model to run
|
|
fps (int): frames per second to run inference
|
|
|
|
Returns:
|
|
A list of the signed url and job id
|
|
|
|
Example:
|
|
>>> import roboflow
|
|
|
|
>>> rf = roboflow.Roboflow(api_key="")
|
|
|
|
>>> project = rf.workspace().project("PROJECT_ID")
|
|
|
|
>>> model = project.version("1").model
|
|
|
|
>>> job_id,signed_url,signed_url_expires = model.predict_video("video.mp4"
|
|
,fps=5, inference_type="object-detection")
|
|
"""
|
|
|
|
signed_url_expires = None
|
|
|
|
url = urljoin(API_URL, "/video_upload_signed_url?api_key=" + self.__api_key)
|
|
if fps > 120:
|
|
raise Exception("FPS must be less than or equal to 120.")
|
|
|
|
if additional_models is None:
|
|
additional_models = []
|
|
|
|
for model in additional_models:
|
|
if model not in SUPPORTED_ADDITIONAL_MODELS:
|
|
raise Exception(f"Model {model} is not supported for video inference.")
|
|
|
|
if prediction_type not in SUPPORTED_ROBOFLOW_MODELS:
|
|
raise Exception(f"{prediction_type} is not supported for video inference.")
|
|
|
|
model_class = self.__class__.__name__
|
|
|
|
if model_class == "ObjectDetectionModel":
|
|
self.type = "object-detection"
|
|
elif model_class == "ClassificationModel":
|
|
self.type = "classification"
|
|
elif model_class == "InstanceSegmentationModel":
|
|
self.type = "instance-segmentation"
|
|
elif model_class == "GazeModel":
|
|
self.type = "gaze-detection"
|
|
elif model_class == "CLIPModel":
|
|
self.type = "clip-embed-image"
|
|
elif model_class == "KeypointDetectionModel":
|
|
self.type = "keypoint-detection"
|
|
else:
|
|
raise Exception("Model type not supported for video inference.")
|
|
|
|
payload = json.dumps(
|
|
{
|
|
"file_name": os.path.basename(video_path),
|
|
}
|
|
)
|
|
|
|
if not video_path.startswith(("http://", "https://")):
|
|
headers = {"Content-Type": "application/json"}
|
|
|
|
try:
|
|
response = requests.request("POST", url, headers=headers, data=payload)
|
|
except Exception as e:
|
|
raise Exception(f"Error uploading video: {e}")
|
|
|
|
if not response.ok:
|
|
raise Exception(f"Error uploading video: {response.text}")
|
|
|
|
signed_url = response.json()["signed_url"]
|
|
|
|
signed_url_expires = signed_url.split("&X-Goog-Expires")[1].split("&")[0].strip("=")
|
|
|
|
# make a POST request to the signed URL
|
|
headers = {"Content-Type": "application/octet-stream"}
|
|
|
|
try:
|
|
with open(video_path, "rb") as f:
|
|
video_data = f.read()
|
|
except Exception as e:
|
|
raise Exception(f"Error reading video: {e}")
|
|
|
|
try:
|
|
result = requests.put(signed_url, data=video_data, headers=headers)
|
|
except Exception as e:
|
|
raise Exception(f"There was an error uploading the video: {e}")
|
|
|
|
if not result.ok:
|
|
raise Exception(f"There was an error uploading the video: {result.text}")
|
|
else:
|
|
signed_url = video_path
|
|
|
|
url = urljoin(API_URL, "/videoinfer/?api_key=" + self.__api_key)
|
|
if model_class in ("CLIPModel", "GazeModel"):
|
|
if model_class == "CLIPModel":
|
|
model = "clip"
|
|
else:
|
|
model = "gaze"
|
|
|
|
models = [
|
|
{
|
|
"model_id": SUPPORTED_ADDITIONAL_MODELS[model]["model_id"],
|
|
"model_version": SUPPORTED_ADDITIONAL_MODELS[model]["model_version"],
|
|
"inference_type": SUPPORTED_ADDITIONAL_MODELS[model]["inference_type"],
|
|
}
|
|
]
|
|
else:
|
|
models = [
|
|
{
|
|
"model_id": self.dataset_id,
|
|
"model_version": self.version,
|
|
"inference_type": self.type,
|
|
}
|
|
]
|
|
|
|
for model in additional_models:
|
|
models.append(SUPPORTED_ADDITIONAL_MODELS[model])
|
|
|
|
payload = json.dumps({"input_url": signed_url, "infer_fps": fps, "models": models})
|
|
|
|
headers = {"Content-Type": "application/json"}
|
|
|
|
try:
|
|
response = requests.request("POST", url, headers=headers, data=payload)
|
|
except Exception as e:
|
|
raise Exception(f"Error starting video inference: {e}")
|
|
|
|
if not response.ok:
|
|
raise Exception(f"Error starting video inference: {response.text}")
|
|
|
|
job_id = response.json()["job_id"]
|
|
|
|
self.job_id = job_id
|
|
|
|
return job_id, signed_url, signed_url_expires
|
|
|
|
def poll_for_video_results(self, job_id: Optional[str] = None) -> dict:
|
|
"""
|
|
Polls the Roboflow API to check if video inference is complete.
|
|
|
|
Returns:
|
|
Inference results as a dict
|
|
|
|
Example:
|
|
>>> import roboflow
|
|
|
|
>>> rf = roboflow.Roboflow(api_key="")
|
|
|
|
>>> project = rf.workspace().project("PROJECT_ID")
|
|
|
|
>>> model = project.version("1").model
|
|
|
|
>>> prediction = model.predict("video.mp4")
|
|
|
|
>>> results = model.poll_for_video_results()
|
|
"""
|
|
|
|
if job_id is None:
|
|
job_id = self.job_id
|
|
|
|
url = urljoin(API_URL, "/videoinfer/?api_key=" + self.__api_key + "&job_id=" + job_id)
|
|
try:
|
|
response = requests.get(url, headers={"Content-Type": "application/json"})
|
|
except Exception as e:
|
|
raise Exception(f"Error getting video inference results: {e}")
|
|
|
|
if not response.ok:
|
|
raise Exception(f"Error getting video inference results: {response.text}")
|
|
data = response.json()
|
|
if "status" not in data:
|
|
return {} # No status available
|
|
if data.get("status") > 1:
|
|
return data # Error
|
|
elif data.get("status") == 1:
|
|
return {} # Still running
|
|
else: # done
|
|
output_signed_url = data["output_signed_url"]
|
|
inference_data = requests.get(output_signed_url, headers={"Content-Type": "application/json"})
|
|
|
|
# frame_offset and model name are top-level keys
|
|
return inference_data.json()
|
|
|
|
def poll_until_video_results(self, job_id) -> dict:
|
|
"""
|
|
Polls the Roboflow API to check if video inference is complete.
|
|
|
|
When inference is complete, the results are returned.
|
|
|
|
Returns:
|
|
Inference results as a dict
|
|
|
|
Example:
|
|
>>> import roboflow
|
|
|
|
>>> rf = roboflow.Roboflow(api_key="")
|
|
|
|
>>> project = rf.workspace().project("PROJECT_ID")
|
|
|
|
>>> model = project.version("1").model
|
|
|
|
>>> prediction = model.predict("video.mp4")
|
|
|
|
>>> results = model.poll_until_results()
|
|
"""
|
|
if job_id is None:
|
|
job_id = self.job_id
|
|
|
|
attempts = 0
|
|
print(f"Checking for video inference results for job {job_id} every 60s")
|
|
while True:
|
|
time.sleep(60)
|
|
print(f"({attempts * 60}s): Checking for inference results")
|
|
response = self.poll_for_video_results(job_id)
|
|
attempts += 1
|
|
|
|
if response != {}:
|
|
return response
|
|
|
|
def download(self, format="pt", location="."):
|
|
"""
|
|
Download the weights associated with a model.
|
|
|
|
Args:
|
|
format (str): The format of the output.
|
|
- 'pt': returns a PyTorch weights file
|
|
location (str): The location to save the weights file to
|
|
"""
|
|
supported_formats = ["pt"]
|
|
if format not in supported_formats:
|
|
raise Exception(f"Unsupported format {format}. Must be one of {supported_formats}")
|
|
|
|
workspace, project, version = self.id.rsplit("/")
|
|
|
|
# get pt url
|
|
pt_api_url = f"{API_URL}/{workspace}/{project}/{self.version}/ptFile"
|
|
|
|
r = requests.get(pt_api_url, params={"api_key": self.__api_key})
|
|
|
|
r.raise_for_status()
|
|
|
|
pt_weights_url = r.json()["weightsUrl"]
|
|
|
|
response = requests.get(pt_weights_url, stream=True)
|
|
|
|
# write the zip file to the desired location
|
|
with open(location + "/weights.pt", "wb") as f:
|
|
total_length = int(response.headers.get("content-length")) # type: ignore[arg-type]
|
|
for chunk in tqdm(
|
|
response.iter_content(chunk_size=1024),
|
|
desc=f"Downloading weights to {location}/weights.pt",
|
|
total=int(total_length / 1024) + 1,
|
|
):
|
|
if chunk:
|
|
f.write(chunk)
|
|
f.flush()
|
|
|
|
return
|