Skip to content

Commit

Permalink
feature: add gif pretreatment.
Browse files Browse the repository at this point in the history
  • Loading branch information
kerlomz committed Dec 22, 2019
1 parent a55c95d commit e3cd049
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 39 deletions.
16 changes: 14 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,20 @@ def __init__(self, conf: Config, model_conf_path: str):
self.image_width: int = self.field_root.get('ImageWidth')
self.image_height: int = self.field_root.get('ImageHeight')
self.resize: list = self.field_root.get('Resize')
self.replace_transparent: bool = self.field_root.get("ReplaceTransparent")
self.horizontal_stitching: bool = self.field_root.get("HorizontalStitching")
self.output_split = self.field_root.get('OutputSplit')
self.output_split = self.output_split if self.output_split else ""
self.corp_params = self.field_root.get('CorpParams')
self.output_coord = self.field_root.get('OutputCoord')
self.batch_model = self.field_root.get('BatchModel')

"""PRETREATMENT"""
self.pretreatment_root = self.model_conf.get('Pretreatment')
self.pre_binaryzation = self.get_var(self.pretreatment_root, 'Binaryzation', -1)
self.pre_replace_transparent = self.get_var(self.pretreatment_root, 'ReplaceTransparent', True)
self.pre_horizontal_stitching = self.get_var(self.pretreatment_root, 'HorizontalStitching', False)
self.pre_concat_frames = self.get_var(self.pretreatment_root, 'ConcatFrames', -1)
self.pre_blend_frames = self.get_var(self.pretreatment_root, 'BlendFrames', -1)

"""COMPILE_MODEL"""
self.compile_model_path = os.path.join(self.graph_path, '{}.pb'.format(self.model_name))
if not os.path.exists(self.compile_model_path):
Expand All @@ -208,6 +214,12 @@ def param_convert(source, param_map: dict, text, code, default=None):
def size_match(self, size_str):
return size_str == self.size_string

@staticmethod
def get_var(src: dict, name: str, default=None):
if not src:
return default
return src.get(name)

@property
def size_string(self):
return "{}x{}".format(self.image_width, self.image_height)
18 changes: 11 additions & 7 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# DEFAULT_HOST = "127.0.0.1"


def _image(_path, model_type=None, model_site=None, need_color=None):
def _image(_path, model_type=None, model_site=None, need_color=None, fpath=None):
with open(_path, "rb") as f:
img_bytes = f.read()
# data_stream = io.BytesIO(img_bytes)
Expand All @@ -38,6 +38,7 @@ def _image(_path, model_type=None, model_site=None, need_color=None):
'model_type': model_type,
'model_site': model_site,
'need_color': need_color,
'path': fpath
}


Expand Down Expand Up @@ -102,7 +103,9 @@ def __init__(self, host: str, server_type: ServerType, port=None, url=None):

def request(self, params):
import json
print(json.dumps(params))
# print(params)
# print(params['fpath'])
# print(json.dumps(params))
# return post(self._url, data=base64.b64decode(params.get("image").encode())).json()
return post(self._url, json=params).json()

Expand All @@ -114,8 +117,8 @@ def local_iter(self, image_list: dict):
if _true:
self.true_count += 1
self.total_count += 1
print('result: {}, label: {}, flag: {}, acc_rate: {}'.format(
code, k, _true, self.true_count / self.total_count
print('result: {}, label: {}, flag: {}, acc_rate: {}, {}'.format(
code, k, _true, self.true_count / self.total_count, v.get('path')
))
except Exception as e:
print(e)
Expand Down Expand Up @@ -241,24 +244,25 @@ def press_testing(self, image_list: dict, model_type=None, model_site=None):
# pass

# API by gRPC - The fastest way, Local batch version, only for self testing.
path = r"C:\Users\kerlomz\Desktop\New folder (2)"
path = r"C:\Users\kerlomz\Desktop\New folder (6)"
path_list = os.listdir(path)
import random

random.shuffle(path_list)
# random.shuffle(path_list)
print(path_list)
batch = {
_path.split('_')[0].lower(): _image(
os.path.join(path, _path),
model_type=None,
model_site=None,
need_color=None,
fpath=_path
)
for i, _path in enumerate(path_list)
if i < 10000
}
print(batch)
NoAuth(DEFAULT_HOST, ServerType.TORNADO, port=19982).local_iter(batch)
NoAuth(DEFAULT_HOST, ServerType.TORNADO, port=19952).local_iter(batch)
# NoAuth(DEFAULT_HOST, ServerType.FLASK).local_iter(batch)
# NoAuth(DEFAULT_HOST, ServerType.SANIC).local_iter(batch)
# GoogleRPC(DEFAULT_HOST).local_iter(batch, model_site=None, model_type=None)
Expand Down
60 changes: 60 additions & 0 deletions middleware/impl/gif_frames.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Author: kerlomz <[email protected]>

import cv2
import numpy as np
from PIL import ImageSequence


def split_frames(image_obj, need_frame=None):
image_seq = ImageSequence.all_frames(image_obj)
image_arr_last = [np.asarray(image_seq[-1])] if -1 in need_frame and len(need_frame) > 1 else []
image_arr = [np.asarray(item) for i, item in enumerate(image_seq) if (i in need_frame or need_frame == [-1])]
image_arr += image_arr_last
return image_arr


def concat_arr(img_arr):
if len(img_arr) < 2:
return img_arr[0]
all_slice = img_arr[0]
for im_slice in img_arr[1:]:
all_slice = np.concatenate((all_slice, im_slice), axis=1)
return all_slice


def numpy_to_bytes(numpy_arr):
cv_img = cv2.imencode('.png', numpy_arr)[1]
img_bytes = bytes(bytearray(cv_img))
return img_bytes


def concat_frames(image_obj, need_frame=None):
if not need_frame:
need_frame = [0]
img_arr = split_frames(image_obj, need_frame)
img_arr = concat_arr(img_arr)
return img_arr


def blend_arr(img_arr):
if len(img_arr) < 2:
return img_arr[0]
all_slice = img_arr[0]
for im_slice in img_arr[1:]:
all_slice = cv2.addWeighted(all_slice, 0.5, im_slice, 0.5, 0)
all_slice = cv2.equalizeHist(all_slice)
return all_slice


def blend_frame(image_obj, need_frame=None):
if not need_frame:
need_frame = [-1]
img_arr = split_frames(image_obj, need_frame)
img_arr = blend_arr(img_arr)
return img_arr


if __name__ == "__main__":
pass
2 changes: 1 addition & 1 deletion package.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Version(Enum):

if __name__ == '__main__':

ver = Version.CPU
ver = Version.GPU
upload = False
server_ip = ""
username = ""
Expand Down
24 changes: 1 addition & 23 deletions pretreatment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,11 @@ def binarization(self, value, modify=False):
self.origin = _binarization
return _binarization

def median_blur(self, value, modify=False):
if not value:
return self.origin
value = value + 1 if value % 2 == 0 else value
_smooth = cv2.medianBlur(self.origin, value)
if modify:
self.origin = _smooth
return _smooth

def gaussian_blur(self, value, modify=False):
if not value:
return self.origin
value = value + 1 if value % 2 == 0 else value
_blur = cv2.GaussianBlur(self.origin, (value, value), 0)
if modify:
self.origin = _blur
return _blur


def preprocessing(image, binaryzation=-1, smooth=-1, blur=-1):
def preprocessing(image, binaryzation=-1):
pretreatment = Pretreatment(image)
if binaryzation > 0:
pretreatment.binarization(binaryzation, True)
if smooth != -1:
pretreatment.median_blur(smooth, True)
if blur != -1:
pretreatment.gaussian_blur(blur, True)
return pretreatment.get()


Expand Down
2 changes: 1 addition & 1 deletion tornado_server_gpu.spec
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ coll = COLLECT(exe,
strip=False,
upx=True,
upx_exclude=[],
name='app')
name='captcha_platform_tornado_gpu')



24 changes: 19 additions & 5 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from constants import Response, SystemConfig
from pretreatment import preprocessing
from config import ModelConfig, Config
from middleware.impl.gif_frames import concat_frames, blend_frame


class Arithmetic(object):
Expand Down Expand Up @@ -120,16 +121,29 @@ def load_image(image_bytes):
rgb = pil_image.split()
size = pil_image.size

if len(rgb) > 3 and model.replace_transparent:
gif_handle = model.pre_concat_frames != -1 or model.pre_blend_frames != -1

if len(rgb) > 3 and model.pre_replace_transparent and gif_handle:
background = PIL_Image.new('RGB', pil_image.size, (255, 255, 255))
background.paste(pil_image, (0, 0, size[0], size[1]), pil_image)
pil_image = background

if model.image_channel == 1:
pil_image = pil_image.convert('L')
if model.pre_concat_frames != -1:
im = concat_frames(pil_image, model.pre_concat_frames)
elif model.pre_blend_frames != -1:
im = blend_frame(pil_image, model.pre_blend_frames)
else:
im = np.array(pil_image)

if model.image_channel == 1 and len(im.shape) == 3:
im = im.mean(axis=2).astype(np.float32)

im = preprocessing(
image=im,
binaryzation=model.pre_binaryzation,
)

im = np.asarray(pil_image)
if model.horizontal_stitching:
if model.pre_horizontal_stitching:
up_slice = im[0: int(size[1] / 2), 0: size[0]]
down_slice = im[int(size[1] / 2): size[1], 0: size[0]]
im = np.concatenate((up_slice, down_slice), axis=1)
Expand Down

0 comments on commit e3cd049

Please sign in to comment.