forked from OCR-D/ocrd_segment
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimport_coco_segmentation.py
236 lines (218 loc) · 10.7 KB
/
import_coco_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
from __future__ import absolute_import
import os.path
import json
import logging
import numpy as np
from ocrd_utils import (
getLogger,
make_file_id,
assert_file_grp_cardinality,
points_from_polygon,
MIMETYPE_PAGE,
membername
)
from ocrd_modelfactory import page_from_file
# pragma pylint: disable=unused-import
# (region types will be referenced indirectly via globals())
from ocrd_models.ocrd_page import (
CoordsType,
TextRegionType,
ImageRegionType,
MathsRegionType,
SeparatorRegionType,
NoiseRegionType,
to_xml)
from ocrd_models.ocrd_page_generateds import (
BorderType,
TableRegionType,
GraphicRegionType,
ChartRegionType,
ChemRegionType,
LineDrawingRegionType,
MusicRegionType,
UnknownRegionType,
AdvertRegionType,
CustomRegionType,
MapRegionType,
TextTypeSimpleType,
GraphicsTypeSimpleType,
ChartTypeSimpleType
)
# pragma pylint: enable=unused-import
from ocrd import Processor
from .config import OCRD_TOOL
TOOL = 'ocrd-segment-from-coco'
class ImportCOCOSegmentation(Processor):
def __init__(self, *args, **kwargs):
kwargs['ocrd_tool'] = OCRD_TOOL['tools'][TOOL]
kwargs['version'] = OCRD_TOOL['version']
super(ImportCOCOSegmentation, self).__init__(*args, **kwargs)
def process(self):
"""Performs region segmentation by reading from COCO annotations.
Open and deserialize the COCO JSON file from the second input file group.
(It lists region categories/subtypes, file names and segmentations for all pages.)
Open and deserialize each PAGE input file (or generate from image input file)
from the first input file group. Now find this page in COCO:
- try to match the PAGE ``imageFilename`` or METS file path matches to some
COCO ``file_name``, otherwise
- try to match the numeric part of the METS physical page ID to some
COCO ``id``, otherwise
- skip with an error.
Then create and add a region for each ``segmentation``, converting its polygon
to coordinate points and its COCO category to a region type (and subtype),
either for a PubLayNet classification or PAGE classification (as produced by
ocrd-segment-extract-pages), as indicated by ``source``.
Produce a new output file by serialising the resulting hierarchy.
Afterwards, if there are still COCO images left unaccounted for (i.e. without
corresponding input files), then show a warning.
"""
LOG = getLogger('processor.ImportCOCOSegmentation')
# Load JSON
assert_file_grp_cardinality(self.input_file_grp, 2, 'base and COCO')
# pylint: disable=attribute-defined-outside-init
self.input_file_grp, coco_grp = self.input_file_grp.split(',')
# pylint: disable=attribute-defined-outside-init
if not self.input_files:
LOG.warning('No input files to process')
return
if coco_grp in self.workspace.mets.file_groups:
try:
cocofile = next(f for f in self.workspace.mets.find_files(fileGrp=coco_grp)
# if f.mimetype == 'application/json' and not f.pageId
if not f.pageId)
except StopIteration:
raise Exception("no non-page-specific file in second file group (COCO file)", coco_grp)
cocofile = self.workspace.download_file(cocofile).local_filename
elif os.path.isfile(coco_grp):
cocofile = coco_grp
else:
raise Exception("file not found in second file group (COCO file)", coco_grp)
LOG.info('Loading COCO annotations from "%s" into memory...', cocofile)
with open(cocofile, 'r') as inp:
coco = json.load(inp)
LOG.info('Loaded JSON for %d images with %d regions in %d categories',
len(coco['images']), len(coco['annotations']), len(coco['categories']))
coco_source = 'PubLayNet'
# Convert to usable dicts
# classes:
categories = dict()
subcategories = dict()
for cat in coco['categories']:
if cat['source'] == 'PAGE':
coco_source = 'PAGE'
if 'supercategory' in cat and cat['supercategory']:
categories[cat['id']] = cat['supercategory']
subcategories[cat['id']] = cat['name']
else:
categories[cat['id']] = cat['name']
# images and annotations:
images_by_id = dict()
images_by_filename = dict()
for image in coco['images']:
images_by_id[image['id']] = image
images_by_filename[image['file_name']] = image
for annotation in coco['annotations']:
image = images_by_id[annotation['image_id']]
regions = image.setdefault('regions', list())
regions.append(annotation)
del coco
LOG.info('Converting %s annotations into PAGE-XML', coco_source)
for n, input_file in enumerate(self.input_files):
page_id = input_file.pageId or input_file.ID
num_page_id = int(page_id.strip(page_id.strip("0123456789")))
LOG.info("INPUT FILE %i / %s", n, page_id)
pcgts = page_from_file(self.workspace.download_file(input_file))
self.add_metadata(pcgts)
page = pcgts.get_Page()
# find COCO image
if page.imageFilename in images_by_filename:
image = images_by_filename[page.imageFilename]
elif num_page_id in images_by_id:
image = images_by_id[num_page_id]
else:
LOG.error('Page "%s" / file "%s" not found in COCO',
page_id, page.imageFilename)
# todo: maybe we should at least write the (unchanged) output PAGE?
continue
if image['width'] != page.imageWidth:
LOG.error('Page "%s" width %d does not match annotated width %d',
page_id, page.imageWidth, image['width'])
if image['height'] != page.imageHeight:
LOG.error('Page "%s" height %d does not match annotated height %d',
page_id, page.imageHeight, image['height'])
# todo: remove existing segmentation first?
for region in image['regions']:
assert isinstance(region['segmentation'], list), "importing RLE/mask segmentation not implemented"
polygon = np.array(region['segmentation'])
polygon = np.reshape(polygon, (polygon.shape[1]//2, 2))
coords = CoordsType(points=points_from_polygon(polygon))
category = categories[region['category_id']]
if region['category_id'] in subcategories:
subcategory = subcategories[region['category_id']]
else:
subcategory = None
region_id = 'r' + str(region['id'])
LOG.info('Adding region %s:%s [area %d]', category, subcategory or '', region['area'])
if coco_source == 'PubLayNet':
if category == 'text':
region_obj = TextRegionType(id=region_id, Coords=coords,
type_=TextTypeSimpleType.PARAGRAPH)
page.add_TextRegion(region_obj)
elif category == 'title':
region_obj = TextRegionType(id=region_id, Coords=coords,
type_=TextTypeSimpleType.HEADING) # CAPTION?
page.add_TextRegion(region_obj)
elif category == 'list':
region_obj = TextRegionType(id=region_id, Coords=coords,
type_=TextTypeSimpleType.LISTLABEL) # OTHER?
page.add_TextRegion(region_obj)
elif category == 'table':
region_obj = TableRegionType(id=region_id, Coords=coords)
page.add_TableRegion(region_obj)
elif category == 'figure':
region_obj = ImageRegionType(id=region_id, Coords=coords)
page.add_ImageRegion(region_obj)
else:
raise Exception('unknown region category: %s' % category)
else: # 'PAGE'
args = {'id': region_id,
'Coords': coords}
if subcategory:
typedict = {"TextRegion": TextTypeSimpleType,
"GraphicRegion": GraphicsTypeSimpleType,
"ChartType": ChartTypeSimpleType}
if category in typedict:
subtype = membername(typedict[category], subcategory)
if subtype == subcategory:
# not predefined in PAGE: use other + custom
args['custom'] = "subtype:%s" % subcategory
args['type_'] = "other"
else:
args['type_'] = subcategory
else:
args['custom'] = "subtype:%s" % subcategory
if category + 'Type' not in globals():
raise Exception('unknown region category: %s' % category)
region_type = globals()[category + 'Type']
if region_type is BorderType:
page.set_Border(BorderType(Coords=coords))
else:
region_obj = region_type(**args)
getattr(page, 'add_%s' % category)(region_obj)
# remove image from dicts
images_by_id.pop(num_page_id, None)
images_by_filename.pop(page.imageFilename, None)
file_id = make_file_id(input_file, self.output_file_grp)
self.workspace.add_file(
ID=file_id,
file_grp=self.output_file_grp,
pageId=input_file.pageId,
mimetype=MIMETYPE_PAGE,
local_filename=os.path.join(self.output_file_grp, file_id + '.xml'),
content=to_xml(pcgts))
# warn of remaining COCO images
if images_by_filename and not self.page_id:
LOG.warning('%d images remain unaccounted for after processing', len(images_by_filename))
if LOG.isEnabledFor(logging.DEBUG):
for filename in images_by_filename:
LOG.debug('not found in workspace: "%s"', filename)