Skip to content

Commit

Permalink
规范代码
Browse files Browse the repository at this point in the history
  • Loading branch information
nl8590687 committed Sep 18, 2022
1 parent 819ce7f commit 52c816d
Show file tree
Hide file tree
Showing 9 changed files with 471 additions and 470 deletions.
537 changes: 263 additions & 274 deletions assets/asrt_pb2.py

Large diffs are not rendered by default.

163 changes: 82 additions & 81 deletions assets/asrt_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,25 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.Speech = channel.unary_unary(
'/asrt.AsrtGrpcService/Speech',
request_serializer=asrt__pb2.SpeechRequest.SerializeToString,
response_deserializer=asrt__pb2.SpeechResponse.FromString,
)
'/asrt.AsrtGrpcService/Speech',
request_serializer=asrt__pb2.SpeechRequest.SerializeToString,
response_deserializer=asrt__pb2.SpeechResponse.FromString,
)
self.Language = channel.unary_unary(
'/asrt.AsrtGrpcService/Language',
request_serializer=asrt__pb2.LanguageRequest.SerializeToString,
response_deserializer=asrt__pb2.TextResponse.FromString,
)
'/asrt.AsrtGrpcService/Language',
request_serializer=asrt__pb2.LanguageRequest.SerializeToString,
response_deserializer=asrt__pb2.TextResponse.FromString,
)
self.All = channel.unary_unary(
'/asrt.AsrtGrpcService/All',
request_serializer=asrt__pb2.SpeechRequest.SerializeToString,
response_deserializer=asrt__pb2.TextResponse.FromString,
)
'/asrt.AsrtGrpcService/All',
request_serializer=asrt__pb2.SpeechRequest.SerializeToString,
response_deserializer=asrt__pb2.TextResponse.FromString,
)
self.Stream = channel.stream_stream(
'/asrt.AsrtGrpcService/Stream',
request_serializer=asrt__pb2.SpeechRequest.SerializeToString,
response_deserializer=asrt__pb2.TextResponse.FromString,
)
'/asrt.AsrtGrpcService/Stream',
request_serializer=asrt__pb2.SpeechRequest.SerializeToString,
response_deserializer=asrt__pb2.TextResponse.FromString,
)


class AsrtGrpcServiceServicer(object):
Expand Down Expand Up @@ -68,70 +68,70 @@ def Stream(self, request_iterator, context):

def add_AsrtGrpcServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'Speech': grpc.unary_unary_rpc_method_handler(
servicer.Speech,
request_deserializer=asrt__pb2.SpeechRequest.FromString,
response_serializer=asrt__pb2.SpeechResponse.SerializeToString,
),
'Language': grpc.unary_unary_rpc_method_handler(
servicer.Language,
request_deserializer=asrt__pb2.LanguageRequest.FromString,
response_serializer=asrt__pb2.TextResponse.SerializeToString,
),
'All': grpc.unary_unary_rpc_method_handler(
servicer.All,
request_deserializer=asrt__pb2.SpeechRequest.FromString,
response_serializer=asrt__pb2.TextResponse.SerializeToString,
),
'Stream': grpc.stream_stream_rpc_method_handler(
servicer.Stream,
request_deserializer=asrt__pb2.SpeechRequest.FromString,
response_serializer=asrt__pb2.TextResponse.SerializeToString,
),
'Speech': grpc.unary_unary_rpc_method_handler(
servicer.Speech,
request_deserializer=asrt__pb2.SpeechRequest.FromString,
response_serializer=asrt__pb2.SpeechResponse.SerializeToString,
),
'Language': grpc.unary_unary_rpc_method_handler(
servicer.Language,
request_deserializer=asrt__pb2.LanguageRequest.FromString,
response_serializer=asrt__pb2.TextResponse.SerializeToString,
),
'All': grpc.unary_unary_rpc_method_handler(
servicer.All,
request_deserializer=asrt__pb2.SpeechRequest.FromString,
response_serializer=asrt__pb2.TextResponse.SerializeToString,
),
'Stream': grpc.stream_stream_rpc_method_handler(
servicer.Stream,
request_deserializer=asrt__pb2.SpeechRequest.FromString,
response_serializer=asrt__pb2.TextResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'asrt.AsrtGrpcService', rpc_method_handlers)
'asrt.AsrtGrpcService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))


# This class is part of an EXPERIMENTAL API.
# This class is part of an EXPERIMENTAL API.
class AsrtGrpcService(object):
"""定义服务接口
"""

@staticmethod
def Speech(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/asrt.AsrtGrpcService/Speech',
asrt__pb2.SpeechRequest.SerializeToString,
asrt__pb2.SpeechResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
asrt__pb2.SpeechRequest.SerializeToString,
asrt__pb2.SpeechResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def Language(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/asrt.AsrtGrpcService/Language',
asrt__pb2.LanguageRequest.SerializeToString,
asrt__pb2.TextResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
asrt__pb2.LanguageRequest.SerializeToString,
asrt__pb2.TextResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def All(request,
Expand All @@ -145,24 +145,25 @@ def All(request,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/asrt.AsrtGrpcService/All',
asrt__pb2.SpeechRequest.SerializeToString,
asrt__pb2.TextResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
asrt__pb2.SpeechRequest.SerializeToString,
asrt__pb2.TextResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def Stream(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(request_iterator, target, '/asrt.AsrtGrpcService/Stream',
asrt__pb2.SpeechRequest.SerializeToString,
asrt__pb2.TextResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
asrt__pb2.SpeechRequest.SerializeToString,
asrt__pb2.TextResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout,
metadata)
20 changes: 11 additions & 9 deletions download_default_datalist.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env python3
# !/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2016-2099 Ailemon.net
Expand Down Expand Up @@ -38,12 +38,12 @@

URL_DATALIST_INDEX = "https://d.ailemon.net/asrt_assets/datalist/index.json"
rsp_index = requests.get(URL_DATALIST_INDEX)
rsp_index.encoding='utf-8'
rsp_index.encoding = 'utf-8'
if rsp_index.ok:
logging.info('Has connected to ailemon\'s download server...')
else:
logging.error('%s%s', 'Can not connected to ailemon\'s download server.',
'please check your network connection.')
'please check your network connection.')

index_json = json.loads(rsp_index.text)
if index_json['status_code'] != 200:
Expand All @@ -62,10 +62,11 @@
else:
num = int(num)


def deal_download(datalist_item, url_prefix_str, datalist_path):
'''
"""
to deal datalist file download
'''
"""
logging.info('%s%s', 'start to download datalist ', datalist_item['name'])
save_path = os.path.join(datalist_path, datalist_item['name'])
if not os.path.exists(save_path):
Expand All @@ -83,8 +84,9 @@ def deal_download(datalist_item, url_prefix_str, datalist_path):
logging.info('%s `%s` %s', 'Download', filename, 'complete')
else:
logging.error('%s%s%s%s%s', 'Can not download ', filename,
' from ailemon\'s download server. ',
'http status ok is ', str(rsp_listfile.ok))
' from ailemon\'s download server. ',
'http status ok is ', str(rsp_listfile.ok))


if num == len(body['datalist']):
for i in range(len(body['datalist'])):
Expand All @@ -93,5 +95,5 @@ def deal_download(datalist_item, url_prefix_str, datalist_path):
deal_download(body['datalist'][num], body['url_prefix'], DEFAULT_DATALIST_PATH)

logging.info('%s%s%s', 'Datalist files download complete. ',
'Please remember to download these datasets from ',
body['dataset_download_page_url'])
'Please remember to download these datasets from ',
body['dataset_download_page_url'])
4 changes: 2 additions & 2 deletions evaluate_speech_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@
sm251bn = SpeechModel251BN(
input_shape=(AUDIO_LENGTH, AUDIO_FEATURE_LENGTH, CHANNELS),
output_size=OUTPUT_SIZE
)
)
feat = Spectrogram()
evalue_data = DataLoader('dev')
ms = ModelSpeech(sm251bn, feat, max_label_length=64)

ms.load_model('save_models/' + sm251bn.get_model_name() + '.model.h5')
ms.evaluate_model(data_loader=evalue_data, data_count=-1,
out_report=True, show_ratio=True, show_per_step=100)
out_report=True, show_ratio=True, show_per_step=100)
40 changes: 21 additions & 19 deletions language_model3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env python3
# !/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2016-2099 Ailemon.net
Expand Down Expand Up @@ -29,30 +29,32 @@

from utils.ops import get_symbol_dict, get_language_model


class ModelLanguage:
'''
"""
ASRT专用N-Gram语言模型
'''
"""

def __init__(self, model_path: str):
self.model_path = model_path
self.dict_pinyin = dict()
self.model1 = dict()
self.model2 = dict()

def load_model(self):
'''
"""
加载N-Gram语言模型到内存
'''
"""
self.dict_pinyin = get_symbol_dict('dict.txt')
self.model1 = get_language_model(os.path.join(self.model_path, 'language_model1.txt'))
self.model2 = get_language_model(os.path.join(self.model_path, 'language_model2.txt'))
model = (self.dict_pinyin, self.model1, self.model2 )
model = (self.dict_pinyin, self.model1, self.model2)
return model

def pinyin_to_text(self, list_pinyin: list, beam_size: int=100) -> str:
'''
def pinyin_to_text(self, list_pinyin: list, beam_size: int = 100) -> str:
"""
拼音转文本,一次性取得全部结果
'''
"""
result = list()
tmp_result_last = list()
for item_pinyin in list_pinyin:
Expand All @@ -71,11 +73,11 @@ def pinyin_to_text(self, list_pinyin: list, beam_size: int=100) -> str:
return ''.join(result)

def pinyin_stream_decode(self, temple_result: list,
item_pinyin: str,
beam_size: int = 100) -> list:
'''
item_pinyin: str,
beam_size: int = 100) -> list:
"""
拼音流式解码,逐字转换,每次返回中间结果
'''
"""
# 如果这个拼音不在汉语拼音字典里的话,直接返回空列表,不做decode
if item_pinyin not in self.dict_pinyin:
return []
Expand All @@ -100,19 +102,19 @@ def pinyin_stream_decode(self, temple_result: list,
# 如果2-gram子序列不存在
continue
# 计算状态转移概率
prob_origin = sequence[1] # 原始概率
count_two_word = float(self.model2[tuple2_word]) # 二字频数
count_one_word = float(self.model1[tuple2_word[-2]]) # 单字频数
prob_origin = sequence[1] # 原始概率
count_two_word = float(self.model2[tuple2_word]) # 二字频数
count_one_word = float(self.model1[tuple2_word[-2]]) # 单字频数
cur_probility = prob_origin * count_two_word / count_one_word
new_result.append([sequence[0]+cur_word, cur_probility])
new_result.append([sequence[0] + cur_word, cur_probility])

new_result = sorted(new_result, key=lambda x:x[1], reverse=True)
new_result = sorted(new_result, key=lambda x: x[1], reverse=True)
if len(new_result) > beam_size:
return new_result[0:beam_size]
return new_result


if __name__=='__main__':
if __name__ == '__main__':
ml = ModelLanguage('model_language')
ml.load_model()

Expand Down
6 changes: 3 additions & 3 deletions predict_speech_file.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/usr/bin/env python3
# !/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2016-2099 Ailemon.net
Expand Down Expand Up @@ -40,7 +40,7 @@
sm251bn = SpeechModel251BN(
input_shape=(AUDIO_LENGTH, AUDIO_FEATURE_LENGTH, CHANNELS),
output_size=OUTPUT_SIZE
)
)
feat = Spectrogram()
ms = ModelSpeech(sm251bn, feat, max_label_length=64)

Expand All @@ -52,4 +52,4 @@
ml.load_model()
str_pinyin = res
res = ml.pinyin_to_text(str_pinyin)
print('语音识别最终结果:\n',res)
print('语音识别最终结果:\n', res)
Loading

0 comments on commit 52c816d

Please sign in to comment.