DriverTrac/venv/lib/python3.12/site-packages/tests/models/test_object_detection.py

154 lines
5.1 KiB
Python

import unittest
import responses
from PIL import UnidentifiedImageError
from requests.exceptions import HTTPError
from roboflow.config import OBJECT_DETECTION_URL
from roboflow.models.object_detection import ObjectDetectionModel
from roboflow.util.prediction import PredictionGroup
MOCK_RESPONSE = {
"predictions": [
{
"x": 189.5,
"y": 100,
"width": 163,
"height": 186,
"class": "helmet",
"confidence": 0.544,
}
],
"image": {"width": 2048, "height": 1371},
}
class TestObjectDetection(unittest.TestCase):
api_key = "my-api-key"
workspace = "roboflow"
dataset_id = "test-123"
version = "23"
api_url = f"{OBJECT_DETECTION_URL}/{dataset_id}/{version}"
_default_params = {
"api_key": api_key,
"confidence": "40",
"format": "json",
"labels": "false",
"name": "YOUR_IMAGE.jpg",
"overlap": "30",
"stroke": "1",
}
def setUp(self):
super().setUp()
self.version_id = f"{self.workspace}/{self.dataset_id}/{self.version}"
def test_init_sets_attributes(self):
instance = ObjectDetectionModel(self.api_key, self.version_id, version=self.version)
self.assertEqual(instance.id, self.version_id)
# self.assertEqual(instance.api_url,
# f"{OBJECT_DETECTION_URL}/{self.dataset_id}/{self.version}")
@responses.activate
def test_predict_returns_prediction_group(self):
print(self.api_url)
image_path = "tests/images/rabbit.JPG"
instance = ObjectDetectionModel(self.api_key, self.version_id, version=self.version)
responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE)
group = instance.predict(image_path)
self.assertIsInstance(group, PredictionGroup)
@responses.activate
def test_predict_with_local_image_request(self):
image_path = "tests/images/rabbit.JPG"
instance = ObjectDetectionModel(self.api_key, self.version_id, version=self.version)
responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE)
instance.predict(image_path)
request = responses.calls[0].request
self.assertEqual(request.method, "POST")
self.assertRegex(request.url, rf"^{self.api_url}")
self.assertDictEqual(request.params, self._default_params)
self.assertIsNotNone(request.body)
@responses.activate
def test_predict_with_a_numpy_array_request(self):
import numpy as np
np_array = np.ones((100, 100, 1), dtype=np.uint8)
instance = ObjectDetectionModel(self.api_key, self.version_id, version=self.version)
responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE)
instance.predict(np_array)
request = responses.calls[0].request
self.assertEqual(request.method, "POST")
self.assertRegex(request.url, rf"^{self.api_url}")
self.assertDictEqual(request.params, self._default_params)
self.assertIsNotNone(request.body)
def test_predict_with_local_wrong_image_request(self):
image_path = "tests/images/not_an_image.txt"
instance = ObjectDetectionModel(self.api_key, self.version_id, version=self.version)
self.assertRaises(UnidentifiedImageError, instance.predict, image_path)
@responses.activate
def test_predict_with_hosted_image_request(self):
image_path = "https://example.com/racoon.JPG"
expected_params = {
**self._default_params,
"image": image_path,
}
instance = ObjectDetectionModel(self.api_key, self.version_id, version=self.version)
# Mock the library validating that the URL is valid before sending to the API
responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE)
instance.predict(image_path, hosted=True)
request = responses.calls[0].request
self.assertEqual(request.method, "POST")
self.assertRegex(request.url, rf"^{self.api_url}")
self.assertDictEqual(request.params, expected_params)
self.assertIsNone(request.body)
@responses.activate
def test_predict_with_confidence_request(self):
confidence = "100"
image_path = "tests/images/rabbit.JPG"
expected_params = {**self._default_params, "confidence": confidence}
instance = ObjectDetectionModel(self.api_key, self.version_id, version=self.version)
responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE)
instance.predict(image_path, confidence=confidence)
request = responses.calls[0].request
self.assertEqual(request.method, "POST")
self.assertRegex(request.url, rf"^{self.api_url}")
self.assertDictEqual(request.params, expected_params)
self.assertIsNotNone(request.body)
@responses.activate
def test_predict_with_non_200_response_raises_http_error(self):
image_path = "tests/images/rabbit.JPG"
responses.add(responses.POST, self.api_url, status=403)
instance = ObjectDetectionModel(self.api_key, self.version_id, version=self.version)
with self.assertRaises(HTTPError):
instance.predict(image_path)