1
-
2
1
from __future__ import absolute_import , division , print_function
3
2
4
3
import logging
5
4
import os
6
5
import tensorflow as tf
7
6
from tensorflow .python .saved_model import builder as saved_model_builder
8
- from tensorflow .python .saved_model import (
9
- signature_constants , tag_constants )
7
+ from tensorflow .python .saved_model import (signature_constants , tag_constants )
10
8
11
9
12
10
def get_optimizer_by_name (optimizer_name , learning_rate ):
13
- """
11
+ """
14
12
Get optimizer object by the optimizer name.
15
13
16
14
Args:
@@ -21,30 +19,30 @@ def get_optimizer_by_name(optimizer_name, learning_rate):
21
19
The optimizer object.
22
20
"""
23
21
24
- logging .info ("Use the optimizer: {}" .format (optimizer_name ))
25
- if optimizer_name == "sgd" :
26
- optimizer = tf .train .GradientDescentOptimizer (learning_rate )
27
- elif optimizer_name == "adadelta" :
28
- optimizer = tf .train .AdadeltaOptimizer (learning_rate )
29
- elif optimizer_name == "adagrad" :
30
- optimizer = tf .train .AdagradOptimizer (learning_rate )
31
- elif optimizer_name == "adam" :
32
- optimizer = tf .train .AdamOptimizer (learning_rate )
33
- elif optimizer_name == "ftrl" :
34
- optimizer = tf .train .FtrlOptimizer (learning_rate )
35
- elif optimizer_name == "rmsprop" :
36
- optimizer = tf .train .RMSPropOptimizer (learning_rate )
37
- else :
38
- optimizer = tf .train .GradientDescentOptimizer (learning_rate )
39
- return optimizer
22
+ logging .info ("Use the optimizer: {}" .format (optimizer_name ))
23
+ if optimizer_name == "sgd" :
24
+ optimizer = tf .train .GradientDescentOptimizer (learning_rate )
25
+ elif optimizer_name == "adadelta" :
26
+ optimizer = tf .train .AdadeltaOptimizer (learning_rate )
27
+ elif optimizer_name == "adagrad" :
28
+ optimizer = tf .train .AdagradOptimizer (learning_rate )
29
+ elif optimizer_name == "adam" :
30
+ optimizer = tf .train .AdamOptimizer (learning_rate )
31
+ elif optimizer_name == "ftrl" :
32
+ optimizer = tf .train .FtrlOptimizer (learning_rate )
33
+ elif optimizer_name == "rmsprop" :
34
+ optimizer = tf .train .RMSPropOptimizer (learning_rate )
35
+ else :
36
+ optimizer = tf .train .GradientDescentOptimizer (learning_rate )
37
+ return optimizer
40
38
41
39
42
40
def save_model (model_path ,
43
- model_version ,
44
- sess ,
45
- signature_def_map ,
46
- is_save_graph = False ):
47
- """
41
+ model_version ,
42
+ sess ,
43
+ signature_def_map ,
44
+ is_save_graph = False ):
45
+ """
48
46
Save the model in standard SavedModel format.
49
47
50
48
Args:
@@ -58,36 +56,36 @@ def save_model(model_path,
58
56
None
59
57
"""
60
58
61
- export_path = os .path .join (model_path , str (model_version ))
62
- if os .path .isdir (export_path ) == True :
63
- logging .error ("The model exists in path: {}" .format (export_path ))
64
- return
59
+ export_path = os .path .join (model_path , str (model_version ))
60
+ if os .path .isdir (export_path ) == True :
61
+ logging .error ("The model exists in path: {}" .format (export_path ))
62
+ return
65
63
66
- try :
67
- # Save the SavedModel
68
- legacy_init_op = tf .group (tf .tables_initializer (), name = 'legacy_init_op' )
69
- builder = saved_model_builder .SavedModelBuilder (export_path )
70
- builder .add_meta_graph_and_variables (
71
- sess , [tag_constants .SERVING ],
72
- clear_devices = True ,
73
- signature_def_map = signature_def_map ,
74
- legacy_init_op = legacy_init_op )
75
- logging .info ("Save the model in: {}" .format (export_path ))
76
- builder .save ()
64
+ try :
65
+ # Save the SavedModel
66
+ legacy_init_op = tf .group (tf .tables_initializer (), name = 'legacy_init_op' )
67
+ builder = saved_model_builder .SavedModelBuilder (export_path )
68
+ builder .add_meta_graph_and_variables (
69
+ sess , [tag_constants .SERVING ],
70
+ clear_devices = True ,
71
+ signature_def_map = signature_def_map ,
72
+ legacy_init_op = legacy_init_op )
73
+ logging .info ("Save the model in: {}" .format (export_path ))
74
+ builder .save ()
77
75
78
- # Save the GraphDef
79
- if is_save_graph == True :
80
- graph_file_name = "graph.pb"
81
- logging .info ("Save the graph file in: {}" .format (model_path ))
82
- tf .train .write_graph (
83
- sess .graph_def , model_path , graph_file_name , as_text = False )
76
+ # Save the GraphDef
77
+ if is_save_graph == True :
78
+ graph_file_name = "graph.pb"
79
+ logging .info ("Save the graph file in: {}" .format (model_path ))
80
+ tf .train .write_graph (
81
+ sess .graph_def , model_path , graph_file_name , as_text = False )
84
82
85
- except Exception as e :
86
- logging .error ("Fail to export saved model, exception: {}" .format (e ))
83
+ except Exception as e :
84
+ logging .error ("Fail to export saved model, exception: {}" .format (e ))
87
85
88
86
89
87
def restore_from_checkpoint (sess , saver , checkpoint_file_path ):
90
- """
88
+ """
91
89
Restore session from checkpoint files.
92
90
93
91
Args:
@@ -98,11 +96,11 @@ def restore_from_checkpoint(sess, saver, checkpoint_file_path):
98
96
Return:
99
97
True if restore successfully and False if fail
100
98
"""
101
- if checkpoint_file_path :
102
- logging .info (
103
- "Restore session from checkpoint: {}" .format (checkpoint_file_path ))
104
- saver .restore (sess , checkpoint_file_path )
105
- return True
106
- else :
107
- logging .error ("Checkpoint not found: {}" .format (checkpoint_file_path ))
108
- return False
99
+ if checkpoint_file_path :
100
+ logging .info (
101
+ "Restore session from checkpoint: {}" .format (checkpoint_file_path ))
102
+ saver .restore (sess , checkpoint_file_path )
103
+ return True
104
+ else :
105
+ logging .error ("Checkpoint not found: {}" .format (checkpoint_file_path ))
106
+ return False
0 commit comments