Coretex
instance_extractor.py
1 # Copyright (C) 2023 Coretex LLC
2 
3 # This file is part of Coretex.ai
4 
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License as
7 # published by the Free Software Foundation, either version 3 of the
8 # License, or (at your option) any later version.
9 
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU Affero General Public License for more details.
14 
15 # You should have received a copy of the GNU Affero General Public License
16 # along with this program. If not, see <https://www.gnu.org/licenses/>.
17 
18 from typing import Any, Optional, List, Dict
19 
20 import os
21 import xml.etree.ElementTree as ET
22 
23 from PIL import Image
24 from PIL.Image import Image as PILImage
25 from skimage import measure
26 from shapely.geometry import Polygon
27 
28 import numpy as np
29 
30 from .shared import getTag, getBoxes
31 from ....annotation import CoretexSegmentationInstance, BBox
32 from ....dataset import ImageDataset
33 
34 
35 ContourPoints = List[List[int]]
36 SegmentationPolygon = List[ContourPoints]
37 
38 
40 
41  """
42  Represents the groundtruth object for contour matching
43  """
44 
45  def __init__(self, object: ET.Element):
46  self.objectobject = object
47  self.matchedmatched = False
48 
49 
51 
52  """
53  Represents the extracted contour from segmentation mask
54  """
55 
56  def __init__(self, contours: ContourPoints, iou: Optional[float] = None):
57  if iou is None:
58  iou = 0.0
59 
60  self.contourscontours = contours
61  self.iouiou = iou
62  self.matchedmatched = False
63 
64 
65 class InstanceExtractor:
66 
67  def __init__(self, dataset: ImageDataset) -> None:
68  self.__dataset = dataset
69 
70  def createSubmasks(self, maskImage: PILImage) -> Dict[str, Any]:
71  """
72  Creates submasks for each segmentation mask
73 
74  Parameters
75  ----------
76  maskImage : Image
77  Segmentation mask
78 
79  Returns
80  -------
81  Dict[str, Any] -> Dictionary with submask image and color
82  """
83 
84  width, height = maskImage.size
85 
86  # Initialize a dictionary of sub-masks indexed by RGB colors
87  subMasks: Dict[str, Any] = {}
88  for x in range(width):
89  for y in range(height):
90  # Get the RGB values of the pixel
91  pixel = maskImage.getpixel((x, y))
92 
93  # Ensure pixel is a tuple and has at least 3 elements (RGB)
94  if not isinstance(pixel, tuple):
95  raise TypeError(f"Expected \"tuple\" recieved \"{type(pixel)}\"")
96 
97  if not len(pixel) >= 3:
98  raise ValueError(f"Expected pixel to has at least 3 channels (RGB).")
99 
100  pixel = pixel[:3]
101 
102  pixelStr = str(pixel)
103  subMask = subMasks.get(pixelStr)
104  if subMask is None:
105  subMasks[pixelStr] = Image.new('1', (width + 2, height + 2))
106 
107  # Set the pixel value to 1 (default is 0), accounting for padding
108  subMasks[pixelStr].putpixel((x + 1, y + 1), 1)
109 
110  return subMasks
111 
112  def reshapeContour(self, candidate: ContourCandidate) -> ContourPoints:
113  listOfPoints: ContourPoints = []
114 
115  for segmentContour in candidate.contours:
116  for value in range(0, len(segmentContour) - 1, 2):
117  point = [segmentContour[value], segmentContour[value + 1]]
118  listOfPoints.append(point)
119 
120  return listOfPoints
121 
122  def calculateIoU(self, boxA: Dict[str, float], boxB: List[float]) -> float:
123  """
124  Calculates area of overlap for object boxes and contour boxes
125 
126  Parameters
127  ----------
128  boxA : Dict[str, float]
129  annotated object boxes
130  boxB : Dict[str, float]
131  extracted boxes from contour
132 
133  Returns
134  -------
135  Value of boxes area overlap
136  """
137 
138  xmax = boxA['top_left_x'] + boxA['width']
139  ymax = boxA['top_left_y'] + boxA['height']
140 
141  xA = max(boxA['top_left_x'], boxB[0])
142  yA = max(boxA['top_left_y'], boxB[1])
143  xB = min(xmax, boxB[2])
144  yB = min(ymax, boxB[3])
145 
146  # Compute the area of intersection rectangle
147  interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
148 
149  # Compute the area of both the prediction and ground-truth rectangles
150  boxAArea = (xmax - boxA['top_left_x'] + 1) * (ymax - boxA['top_left_y'] + 1)
151  boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
152 
153  return interArea / float(boxAArea + boxBArea - interArea)
154 
155  def boxesOverlap(self, bbox: Dict[str, Any], listOfPoints: ContourPoints) -> float:
156  """
157  Extracts boxes from contour and calculates IoU
158 
159  Parameters
160  ----------
161  bbox : Dict[str, Any]
162  annotated object boxes
163  listOfPoints : ContourPoints
164  contour points
165 
166  Returns
167  -------
168  float -> Calculated overlap of boxes area
169  """
170 
171  xmin: Optional[float] = None
172  ymin: Optional[float] = None
173  xmax: Optional[float] = None
174  ymax: Optional[float] = None
175 
176  for point in listOfPoints:
177  currentX = point[0]
178  currentY = point[1]
179 
180  if xmin is None or xmax is None:
181  xmin = currentX
182  xmax = currentX
183 
184  if ymin is None or ymax is None:
185  ymin = currentY
186  ymax = currentY
187 
188  # get xmin, xmax, ymin, ymax
189  xmin = min(xmin, currentX)
190  ymin = min(ymin, currentY)
191  xmax = max(xmax, currentX)
192  ymax = max(ymax, currentY)
193 
194  if xmax is not None and xmin is not None and ymin is not None and ymax is not None:
195  contourBox = [xmin, ymin, xmax, ymax]
196  iou = self.calculateIoU(bbox, contourBox)
197 
198  return iou
199 
200  def matchContour(self, bbox: Dict[str, Any], objectCandidate: ObjectCandidate, contourCandidates: List[ContourCandidate]) -> ContourPoints:
201  """
202  Matches object and contour with max area of overlap
203 
204  Parameters
205  ----------
206  bbox : Dict[str, Any]
207  annotated object boxes
208  objectCandidate : ObjectCandidate
209  groundtruth object
210  contourCandidates : List[ContourCandidate]
211  list of ContourCandidate objects
212 
213  Returns
214  -------
215  ContourPoints -> The corresponding contour which matches given object
216  """
217 
218  maxIoU = -1.0
219  contourIndex = 0
220 
221  for candidate in contourCandidates:
222  if candidate.matched:
223  continue
224 
225  listOfPoints = self.reshapeContour(candidate)
226  candidate.iou = self.boxesOverlap(bbox, listOfPoints)
227 
228  if maxIoU < candidate.iou:
229  maxIoU = candidate.iou
230  contourIndex = contourCandidates.index(candidate)
231 
232  objectCandidate.matched = True
233  contourCandidates[contourIndex].matched = True
234 
235  return contourCandidates[contourIndex].contours
236 
237  def extractSubmaskContours(self, subMask: PILImage) -> ContourPoints:
238  """
239  Extracts contours from submask image
240 
241  Parameters
242  ----------
243  subMask : Image
244  binary image
245 
246  Returns
247  -------
248  ContourPoints -> List of contours
249  """
250 
251  subMaskArray = np.asarray(subMask)
252  contours = measure.find_contours(subMaskArray, 0.5)
253 
254  segmentations: ContourPoints = []
255  for contour in contours:
256  for i in range(len(contour)):
257  row, col = contour[i]
258  contour[i] = (col - 1, row - 1)
259 
260  # Make a polygon and simplify it
261  poly = Polygon(contour)
262  #poly = poly.simplify(1.0, preserve_topology=False)
263 
264  if poly.geom_type == 'MultiPolygon':
265  # If MultiPolygon, take the smallest convex Polygon containing all the points in the object
266  poly = poly.convex_hull
267 
268  # Ignore if still not a Polygon (could be a line or point)
269  if poly.geom_type == 'Polygon':
270  segmentation = np.array(poly.exterior.coords).ravel().tolist()
271  segmentations.append(segmentation)
272 
273  return segmentations
274 
275  def getSegmentationInstance(
276  self,
277  objectCandidate: ObjectCandidate,
278  contourCandidates: Optional[List[ContourCandidate]]=None
279  ) -> Optional[CoretexSegmentationInstance]:
280 
281  label = getTag(objectCandidate.object, "name")
282  if label is None:
283  return None
284 
285  clazz = self.__dataset.classByName(label)
286  if clazz is None:
287  return None
288 
289  bndbox = objectCandidate.object.find('bndbox')
290  if bndbox is None:
291  return None
292 
293  boxes = getBoxes(bndbox)
294  if boxes is None:
295  return None
296 
297  polygon: ContourPoints = []
298  if contourCandidates is not None:
299  polygon = self.matchContour(boxes, objectCandidate, contourCandidates)
300 
301  bbox = BBox.decode(boxes)
302 
303  if len(polygon) == 0:
304  polygon = [bbox.polygon]
305 
306  return CoretexSegmentationInstance.create(clazz.classIds[0], bbox, polygon)
307 
308  def __extractNonSegmentedInstances(self, objectCandidates: List[ObjectCandidate]) -> List[CoretexSegmentationInstance]:
309  coretexInstances: List[CoretexSegmentationInstance] = []
310 
311  for objectCandidate in objectCandidates:
312  instance = self.getSegmentationInstance(objectCandidate)
313  if instance is None:
314  continue
315 
316  coretexInstances.append(instance)
317 
318  return coretexInstances
319 
320  def __extractSegmentedInstances(
321  self,
322  segmentationPath: str,
323  filename: str,
324  objectCandidates: List[ObjectCandidate]
325  ) -> List[CoretexSegmentationInstance]:
326 
327  imageFile = Image.open(os.path.join(segmentationPath, filename))
328 
329  contourCandidates: List[ContourCandidate] = []
330 
331  maskImage = imageFile.convert("RGB")
332  subMasks = self.createSubmasks(maskImage)
333 
334  for color, subMask in subMasks.items():
335  # Ignore border contours
336  if color == '(224, 224, 192)':
337  continue
338 
339  annotation = self.extractSubmaskContours(subMask)
340  contourCandidates.append(ContourCandidate(annotation))
341 
342  coretexInstances: List[CoretexSegmentationInstance] = []
343 
344  for objectCandidate in objectCandidates:
345  if objectCandidate.matched:
346  continue
347 
348  instance = self.getSegmentationInstance(objectCandidate, contourCandidates)
349  if instance is None:
350  continue
351 
352  coretexInstances.append(instance)
353 
354  return coretexInstances
355 
356  def extractInstances(
357  self,
358  root: ET.Element,
359  filename: str,
360  segmentationPath: str
361  ) -> List[CoretexSegmentationInstance]:
362  """
363  Extracts polygons from segmentation masks and creates instances
364 
365  Parameters
366  ----------
367  filename : str
368  file with annotations
369  objects : List[ET.Element]
370  annotated objects
371  """
372 
373  objects = root.findall("object")
374  objectCandidates: List[ObjectCandidate] = [ObjectCandidate(obj) for obj in objects]
375 
376  segmented = getTag(root, "segmented")
377  if segmented is not None:
378  isSegmented = bool(int(segmented))
379 
380  if isSegmented:
381  return self.__extractSegmentedInstances(segmentationPath, filename, objectCandidates)
382  else:
383  return self.__extractNonSegmentedInstances(objectCandidates)
384 
385  # TODO: Raise error or fallback to extracting non-segmented instances?
386  raise RuntimeError(">> [Coretex] (segmented) XML tag missing")