Skip to content

Commit

Permalink
增加ctc stream decode功能,并规范代码
Browse files Browse the repository at this point in the history
  • Loading branch information
nl8590687 committed Sep 16, 2022
1 parent 89c757b commit 819ce7f
Show file tree
Hide file tree
Showing 13 changed files with 469 additions and 380 deletions.
66 changes: 32 additions & 34 deletions asrserver_grpc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@

#!/usr/bin/env python3
# !/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2016-2099 Ailemon.net
Expand Down Expand Up @@ -37,13 +36,13 @@
from language_model3 import ModelLanguage
from utils.ops import decode_wav_bytes

API_STATUS_CODE_OK = 200000 # OK
API_STATUS_CODE_OK_PART = 206000 # 部分结果OK,用于stream
API_STATUS_CODE_OK = 200000 # OK
API_STATUS_CODE_OK_PART = 206000 # 部分结果OK,用于stream
API_STATUS_CODE_CLIENT_ERROR = 400000
API_STATUS_CODE_CLIENT_ERROR_FORMAT = 400001 # 请求数据格式错误
API_STATUS_CODE_CLIENT_ERROR_CONFIG = 400002 # 请求数据配置不支持
API_STATUS_CODE_CLIENT_ERROR_FORMAT = 400001 # 请求数据格式错误
API_STATUS_CODE_CLIENT_ERROR_CONFIG = 400002 # 请求数据配置不支持
API_STATUS_CODE_SERVER_ERROR = 500000
API_STATUS_CODE_SERVER_ERROR_RUNNING = 500001 # 服务器运行中出错
API_STATUS_CODE_SERVER_ERROR_RUNNING = 500001 # 服务器运行中出错

parser = argparse.ArgumentParser(description='ASRT gRPC Protocol API Service')
parser.add_argument('--listen', default='0.0.0.0', type=str, help='the network to listen')
Expand All @@ -58,86 +57,85 @@
sm251bn = SpeechModel251BN(
input_shape=(AUDIO_LENGTH, AUDIO_FEATURE_LENGTH, CHANNELS),
output_size=OUTPUT_SIZE
)
)
feat = Spectrogram()
ms = ModelSpeech(sm251bn, feat, max_label_length=64)
ms.load_model('save_models/' + sm251bn.get_model_name() + '.model.h5')

ml = ModelLanguage('model_language')
ml.load_model()



_ONE_DAY_IN_SECONDS = 60 * 60 * 24


class ApiService(AsrtGrpcServiceServicer):
'''
"""
继承AsrtGrpcServiceServicer,实现hello方法
'''
"""

def __init__(self):
pass

def Speech(self, request, context):
'''
"""
具体实现Speech的方法, 并按照pb的返回对象构造SpeechResponse返回
:param request:
:param context:
:return:
'''
"""
wav_data = request.wav_data
wav_samples = decode_wav_bytes(samples_data=wav_data.samples,
channels=wav_data.channels, byte_width=wav_data.byte_width)
channels=wav_data.channels, byte_width=wav_data.byte_width)
result = ms.recognize_speech(wav_samples, wav_data.sample_rate)
print("语音识别声学模型结果:", result)
return SpeechResponse(status_code=API_STATUS_CODE_OK, status_message='',
result_data=result)
result_data=result)

def Language(self, request, context):
'''
"""
具体实现Language的方法, 并按照pb的返回对象构造TextResponse返回
:param request:
:param context:
:return:
'''
"""
print('Language收到了请求:', request)
result = ml.pinyin_to_text(list(request.pinyins))
print('Language结果:', result)
return TextResponse(status_code=API_STATUS_CODE_OK, status_message='',
text_result=result)
text_result=result)

def All(self, request, context):
'''
"""
具体实现All的方法, 并按照pb的返回对象构造TextResponse返回
:param request:
:param context:
:return:
'''
"""
wav_data = request.wav_data
wav_samples = decode_wav_bytes(samples_data=wav_data.samples,
channels=wav_data.channels, byte_width=wav_data.byte_width)
channels=wav_data.channels, byte_width=wav_data.byte_width)
result_speech = ms.recognize_speech(wav_samples, wav_data.sample_rate)
result = ml.pinyin_to_text(result_speech)
print("语音识别结果:", result)
return TextResponse(status_code=API_STATUS_CODE_OK, status_message='',
text_result=result)
text_result=result)

def Stream(self, request_iterator, context):
'''
"""
具体实现Stream的方法, 并按照pb的返回对象构造TextResponse返回
:param request:
:param request_iterator:
:param context:
:return:
'''
"""
result = list()
tmp_result_last = list()
beam_size = 100

for request in request_iterator:
wav_data = request.wav_data
wav_samples = decode_wav_bytes(samples_data=wav_data.samples,
channels=wav_data.channels,
byte_width=wav_data.byte_width)
channels=wav_data.channels,
byte_width=wav_data.byte_width)
result_speech = ms.recognize_speech(wav_samples, wav_data.sample_rate)

for item_pinyin in result_speech:
Expand All @@ -146,28 +144,28 @@ def Stream(self, request_iterator, context):
result.append(tmp_result_last[0][0])
print("流式语音识别结果:", ''.join(result))
yield TextResponse(status_code=API_STATUS_CODE_OK, status_message='',
text_result=''.join(result))
text_result=''.join(result))
result = list()

tmp_result = ml.pinyin_stream_decode([], item_pinyin, beam_size)
tmp_result_last = tmp_result
yield TextResponse(status_code=API_STATUS_CODE_OK_PART, status_message='',
text_result=''.join(tmp_result[0][0]))
text_result=''.join(tmp_result[0][0]))

if len(tmp_result_last) > 0:
result.append(tmp_result_last[0][0])
print("流式语音识别结果:", ''.join(result))
yield TextResponse(status_code=API_STATUS_CODE_OK, status_message='',
text_result=''.join(result))
text_result=''.join(result))


def run(host, port):
'''
"""
gRPC API服务启动
:return:
'''
"""
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
add_AsrtGrpcServiceServicer_to_server(ApiService(),server)
add_AsrtGrpcServiceServicer_to_server(ApiService(), server)
server.add_insecure_port(''.join([host, ':', port]))
server.start()
print("start service...")
Expand Down
55 changes: 30 additions & 25 deletions asrserver_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@
from language_model3 import ModelLanguage
from utils.ops import decode_wav_bytes

API_STATUS_CODE_OK = 200000 # OK
API_STATUS_CODE_OK = 200000 # OK
API_STATUS_CODE_CLIENT_ERROR = 400000
API_STATUS_CODE_CLIENT_ERROR_FORMAT = 400001 # 请求数据格式错误
API_STATUS_CODE_CLIENT_ERROR_CONFIG = 400002 # 请求数据配置不支持
API_STATUS_CODE_CLIENT_ERROR_FORMAT = 400001 # 请求数据格式错误
API_STATUS_CODE_CLIENT_ERROR_CONFIG = 400002 # 请求数据配置不支持
API_STATUS_CODE_SERVER_ERROR = 500000
API_STATUS_CODE_SERVER_ERROR_RUNNING = 500001 # 服务器运行中出错
API_STATUS_CODE_SERVER_ERROR_RUNNING = 500001 # 服务器运行中出错

parser = argparse.ArgumentParser(description='ASRT HTTP+Json RESTful API Service')
parser.add_argument('--listen', default='0.0.0.0', type=str, help='the network to listen')
Expand All @@ -56,7 +56,7 @@
sm251bn = SpeechModel251BN(
input_shape=(AUDIO_LENGTH, AUDIO_FEATURE_LENGTH, CHANNELS),
output_size=OUTPUT_SIZE
)
)
feat = Spectrogram()
ms = ModelSpeech(sm251bn, feat, max_label_length=64)
ms.load_model('save_models/' + sm251bn.get_model_name() + '.model.h5')
Expand All @@ -66,59 +66,63 @@


class AsrtApiResponse:
'''
"""
ASRT语音识别基于HTTP协议的API接口响应类
'''
"""

def __init__(self, status_code, status_message='', result=''):
self.status_code = status_code
self.status_message = status_message
self.result = result

def to_json(self):
'''
"""
类转json
'''
return json.dumps(self, default=lambda o: o.__dict__,
sort_keys=True)
"""
return json.dumps(self, default=lambda o: o.__dict__, sort_keys=True)


# api接口根url:GET
@app.route('/', methods=["GET"])
def index_get():
'''
"""
根路径handle GET方法
'''
"""
buffer = ''
with open('assets/default.html', 'r', encoding='utf-8') as file_handle:
buffer = file_handle.read()
return Response(buffer, mimetype='text/html; charset=utf-8')


# api接口根url:POST
@app.route('/', methods=["POST"])
def index_post():
'''
"""
根路径handle POST方法
'''
"""
json_data = AsrtApiResponse(API_STATUS_CODE_OK, 'ok')
buffer = json_data.to_json()
return Response(buffer, mimetype='application/json')


# 获取分类列表
@app.route('/<level>', methods=["POST"])
def recognition_post(level):
'''
"""
其他路径 POST方法
'''
#读取json文件内容
"""
# 读取json文件内容
try:
if level == 'speech':
request_data = request.get_json()
samples = request_data['samples']
wavdata_bytes = base64.urlsafe_b64decode(bytes(samples,encoding='utf-8'))
wavdata_bytes = base64.urlsafe_b64decode(bytes(samples, encoding='utf-8'))
sample_rate = request_data['sample_rate']
channels = request_data['channels']
byte_width = request_data['byte_width']

wavdata = decode_wav_bytes(samples_data=wavdata_bytes,
channels=channels, byte_width=byte_width)
channels=channels, byte_width=byte_width)
result = ms.recognize_speech(wavdata, sample_rate)

json_data = AsrtApiResponse(API_STATUS_CODE_OK, 'speech level')
Expand Down Expand Up @@ -148,14 +152,14 @@ def recognition_post(level):
byte_width = request_data['byte_width']

wavdata = decode_wav_bytes(samples_data=wavdata_bytes,
channels=channels, byte_width=byte_width)
channels=channels, byte_width=byte_width)
result_speech = ms.recognize_speech(wavdata, sample_rate)
result = ml.pinyin_to_text(result_speech)

json_data = AsrtApiResponse(API_STATUS_CODE_OK, 'all level')
json_data.result = result
buffer = json_data.to_json()
print('ASRT Result:', result,'output:', buffer)
print('ASRT Result:', result, 'output:', buffer)
return Response(buffer, mimetype='application/json')
else:
request_data = request.get_json()
Expand All @@ -166,19 +170,20 @@ def recognition_post(level):
return Response(buffer, mimetype='application/json')
except Exception as except_general:
request_data = request.get_json()
#print(request_data['sample_rate'], request_data['channels'],
# print(request_data['sample_rate'], request_data['channels'],
# request_data['byte_width'], len(request_data['samples']),
# request_data['samples'][-100:])
json_data = AsrtApiResponse(API_STATUS_CODE_SERVER_ERROR, str(except_general))
buffer = json_data.to_json()
#print("input:", request_data, "\n", "output:", buffer)
# print("input:", request_data, "\n", "output:", buffer)
print("output:", buffer, "error:", except_general)
return Response(buffer, mimetype='application/json')


if __name__ == '__main__':
# for development env
#app.run(host='0.0.0.0', port=20001)
# app.run(host='0.0.0.0', port=20001)
# for production env
import waitress

waitress.serve(app, host=args.listen, port=args.port)
Loading

0 comments on commit 819ce7f

Please sign in to comment.