Coretex
synthetic_image_generator.py
1 # Copyright (C) 2023 Coretex LLC
2 
3 # This file is part of Coretex.ai
4 
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License as
7 # published by the Free Software Foundation, either version 3 of the
8 # License, or (at your option) any later version.
9 
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU Affero General Public License for more details.
14 
15 # You should have received a copy of the GNU Affero General Public License
16 # along with this program. If not, see <https://www.gnu.org/licenses/>.
17 
18 from typing import List, Tuple
19 from typing_extensions import Self
20 from pathlib import Path
21 from zipfile import ZipFile
22 
23 import copy
24 import json
25 
26 from PIL import Image
27 from PIL.Image import Image as PILImage
28 from numpy import ndarray
29 
30 import numpy as np
31 
32 from .base import BaseImageDataset
33 from ...sample import ImageSample, AnnotatedImageSampleData
34 from ...annotation import CoretexSegmentationInstance, CoretexImageAnnotation, BBox
35 from ...._folder_manager import folder_manager
36 
37 
38 ANNOTATION_NAME = "annotations.json"
39 
40 
41 class AugmentedImageSample(ImageSample):
42 
43  @property
44  def path(self) -> Path:
45  """
46  Returns
47  -------
48  Path -> path to new augmented samples directory
49  """
50 
51  return folder_manager.temp / "temp-augmented-ds" / str(self.id)
52 
53  @classmethod
54  def createFromSample(cls, sample: ImageSample) -> Self:
55  """
56  Creates exact copy of sample from provided sample
57 
58  Parameters
59  ----------
60  sample : ImageSample
61  sample object
62 
63  Returns
64  -------
65  Self -> sample object
66  """
67 
68  obj = cls()
69 
70  for key, value in sample.__dict__.items():
71  obj.__dict__[key] = copy.deepcopy(value)
72 
73  return obj
74 
75 
76 def isOverlapping(
77  x: int,
78  y: int,
79  image: PILImage,
80  locations: List[Tuple[int, int, int, int]]
81 ) -> bool:
82 
83  for loc in locations:
84  if (x < loc[0] + loc[2] and x + image.width > loc[0] and
85  y < loc[1] + loc[3] and y + image.height > loc[1]):
86 
87  return True
88 
89  return False
90 
91 
92 def generateSegmentedImage(image: np.ndarray, segmentationMask: np.ndarray) -> PILImage:
93  rgbaImage = Image.fromarray(image).convert("RGBA")
94 
95  segmentedImage = np.asarray(rgbaImage) * segmentationMask
96  segmentedImage = Image.fromarray(segmentedImage)
97 
98  alpha = segmentedImage.getchannel("A")
99  bbox = alpha.getbbox()
100  croppedImage = segmentedImage.crop(bbox)
101 
102  if not isinstance(croppedImage, PILImage):
103  raise TypeError(f"Expected \"PIL.Image.Image\" recieved \"{type(croppedImage)}\"")
104 
105  return croppedImage
106 
107 
108 def composeImage(
109  segmentedImages: List[PILImage],
110  backgroundImage: np.ndarray,
111  angle: int,
112  scale: float
113 ) -> Tuple[PILImage, List[Tuple[int, int]]]:
114 
115  centroids: List[Tuple[int, int]] = []
116  locations: List[Tuple[int, int, int, int]] = []
117 
118  background = Image.fromarray(backgroundImage)
119 
120  for segmentedImage in segmentedImages:
121  image = segmentedImage
122 
123  rotatedImage = image.rotate(angle, expand = True)
124  resizedImage = rotatedImage.resize((int(rotatedImage.width * scale), int(rotatedImage.height * scale)))
125 
126  # Calculate the maximum x and y coordinates for the top left corner of the image
127  maxX = background.width - resizedImage.width
128  maxY = background.height - resizedImage.height
129 
130  while True:
131  # Generate a random location within the bounds of the background image
132  x = np.random.randint(0, maxX)
133  y = np.random.randint(0, maxY)
134 
135  # Check if the image overlaps with any previously pasted images
136  if not isOverlapping(x, y, resizedImage, locations):
137  break
138 
139  background.paste(resizedImage, (x, y), resizedImage)
140 
141  centerX = x + resizedImage.width // 2
142  centerY = y + resizedImage.height // 2
143 
144  centroids.append((centerX, centerY))
145 
146  # Add the location to the list
147  locations.append((x, y, resizedImage.width, resizedImage.height))
148 
149  return background, centroids
150 
151 
152 def processInstance(
153  sample: ImageSample,
154  backgroundSampleData: AnnotatedImageSampleData,
155  angle: int,
156  scale: float
157 ) -> Tuple[PILImage, List[CoretexSegmentationInstance]]:
158 
159  segmentedImages: List[PILImage] = []
160  augmentedInstances: List[CoretexSegmentationInstance]= []
161 
162  sampleData = sample.load()
163  if sampleData.annotation is None:
164  raise RuntimeError(f"CTX sample dataset sample id: {sample.id} image doesn't exist!")
165 
166  for instance in sampleData.annotation.instances:
167  foregroundMask = instance.extractBinaryMask(sampleData.image.shape[1], sampleData.image.shape[0])
168  segmentedImage = generateSegmentedImage(sampleData.image, foregroundMask)
169 
170  segmentedImages.append(segmentedImage)
171 
172  composedImage, centroids = composeImage(segmentedImages, backgroundSampleData.image, angle, scale)
173 
174  for i, instance in enumerate(sampleData.annotation.instances):
175  segmentationsFlatten = [sample for sublist in instance.segmentations for sample in sublist]
176  augmentedInstance = CoretexSegmentationInstance.create(instance.classId, BBox.fromPoly(segmentationsFlatten), instance.segmentations)
177 
178  augmentedInstance.rotateSegmentations(angle)
179  augmentedInstance.centerSegmentations(centroids[i])
180 
181  augmentedInstances.append(augmentedInstance)
182 
183  return composedImage, augmentedInstances
184 
185 
186 def processSample(
187  sample: ImageSample,
188  backgroundSample: ImageSample,
189  angle: int,
190  scale: float
191 ) -> Tuple[PILImage, CoretexImageAnnotation]:
192 
193  backgroundSampleData = backgroundSample.load()
194 
195  composedImage, augmentedInstances = processInstance(sample, backgroundSampleData, angle, scale)
196  annotation = CoretexImageAnnotation.create(sample.name, composedImage.width, composedImage.height, augmentedInstances)
197 
198  return composedImage, annotation
199 
200 
201 def storeFiles(
202  tempPath: Path,
203  augmentedSample: ImageSample,
204  augmentedImage: PILImage,
205  annotation: CoretexImageAnnotation
206 ) -> None:
207 
208  imagePath = tempPath / f"{augmentedSample.name}.jpeg"
209  annotationPath = tempPath / ANNOTATION_NAME
210  zipPath = tempPath / f"{augmentedSample.id}"
211 
212  augmentedImage.save(imagePath)
213 
214  with open(annotationPath, 'w') as annotationfile:
215  jsonObject = json.dumps(annotation.encode())
216  annotationfile.write(jsonObject)
217 
218  with ZipFile(zipPath.with_suffix(".zip"), mode = "w") as archive:
219  archive.write(imagePath, f"{augmentedSample.name}.jpeg")
220  archive.write(annotationPath, ANNOTATION_NAME)
221 
222  imagePath.unlink(missing_ok = True)
223  annotationPath.unlink(missing_ok = True)
224 
225 
226 def augmentDataset(
227  normalDataset: BaseImageDataset,
228  backgroundDataset: BaseImageDataset,
229  angle: int = 0,
230  scale: float = 1.0
231 ) -> None:
232  """
233  Modifies normalDataset by adding new augmented samples to it
234 
235  Parameters
236  ----------
237  normalDataset : BaseImageDataset
238  BaseImageDataset object
239  backgroundDataset : BaseImageDataset
240  BaseImageDataset object
241  angle : int
242  angle of rotation in degrees
243  scale : float
244  scaling factor
245  """
246 
247  tempPath = folder_manager.createTempFolder("temp-augmented-ds")
248  augmentedSamples: List[AugmentedImageSample] = []
249 
250  for i, background in enumerate(backgroundDataset.samples):
251  background.unzip()
252 
253  for j, sample in enumerate(normalDataset.samples):
254  sample.unzip()
255  augmentedImage, annotations = processSample(sample, background, angle, scale)
256 
257  augmentedSample = AugmentedImageSample.createFromSample(sample)
258  augmentedSample.id = int(f"{i}{j}{augmentedSample.id}")
259 
260  storeFiles(tempPath, augmentedSample, augmentedImage, annotations)
261 
262  augmentedSamples.append(augmentedSample)
263 
264  for sample in augmentedSamples:
265  normalDataset.samples.append(sample)