Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(preprocessor): modify ffmpeg video pre add video cutting method
Browse files Browse the repository at this point in the history
  • Loading branch information
Larryjianfeng committed Aug 12, 2019
1 parent c81e58f commit c150ad5
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 1 deletion.
2 changes: 1 addition & 1 deletion gnes/preprocessor/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def hsv_histogram(image: 'np.ndarray') -> 'np.ndarray':


def phash_descriptor(image: 'np.ndarray'):
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
image = Image.fromarray(image)
import imagehash
return imagehash.phash(image)

Expand Down
12 changes: 12 additions & 0 deletions gnes/preprocessor/video/ffmpeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(self,

def apply(self, doc: 'gnes_pb2.Document') -> None:
super().apply(doc)
from sklearn.cluster import KMeans
if doc.raw_bytes:
if self.use_image_input:
frames = split_video_frames(doc.raw_bytes, self.splitter)
Expand Down Expand Up @@ -155,6 +156,17 @@ def apply(self, doc: 'gnes_pb2.Document') -> None:
else:
sub_videos = [frames]

# cut by clustering: params required
# segment_num
elif self.segment_method == 'cut_by_clustering':
if self.segment_num >= 2:
hash_v = [phash_descriptor(_).hash for _ in frames]
label_v = KMeans(n_clusters=self.segment_num
).fit_predict(np.array(hash_v, dtype=np.int32))
sub_videos = [[frames[i] for i, j in enumerate(label_v) if j == _] for _ in range(self.segment_num)]
else:
sub_videos = [frames]

for ci, chunk in enumerate(sub_videos):
c = doc.chunks.add()
c.doc_id = doc.doc_id
Expand Down
19 changes: 19 additions & 0 deletions tests/test_video_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def setUp(self):
self.yml_path = os.path.join(self.dirname, 'yaml', 'preprocessor-ffmpeg.yml')
self.yml_path_2 = os.path.join(self.dirname, 'yaml', 'preprocessor-ffmpeg2.yml')
self.yml_path_3 = os.path.join(self.dirname, 'yaml', 'preprocessor-ffmpeg3.yml')
self.yml_path_3 = os.path.join(self.dirname, 'yaml', 'preprocessor-ffmpeg4.yml')
self.video_path = os.path.join(self.dirname, 'videos')
self.video_bytes = [open(os.path.join(self.video_path, _), 'rb').read()
for _ in os.listdir(self.video_path)]
Expand Down Expand Up @@ -87,3 +88,21 @@ def test_video_cut_by_num(self):
r = client.recv_message()
for d in r.request.index.docs:
self.assertEqual(len(d.chunks), 6)

def test_video_cut_by_clustering(self):
args = set_preprocessor_service_parser().parse_args([
'--yaml_path', self.yml_path_4
])
c_args = _set_client_parser().parse_args([
'--port_in', str(args.port_out),
'--port_out', str(args.port_in)
])

with PreprocessorService(args), ZmqClient(c_args) as client:
for req in RequestGenerator.index(self.video_bytes):
msg = gnes_pb2.Message()
msg.request.index.CopyFrom(req.index)
client.send_message(msg)
r = client.recv_message()
for d in r.request.index.docs:
self.assertEqual(len(d.chunks), 6)
8 changes: 8 additions & 0 deletions tests/yaml/preprocessor-ffmpeg4.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
!FFmpegVideoSegmentor
parameter:
segment_method: cut_by_clustering
segment_num: 6
s: "192*168"
r: 1
gnes_config:
is_trained: true

0 comments on commit c150ad5

Please sign in to comment.