154 lines
5.1 KiB
Python
154 lines
5.1 KiB
Python
import unittest
|
|
|
|
import responses
|
|
from PIL import UnidentifiedImageError
|
|
from requests.exceptions import HTTPError
|
|
|
|
from roboflow.config import OBJECT_DETECTION_URL
|
|
from roboflow.models.object_detection import ObjectDetectionModel
|
|
from roboflow.util.prediction import PredictionGroup
|
|
|
|
MOCK_RESPONSE = {
|
|
"predictions": [
|
|
{
|
|
"x": 189.5,
|
|
"y": 100,
|
|
"width": 163,
|
|
"height": 186,
|
|
"class": "helmet",
|
|
"confidence": 0.544,
|
|
}
|
|
],
|
|
"image": {"width": 2048, "height": 1371},
|
|
}
|
|
|
|
|
|
class TestObjectDetection(unittest.TestCase):
|
|
api_key = "my-api-key"
|
|
workspace = "roboflow"
|
|
dataset_id = "test-123"
|
|
version = "23"
|
|
|
|
api_url = f"{OBJECT_DETECTION_URL}/{dataset_id}/{version}"
|
|
|
|
_default_params = {
|
|
"api_key": api_key,
|
|
"confidence": "40",
|
|
"format": "json",
|
|
"labels": "false",
|
|
"name": "YOUR_IMAGE.jpg",
|
|
"overlap": "30",
|
|
"stroke": "1",
|
|
}
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.version_id = f"{self.workspace}/{self.dataset_id}/{self.version}"
|
|
|
|
def test_init_sets_attributes(self):
|
|
instance = ObjectDetectionModel(self.api_key, self.version_id, version=self.version)
|
|
|
|
self.assertEqual(instance.id, self.version_id)
|
|
# self.assertEqual(instance.api_url,
|
|
# f"{OBJECT_DETECTION_URL}/{self.dataset_id}/{self.version}")
|
|
|
|
@responses.activate
|
|
def test_predict_returns_prediction_group(self):
|
|
print(self.api_url)
|
|
image_path = "tests/images/rabbit.JPG"
|
|
instance = ObjectDetectionModel(self.api_key, self.version_id, version=self.version)
|
|
|
|
responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE)
|
|
|
|
group = instance.predict(image_path)
|
|
|
|
self.assertIsInstance(group, PredictionGroup)
|
|
|
|
@responses.activate
|
|
def test_predict_with_local_image_request(self):
|
|
image_path = "tests/images/rabbit.JPG"
|
|
instance = ObjectDetectionModel(self.api_key, self.version_id, version=self.version)
|
|
|
|
responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE)
|
|
|
|
instance.predict(image_path)
|
|
|
|
request = responses.calls[0].request
|
|
|
|
self.assertEqual(request.method, "POST")
|
|
self.assertRegex(request.url, rf"^{self.api_url}")
|
|
self.assertDictEqual(request.params, self._default_params)
|
|
self.assertIsNotNone(request.body)
|
|
|
|
@responses.activate
|
|
def test_predict_with_a_numpy_array_request(self):
|
|
import numpy as np
|
|
|
|
np_array = np.ones((100, 100, 1), dtype=np.uint8)
|
|
instance = ObjectDetectionModel(self.api_key, self.version_id, version=self.version)
|
|
|
|
responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE)
|
|
|
|
instance.predict(np_array)
|
|
|
|
request = responses.calls[0].request
|
|
|
|
self.assertEqual(request.method, "POST")
|
|
self.assertRegex(request.url, rf"^{self.api_url}")
|
|
self.assertDictEqual(request.params, self._default_params)
|
|
self.assertIsNotNone(request.body)
|
|
|
|
def test_predict_with_local_wrong_image_request(self):
|
|
image_path = "tests/images/not_an_image.txt"
|
|
instance = ObjectDetectionModel(self.api_key, self.version_id, version=self.version)
|
|
self.assertRaises(UnidentifiedImageError, instance.predict, image_path)
|
|
|
|
@responses.activate
|
|
def test_predict_with_hosted_image_request(self):
|
|
image_path = "https://example.com/racoon.JPG"
|
|
expected_params = {
|
|
**self._default_params,
|
|
"image": image_path,
|
|
}
|
|
instance = ObjectDetectionModel(self.api_key, self.version_id, version=self.version)
|
|
|
|
# Mock the library validating that the URL is valid before sending to the API
|
|
responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE)
|
|
|
|
instance.predict(image_path, hosted=True)
|
|
|
|
request = responses.calls[0].request
|
|
|
|
self.assertEqual(request.method, "POST")
|
|
self.assertRegex(request.url, rf"^{self.api_url}")
|
|
self.assertDictEqual(request.params, expected_params)
|
|
self.assertIsNone(request.body)
|
|
|
|
@responses.activate
|
|
def test_predict_with_confidence_request(self):
|
|
confidence = "100"
|
|
image_path = "tests/images/rabbit.JPG"
|
|
expected_params = {**self._default_params, "confidence": confidence}
|
|
instance = ObjectDetectionModel(self.api_key, self.version_id, version=self.version)
|
|
|
|
responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE)
|
|
|
|
instance.predict(image_path, confidence=confidence)
|
|
|
|
request = responses.calls[0].request
|
|
|
|
self.assertEqual(request.method, "POST")
|
|
self.assertRegex(request.url, rf"^{self.api_url}")
|
|
self.assertDictEqual(request.params, expected_params)
|
|
self.assertIsNotNone(request.body)
|
|
|
|
@responses.activate
|
|
def test_predict_with_non_200_response_raises_http_error(self):
|
|
image_path = "tests/images/rabbit.JPG"
|
|
responses.add(responses.POST, self.api_url, status=403)
|
|
|
|
instance = ObjectDetectionModel(self.api_key, self.version_id, version=self.version)
|
|
|
|
with self.assertRaises(HTTPError):
|
|
instance.predict(image_path)
|