DriverTrac/venv/lib/python3.12/site-packages/tests/models/test_keypoint_detection.py

71 lines
2.5 KiB
Python

import json
import os
import unittest
from pathlib import Path
import responses
from dotenv import load_dotenv
from roboflow.models.keypoint_detection import KeypointDetectionModel
from roboflow.util.prediction import PredictionGroup
load_dotenv(Path("../../.env"))
with open(Path("tests/annotations/keypoint-detection-annotations/MM2A_46_R_T_predictions.json")) as f:
MOCK_RESPONSE = json.load(f)
class TestKeypointDetection(unittest.TestCase):
api_key = os.getenv("ROBOFLOW_API_KEY", "test-api-key")
workspace = os.getenv("WORKSPACE_ID")
dataset_id = os.getenv("PROJECT_NAME")
version = "1"
api_url = f"https://detect.roboflow.com/{dataset_id}/{version}"
_default_params = {"api_key": api_key, "confidence": "40", "name": "YOUR_IMAGE.jpg"}
def setUp(self):
super().setUp()
self.version_id = f"{self.workspace}/{self.dataset_id}/{self.version}"
def test_init_sets_attributes(self):
instance = KeypointDetectionModel(self.api_key, self.version_id, version=self.version)
self.assertEqual(instance.id, self.version_id)
self.assertEqual(instance.version, self.version)
self.assertEqual(instance.base_url, "https://detect.roboflow.com/")
@responses.activate
def test_predict_local_image(self):
instance = KeypointDetectionModel(self.api_key, self.version_id, version=self.version)
responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE, status=200)
result = instance.predict("tests/images/MM2A_46_R_T.png")
self.assertIsInstance(result, PredictionGroup)
self.assertEqual(len(result.predictions), 1)
@responses.activate
def test_predict_with_confidence(self):
instance = KeypointDetectionModel(self.api_key, self.version_id, version=self.version)
responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE, status=200)
result = instance.predict("tests/images/MM2A_46_R_T.png", confidence=30)
self.assertIsInstance(result, PredictionGroup)
request = responses.calls[0].request
self.assertEqual(request.params["confidence"], "30")
@responses.activate
def test_predict_error_response(self):
instance = KeypointDetectionModel(self.api_key, self.version_id, version=self.version)
responses.add(responses.POST, self.api_url, json={"error": "Invalid API key"}, status=401)
with self.assertRaises(Exception):
instance.predict("tests/images/MM2A_46_R_T.png")