import unittest import responses from requests.exceptions import HTTPError from roboflow.config import SEMANTIC_SEGMENTATION_URL from roboflow.models.semantic_segmentation import SemanticSegmentationModel from roboflow.util.prediction import PredictionGroup MOCK_RESPONSE = { "segmentation_mask": "iVBORw0KGgoAAAANSUhEUgAAAgAAAAIACAAAAADRE4smAAACjElEQVR4nO3bz" "XKbMBiGUanT+79ldVHXwSmmFmJGfcU5i8SZbDR8DzL4pxQAAAAAAAAAAAAAAA" "AAAAAAAAAAAAAAAAAAgKXUOnsFTGX+AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" "AAAAAAAAAAAAAAMuosxewklpKm72GXj9mLyBf/etBEgGMqqVGTv5BADcngEF1" "51GSn7MXkO116HFXgMUOMCZ//gK4TuT8BTCkvfyKJICbE8CQ9vyRSgBDMm/9t" "gQwLHoDWCDhaWopj+nXkpuBHeBT9dtL/t9OndQzSQCf2j/Fa8mdfSlFAD3a3l" "+H20IAAXzszbN83fw3b/4COO9PEIFT3xDAeW2zJ6TeBHg7eEjLvgUsxQ4wrGX" "PXwDDoscvgHE1+ypQAGfU/U8CJpaQuObpvt4FeBy/9vIoigD6fR2z9nwZaPty" "UBQB9Ds6ZnEFuAa4OQF0O9w043ZUAdxcXLGT/eN4xV0C2AH6LDd/AVwpcP4Cu" "FDi/AVwocjrKQF0iTzJDwmgz3IFCKDTagUI4OYEcJ3IzUEAnSIv9Q/4VHCXd+" "NvsWGkrnum9xUE8hTQL3LQ7wjghJUKEMBlMrMQwBm7s868nMpc9X/iefB+fzc" "8cgtwGzjg8XWAyMFzOZspAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAC" "wil+AQDJrnsYcnwAAAABJRU5ErkJggg==", "class_map": {"0": "background", "1": "object"}, "image": {"width": 800, "height": 600}, } class TestSemanticSegmentation(unittest.TestCase): api_key = "my-api-key" workspace = "roboflow" dataset_id = "test-123" version = "23" api_url = f"https://segment.roboflow.com/{dataset_id}/{version}" _default_params = {"api_key": api_key, "confidence": "50"} version_id = f"{workspace}/{dataset_id}/{version}" def test_init_sets_attributes(self): instance = SemanticSegmentationModel(self.api_key, self.version_id) self.assertEqual(instance.id, self.version_id) self.assertEqual( instance.api_url, f"{SEMANTIC_SEGMENTATION_URL}/{self.dataset_id}/{self.version}", ) @responses.activate def test_predict_returns_prediction_group(self): image_path = "tests/images/rabbit.JPG" instance = SemanticSegmentationModel(self.api_key, self.version_id) responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE) group = instance.predict(image_path) self.assertIsInstance(group, PredictionGroup) @responses.activate def test_predict_with_local_image_request(self): image_path = "tests/images/rabbit.JPG" instance = SemanticSegmentationModel(self.api_key, self.version_id) responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE) instance.predict(image_path) request = responses.calls[0].request self.assertEqual(request.method, "POST") self.assertRegex(request.url, rf"^{self.api_url}") self.assertDictEqual(request.params, self._default_params) self.assertIsNotNone(request.body) @responses.activate def test_predict_with_hosted_image_request(self): image_path = "https://example.com/raccoon.JPG" expected_params = { **self._default_params, "image": image_path, } instance = SemanticSegmentationModel(self.api_key, self.version_id) # Mock the library validating that the URL is valid before sending to the API responses.add(responses.HEAD, image_path) responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE) instance.predict(image_path) request = responses.calls[1].request self.assertEqual(request.method, "POST") self.assertRegex(request.url, rf"^{self.api_url}") self.assertDictEqual(request.params, expected_params) self.assertIsNone(request.body) @responses.activate def test_predict_with_confidence_request(self): confidence = "100" image_path = "tests/images/rabbit.JPG" expected_params = {**self._default_params, "confidence": confidence} instance = SemanticSegmentationModel(self.api_key, self.version_id) responses.add(responses.POST, self.api_url, json=MOCK_RESPONSE) instance.predict(image_path, confidence=confidence) request = responses.calls[0].request self.assertEqual(request.method, "POST") self.assertRegex(request.url, rf"^{self.api_url}") self.assertDictEqual(request.params, expected_params) self.assertIsNotNone(request.body) @responses.activate def test_predict_with_non_200_response_raises_http_error(self): image_path = "tests/images/rabbit.JPG" responses.add(responses.POST, self.api_url, status=403) instance = SemanticSegmentationModel(self.api_key, self.version_id) with self.assertRaises(HTTPError): instance.predict(image_path)