Advertisement
Guest User

Untitled

a guest
May 27th, 2018
136
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.89 KB | None | 0 0
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. #     http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. r"""Converts a trained checkpoint into a frozen model for mobile inference.
  16.  
  17. Once you've trained a model using the `train.py` script, you can use this tool
  18. to convert it into a binary GraphDef file that can be loaded into the Android,
  19. iOS, or Raspberry Pi example code. Here's an example of how to run it:
  20.  
  21. bazel run tensorflow/examples/speech_commands/freeze -- \
  22. --sample_rate=16000 --dct_coefficient_count=40 --window_size_ms=20 \
  23. --window_stride_ms=10 --clip_duration_ms=1000 \
  24. --model_architecture=conv \
  25. --start_checkpoint=/tmp/speech_commands_train/conv.ckpt-1300 \
  26. --output_file=/tmp/my_frozen_graph.pb
  27.  
  28. One thing to watch out for is that you need to pass in the same arguments for
  29. `sample_rate` and other command line variables here as you did for the training
  30. script.
  31.  
  32. The resulting graph has an input for WAV-encoded data named 'wav_data', one for
  33. raw PCM data (as floats in the range -1.0 to 1.0) called 'decoded_sample_data',
  34. and the output is called 'labels_softmax'.
  35.  
  36. """
  37. from __future__ import absolute_import
  38. from __future__ import division
  39. from __future__ import print_function
  40.  
  41. import argparse
  42. import os.path
  43. import sys
  44.  
  45. import tensorflow as tf
  46.  
  47. from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
  48. import input_data
  49. import models
  50. from tensorflow.python.framework import graph_util
  51.  
  52. FLAGS = None
  53.  
  54.  
  55. def create_inference_graph(wanted_words, sample_rate, clip_duration_ms,
  56.                            clip_stride_ms, window_size_ms, window_stride_ms,
  57.                            dct_coefficient_count, model_architecture):
  58.   """Creates an audio model with the nodes needed for inference.
  59.  
  60.  Uses the supplied arguments to create a model, and inserts the input and
  61.  output nodes that are needed to use the graph for inference.
  62.  
  63.  Args:
  64.    wanted_words: Comma-separated list of the words we're trying to recognize.
  65.    sample_rate: How many samples per second are in the input audio files.
  66.    clip_duration_ms: How many samples to analyze for the audio pattern.
  67.    clip_stride_ms: How often to run recognition. Useful for models with cache.
  68.    window_size_ms: Time slice duration to estimate frequencies from.
  69.    window_stride_ms: How far apart time slices should be.
  70.    dct_coefficient_count: Number of frequency bands to analyze.
  71.    model_architecture: Name of the kind of model to generate.
  72.  """
  73.  
  74.   words_list = input_data.prepare_words_list(wanted_words.split(','))
  75.   model_settings = models.prepare_model_settings(
  76.       len(words_list), sample_rate, clip_duration_ms, window_size_ms,
  77.       window_stride_ms, dct_coefficient_count)
  78.   runtime_settings = {'clip_stride_ms': clip_stride_ms}
  79.  
  80.   wav_data_placeholder = tf.placeholder(tf.string, [], name='wav_data')
  81.   decoded_sample_data = contrib_audio.decode_wav(
  82.       wav_data_placeholder,
  83.       desired_channels=1,
  84.       desired_samples=model_settings['desired_samples'],
  85.       name='decoded_sample_data')
  86.   spectrogram = contrib_audio.audio_spectrogram(
  87.       decoded_sample_data.audio,
  88.       window_size=model_settings['window_size_samples'],
  89.       stride=model_settings['window_stride_samples'],
  90.       magnitude_squared=True)
  91.   fingerprint_input = contrib_audio.mfcc(
  92.       spectrogram,
  93.       decoded_sample_data.sample_rate,
  94.       dct_coefficient_count=dct_coefficient_count)
  95.   fingerprint_frequency_size = model_settings['dct_coefficient_count']
  96.   fingerprint_time_size = model_settings['spectrogram_length']
  97.   reshaped_input = tf.reshape(fingerprint_input, [
  98.       -1, fingerprint_time_size * fingerprint_frequency_size
  99.   ])
  100.  
  101.   logits = models.create_model(
  102.       reshaped_input, model_settings, model_architecture, is_training=False,
  103.       runtime_settings=runtime_settings)
  104.  
  105.   # Create an output to use for inference.
  106.   tf.nn.softmax(logits, name='labels_softmax')
  107.  
  108.  
  109. def main(_):
  110.  
  111.   # Create the model and load its weights.
  112.   sess = tf.InteractiveSession()
  113.   create_inference_graph(FLAGS.wanted_words, FLAGS.sample_rate,
  114.                          FLAGS.clip_duration_ms, FLAGS.clip_stride_ms,
  115.                          FLAGS.window_size_ms, FLAGS.window_stride_ms,
  116.                          FLAGS.dct_coefficient_count, FLAGS.model_architecture)
  117.   models.load_variables_from_checkpoint(sess, FLAGS.start_checkpoint)
  118.  
  119.   # Turn all the variables into inline constants inside the graph and save it.
  120.  
  121.   if input_graph:
  122.     input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
  123.   input_meta_graph_def = None
  124.   if input_meta_graph:
  125.     input_meta_graph_def = _parse_input_meta_graph_proto(
  126.         input_meta_graph, input_binary)
  127.   input_saver_def = None
  128.   if input_saver:
  129.     input_saver_def = _parse_input_saver_proto(input_saver, input_binary)
  130.   freeze_graph_with_def_protos(
  131.       input_graph_def,
  132.       input_saver_def,
  133.       input_checkpoint,
  134.       output_node_names,
  135.       restore_op_name,
  136.       filename_tensor_name,
  137.       output_graph,
  138.       clear_devices,
  139.       initializer_nodes,
  140.       variable_names_whitelist,
  141.       variable_names_blacklist,
  142.       input_meta_graph_def,
  143.       input_saved_model_dir,
  144.       saved_model_tags.replace(" ", "").split(","),
  145.       checkpoint_version=checkpoint_version)
  146.  
  147. if __name__ == '__main__':
  148.   parser = argparse.ArgumentParser()
  149.   parser.add_argument(
  150.       '--sample_rate',
  151.       type=int,
  152.       default=16000,
  153.       help='Expected sample rate of the wavs',)
  154.   parser.add_argument(
  155.       '--clip_duration_ms',
  156.       type=int,
  157.       default=1000,
  158.       help='Expected duration in milliseconds of the wavs',)
  159.   parser.add_argument(
  160.       '--clip_stride_ms',
  161.       type=int,
  162.       default=30,
  163.       help='How often to run recognition. Useful for models with cache.',)
  164.   parser.add_argument(
  165.       '--window_size_ms',
  166.       type=float,
  167.       default=30.0,
  168.       help='How long each spectrogram timeslice is',)
  169.   parser.add_argument(
  170.       '--window_stride_ms',
  171.       type=float,
  172.       default=10.0,
  173.       help='How long the stride is between spectrogram timeslices',)
  174.   parser.add_argument(
  175.       '--dct_coefficient_count',
  176.       type=int,
  177.       default=40,
  178.       help='How many bins to use for the MFCC fingerprint',)
  179.   parser.add_argument(
  180.       '--start_checkpoint',
  181.       type=str,
  182.       default='',
  183.       help='If specified, restore this pretrained model before any training.')
  184.   parser.add_argument(
  185.       '--model_architecture',
  186.       type=str,
  187.       default='conv',
  188.       help='What model architecture to use')
  189.   parser.add_argument(
  190.       '--wanted_words',
  191.       type=str,
  192.       default='zero,one,two,three,four,five,six,seven,eight,nine',
  193.       help='Words to use (others will be added to an unknown label)',)
  194.   parser.add_argument(
  195.       '--output_file',
  196.       type=str,
  197.       default='F:\fuck',
  198.       help='Where to save the frozen graph.')
  199.   parser.add_argument(
  200.       "--filename_tensor_name",
  201.       type=str,
  202.       default="save/Const:0",
  203.       help="""\
  204.      The name of the tensor holding the save path. Deprecated, unused by \
  205.      updated loading code.
  206.      """)
  207.   parser.add_argument(
  208.       "--input_graph",
  209.       type=str,
  210.       default="",
  211.       help="TensorFlow \'GraphDef\' file to load.")
  212.  
  213.   FLAGS, unparsed = parser.parse_known_args()
  214.   tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement