71 lines
2.5 KiB
Python
71 lines
2.5 KiB
Python
import json
|
|
import os
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
import responses
|
|
from dotenv import load_dotenv
|
|
|
|
from roboflow.models.keypoint_detection import KeypointDetectionModel
|
|
from roboflow.util.prediction import PredictionGroup
|
|
|
|
load_dotenv(Path("../../.env"))
|
|
|
|
|
|
with open(Path("tests/annotations/keypoint-detection-annotations/MM2A_46_R_T_predictions.json")) as f:
|
|
MOCK_RESPONSE = json.load(f)
|
|
|
|
|
|
class TestKeypointDetection(unittest.TestCase):
|
|
api_key = os.getenv("ROBOFLOW_API_KEY", "test-api-key")
|
|
workspace = os.getenv("WORKSPACE_ID")
|
|
dataset_id = os.getenv("PROJECT_NAME")
|
|
version = "1"
|
|
|
|
api_url = f"https://detect.roboflow.com/{dataset_id}/{version}"
|
|
|
|
_default_params = {"api_key": api_key, "confidence": "40", "name": "YOUR_IMAGE.jpg"}
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.version_id = f"{self.workspace}/{self.dataset_id}/{self.version}"
|
|
|
|
def test_init_sets_attributes(self):
|
|
instance = KeypointDetectionModel(self.api_key, self.version_id, version=self.version)
|
|
|
|
self.assertEqual(instance.id, self.version_id)
|
|
self.assertEqual(instance.version, self.version)
|
|
self.assertEqual(instance.base_url, "https://detect.roboflow.com/")
|
|
|
|
@responses.activate
|
|
def test_predict_local_image(self):
|
|
instance = KeypointDetectionModel(self.api_key, self.version_id, version=self.version)
|
|
|
|
responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE, status=200)
|
|
|
|
result = instance.predict("tests/images/MM2A_46_R_T.png")
|
|
|
|
self.assertIsInstance(result, PredictionGroup)
|
|
self.assertEqual(len(result.predictions), 1)
|
|
|
|
@responses.activate
|
|
def test_predict_with_confidence(self):
|
|
instance = KeypointDetectionModel(self.api_key, self.version_id, version=self.version)
|
|
|
|
responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE, status=200)
|
|
|
|
result = instance.predict("tests/images/MM2A_46_R_T.png", confidence=30)
|
|
|
|
self.assertIsInstance(result, PredictionGroup)
|
|
request = responses.calls[0].request
|
|
self.assertEqual(request.params["confidence"], "30")
|
|
|
|
@responses.activate
|
|
def test_predict_error_response(self):
|
|
instance = KeypointDetectionModel(self.api_key, self.version_id, version=self.version)
|
|
|
|
responses.add(responses.POST, self.api_url, json={"error": "Invalid API key"}, status=401)
|
|
|
|
with self.assertRaises(Exception):
|
|
instance.predict("tests/images/MM2A_46_R_T.png")
|