DriverTrac/venv/lib/python3.12/site-packages/roboflow/models/classification.py

187 lines
5.7 KiB
Python

import base64
import io
import json
import os
import urllib
from typing import Optional
import requests
from PIL import Image
from roboflow.config import CLASSIFICATION_MODEL
from roboflow.models.inference import InferenceModel
from roboflow.util.image_utils import check_image_url
from roboflow.util.prediction import PredictionGroup
class ClassificationModel(InferenceModel):
"""
Run inference on a classification model hosted on Roboflow or served through
Roboflow Inference.
"""
def __init__(
self,
api_key: str,
id: str,
name: Optional[str] = None,
version: Optional[str] = None,
local: Optional[str] = None,
colors: Optional[dict] = None,
preprocessing: Optional[dict] = None,
):
"""
Create a ClassificationModel object through which you can run inference.
Args:
api_key (str): private roboflow api key
id (str): the workspace/project id
name (str): is the name of the project
version (str): version number
local (str): localhost address and port if pointing towards local inference engine
colors (dict): colors to use for the image
preprocessing (dict): preprocessing to use for the image
Returns:
ClassificationModel Object
"""
# Instantiate different API URL parameters
super().__init__(api_key, id, version=version)
self.__api_key = api_key
self.id = id
self.name = name
self.version = version
self.base_url = "https://classify.roboflow.com/"
if self.name is not None and version is not None:
self.__generate_url()
self.colors = {} if colors is None else colors
self.preprocessing = {} if preprocessing is None else preprocessing
if local:
print(f"initalizing local classification model hosted at : {local}")
self.base_url = local
def predict(self, image_path, hosted=False): # type: ignore[override]
"""
Run inference on an image.
Args:
image_path (str): path to the image you'd like to perform prediction on
hosted (bool): whether the image you're providing is hosted on Roboflow
Returns:
PredictionGroup Object
Example:
>>> import roboflow
>>> rf = roboflow.Roboflow(api_key="")
>>> project = rf.workspace().project("PROJECT_ID")
>>> model = project.version("1").model
>>> prediction = model.predict("YOUR_IMAGE.jpg")
"""
self.__generate_url()
self.__exception_check(image_path_check=image_path)
# If image is local image
if not hosted:
# Open Image in RGB Format
image = Image.open(image_path).convert("RGB")
# Create buffer
buffered = io.BytesIO()
image.save(buffered, quality=90, format="JPEG")
img_dims = image.size
# Base64 encode image
img_str = base64.b64encode(buffered.getvalue())
img_str = img_str.decode("ascii")
# Post to API and return response
resp = requests.post(
self.api_url,
data=img_str,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
else:
# Create API URL for hosted image (slightly different)
self.api_url += "&image=" + urllib.parse.quote_plus(image_path)
# POST to the API
resp = requests.post(self.api_url)
img_dims = {"width": "0", "height": "0"}
if resp.status_code != 200:
raise Exception(resp.text)
return PredictionGroup.create_prediction_group(
resp.json(),
image_dims=img_dims,
image_path=image_path,
prediction_type=CLASSIFICATION_MODEL,
colors=self.colors,
)
def load_model(self, name, version):
"""
Load a model.
Args:
name (str): is the name of the model you'd like to load
version (int): version number
"""
# Load model based on user defined characteristics
self.name = name
self.version = version
self.__generate_url()
def __generate_url(self):
"""
Generate a Roboflow API URL on which to run inference.
Returns:
url (str): the url on which to run inference
"""
# Generates URL based on all parameters
splitted = self.id.rsplit("/")
without_workspace = splitted[1]
version = self.version
if not version and len(splitted) > 2:
version = splitted[2]
self.api_url = "".join(
[
self.base_url + without_workspace + "/" + str(version),
"?api_key=" + self.__api_key,
"&name=YOUR_IMAGE.jpg",
]
)
def __exception_check(self, image_path_check=None):
"""
Check to see if an image exists.
Args:
image_path_check (str): path to the image to check
Raises:
Exception: if image does not exist
"""
# Checks if image exists
if image_path_check is not None:
if not os.path.exists(image_path_check) and not check_image_url(image_path_check):
raise Exception("Image does not exist at " + image_path_check + "!")
def __str__(self):
"""
String representation of classification object
"""
json_value = {
"name": self.name,
"version": self.version,
"base_url": self.base_url,
}
return json.dumps(json_value, indent=2)