Skip to content

Commit 63a400c

Browse files
committed
fix: replace import tf.keras to keras, update tiny rnnt model result
1 parent a4d411d commit 63a400c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+1013
-603
lines changed

.pre-commit-config.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,12 @@ repos:
99
stages: [pre-commit]
1010
fail_fast: true
1111
verbose: true
12+
- id: pylint-check
13+
name: pylint-check
14+
entry: pylint --rcfile=.pylintrc -rn -sn
15+
language: system
16+
types: [python]
17+
stages: [pre-commit]
18+
fail_fast: true
19+
require_serial: true
20+
verbose: true

.pylintrc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,11 @@ disable=too-few-public-methods,
114114
consider-using-enumerate,
115115
too-many-statements,
116116
assignment-from-none,
117-
eval-used
117+
eval-used,
118+
duplicate-code,
119+
redefined-outer-name,
120+
consider-using-f-string,
121+
fixme,
118122

119123
# Enable the message, report, category or checker with the given id(s). You can
120124
# either give multiple identifier separated by comma (,) or put this option

examples/inferences/main.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
import os
1616

1717
import tensorflow as tf
18+
import keras
1819

1920
from tensorflow_asr import schemas, tokenizers
21+
from tensorflow_asr.models import base_model
2022
from tensorflow_asr.configs import Config
2123
from tensorflow_asr.utils import cli_util, data_util, env_util, file_util
2224

@@ -35,7 +37,7 @@ def main(
3537
config = Config(config_path, training=False, repodir=repodir)
3638
tokenizer = tokenizers.get(config)
3739

38-
model: tf.keras.Model = tf.keras.models.model_from_config(config.model_config)
40+
model: base_model.BaseModel = keras.models.model_from_config(config.model_config)
3941
model.make(batch_size=1)
4042
model.load_weights(h5, by_name=file_util.is_hdf5_filepath(h5), skip_mismatch=False)
4143
model.summary()
@@ -44,7 +46,15 @@ def main(
4446
signal = tf.reshape(signal, [1, -1])
4547
signal_length = tf.reshape(tf.shape(signal)[1], [1])
4648

47-
outputs = model.recognize(schemas.PredictInput(signal, signal_length))
49+
outputs = model.recognize(
50+
schemas.PredictInput(
51+
inputs=signal,
52+
inputs_length=signal_length,
53+
previous_tokens=model.get_initial_tokens(),
54+
previous_encoder_states=model.get_initial_encoder_states(),
55+
previous_decoder_states=model.get_initial_decoder_states(),
56+
)
57+
)
4858
print(outputs.tokens)
4959
transcript = tokenizer.detokenize(outputs.tokens)[0].numpy().decode("utf-8")
5060

examples/inferences/rnn_transducer.py

Lines changed: 89 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,89 +1,89 @@
1-
# Copyright 2020 Huy Le Nguyen (@nglehuy)
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
14-
15-
import argparse
16-
17-
from tensorflow_asr.utils import data_util, env_util, math_util
18-
19-
logger = env_util.setup_environment()
20-
import tensorflow as tf
21-
22-
parser = argparse.ArgumentParser(prog="Rnn Transducer non streaming")
23-
24-
parser.add_argument("filename", metavar="FILENAME", help="audio file to be played back")
25-
26-
parser.add_argument("--config", type=str, default=None, help="Path to rnnt config yaml")
27-
28-
parser.add_argument("--saved", type=str, default=None, help="Path to rnnt saved h5 weights")
29-
30-
parser.add_argument("--beam_width", type=int, default=0, help="Beam width")
31-
32-
parser.add_argument("--timestamp", default=False, action="store_true", help="Return with timestamp")
33-
34-
parser.add_argument("--device", type=int, default=0, help="Device's id to run test on")
35-
36-
parser.add_argument("--cpu", default=False, action="store_true", help="Whether to only use cpu")
37-
38-
parser.add_argument("--subwords", default=False, action="store_true", help="Path to file that stores generated subwords")
39-
40-
parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model")
41-
42-
args = parser.parse_args()
43-
44-
env_util.setup_devices([args.device], cpu=args.cpu)
45-
46-
from tensorflow_asr.configs import Config
47-
from tensorflow_asr.features.speech_featurizers import SpeechFeaturizer, read_raw_audio
48-
from tensorflow_asr.models.transducer.rnnt import RnnTransducer
49-
from tensorflow_asr.tokenizers import CharTokenizer, SentencePieceTokenizer, SubwordFeaturizer
50-
51-
config = Config(args.config)
52-
speech_featurizer = SpeechFeaturizer(config.speech_config)
53-
if args.sentence_piece:
54-
logger.info("Loading SentencePiece model ...")
55-
text_featurizer = SentencePieceTokenizer(config.decoder_config)
56-
elif args.subwords:
57-
logger.info("Loading subwords ...")
58-
text_featurizer = SubwordFeaturizer(config.decoder_config)
59-
else:
60-
text_featurizer = CharTokenizer(config.decoder_config)
61-
text_featurizer.decoder_config.beam_width = args.beam_width
62-
63-
# build model
64-
rnnt = RnnTransducer(**config.model_config, vocab_size=text_featurizer.num_classes)
65-
rnnt.make(speech_featurizer.shape)
66-
rnnt.load_weights(args.saved, by_name=True, skip_mismatch=True)
67-
rnnt.summary()
68-
rnnt.add_featurizers(speech_featurizer, text_featurizer)
69-
70-
signal = read_raw_audio(args.filename)
71-
features = speech_featurizer.tf_extract(signal)
72-
input_length = math_util.get_reduced_length(tf.shape(features)[0], rnnt.time_reduction_factor)
73-
74-
if args.beam_width:
75-
transcript = rnnt.recognize_beam(data_util.create_inputs(inputs=features[None, ...], inputs_length=input_length[None, ...]))
76-
logger.info("Transcript:", transcript[0].numpy().decode("UTF-8"))
77-
elif args.timestamp:
78-
transcript, stime, etime, _, _, _ = rnnt.recognize_tflite_with_timestamp(
79-
signal=signal,
80-
predicted=tf.constant(text_featurizer.blank, dtype=tf.int32),
81-
encoder_states=rnnt.encoder.get_initial_state(),
82-
prediction_states=rnnt.predict_net.get_initial_state(),
83-
)
84-
logger.info("Transcript:", transcript)
85-
logger.info("Start time:", stime)
86-
logger.info("End time:", etime)
87-
else:
88-
transcript = rnnt.recognize(data_util.create_inputs(inputs=features[None, ...], inputs_length=input_length[None, ...]))
89-
logger.info("Transcript:", transcript[0].numpy().decode("UTF-8"))
1+
# # Copyright 2020 Huy Le Nguyen (@nglehuy)
2+
# #
3+
# # Licensed under the Apache License, Version 2.0 (the "License");
4+
# # you may not use this file except in compliance with the License.
5+
# # You may obtain a copy of the License at
6+
# #
7+
# # http://www.apache.org/licenses/LICENSE-2.0
8+
# #
9+
# # Unless required by applicable law or agreed to in writing, software
10+
# # distributed under the License is distributed on an "AS IS" BASIS,
11+
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# # See the License for the specific language governing permissions and
13+
# # limitations under the License.
14+
15+
# import argparse
16+
17+
# from tensorflow_asr.utils import data_util, env_util, math_util
18+
19+
# logger = env_util.setup_environment()
20+
# import tensorflow as tf
21+
22+
# parser = argparse.ArgumentParser(prog="Rnn Transducer non streaming")
23+
24+
# parser.add_argument("filename", metavar="FILENAME", help="audio file to be played back")
25+
26+
# parser.add_argument("--config", type=str, default=None, help="Path to rnnt config yaml")
27+
28+
# parser.add_argument("--saved", type=str, default=None, help="Path to rnnt saved h5 weights")
29+
30+
# parser.add_argument("--beam_width", type=int, default=0, help="Beam width")
31+
32+
# parser.add_argument("--timestamp", default=False, action="store_true", help="Return with timestamp")
33+
34+
# parser.add_argument("--device", type=int, default=0, help="Device's id to run test on")
35+
36+
# parser.add_argument("--cpu", default=False, action="store_true", help="Whether to only use cpu")
37+
38+
# parser.add_argument("--subwords", default=False, action="store_true", help="Path to file that stores generated subwords")
39+
40+
# parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model")
41+
42+
# args = parser.parse_args()
43+
44+
# env_util.setup_devices([args.device], cpu=args.cpu)
45+
46+
# from tensorflow_asr.configs import Config
47+
# from tensorflow_asr.features.speech_featurizers import SpeechFeaturizer, read_raw_audio
48+
# from tensorflow_asr.models.transducer.rnnt import RnnTransducer
49+
# from tensorflow_asr.tokenizers import CharTokenizer, SentencePieceTokenizer, SubwordFeaturizer
50+
51+
# config = Config(args.config)
52+
# speech_featurizer = SpeechFeaturizer(config.speech_config)
53+
# if args.sentence_piece:
54+
# logger.info("Loading SentencePiece model ...")
55+
# text_featurizer = SentencePieceTokenizer(config.decoder_config)
56+
# elif args.subwords:
57+
# logger.info("Loading subwords ...")
58+
# text_featurizer = SubwordFeaturizer(config.decoder_config)
59+
# else:
60+
# text_featurizer = CharTokenizer(config.decoder_config)
61+
# text_featurizer.decoder_config.beam_width = args.beam_width
62+
63+
# # build model
64+
# rnnt = RnnTransducer(**config.model_config, vocab_size=text_featurizer.num_classes)
65+
# rnnt.make(speech_featurizer.shape)
66+
# rnnt.load_weights(args.saved, by_name=True, skip_mismatch=True)
67+
# rnnt.summary()
68+
# rnnt.add_featurizers(speech_featurizer, text_featurizer)
69+
70+
# signal = read_raw_audio(args.filename)
71+
# features = speech_featurizer.tf_extract(signal)
72+
# input_length = math_util.get_reduced_length(tf.shape(features)[0], rnnt.time_reduction_factor)
73+
74+
# if args.beam_width:
75+
# transcript = rnnt.recognize_beam(data_util.create_inputs(inputs=features[None, ...], inputs_length=input_length[None, ...]))
76+
# logger.info("Transcript:", transcript[0].numpy().decode("UTF-8"))
77+
# elif args.timestamp:
78+
# transcript, stime, etime, _, _, _ = rnnt.recognize_tflite_with_timestamp(
79+
# signal=signal,
80+
# predicted=tf.constant(text_featurizer.blank, dtype=tf.int32),
81+
# encoder_states=rnnt.encoder.get_initial_state(),
82+
# prediction_states=rnnt.predict_net.get_initial_state(),
83+
# )
84+
# logger.info("Transcript:", transcript)
85+
# logger.info("Start time:", stime)
86+
# logger.info("End time:", etime)
87+
# else:
88+
# transcript = rnnt.recognize(data_util.create_inputs(inputs=features[None, ...], inputs_length=input_length[None, ...]))
89+
# logger.info("Transcript:", transcript[0].numpy().decode("UTF-8"))

0 commit comments

Comments
 (0)