-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathoptimistic_restore.py
More file actions
25 lines (21 loc) · 989 Bytes
/
optimistic_restore.py
File metadata and controls
25 lines (21 loc) · 989 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import tensorflow as tf
def optimistic_restore(session, save_file):
reader = tf.train.NewCheckpointReader(save_file)
saved_shapes = reader.get_variable_to_shape_map()
var_names = sorted([(var.name, var.name.split(':')[0])
for var in tf.global_variables()
if var.name.split(':')[0] in saved_shapes])
restore_vars = []
name2var = dict(
zip(map(lambda x: x.name.split(':')[0], tf.global_variables()),
tf.global_variables()))
with tf.variable_scope('', reuse=True):
for var_name, saved_var_name in var_names:
curr_var = name2var[saved_var_name]
var_shape = curr_var.get_shape().as_list()
if var_shape == saved_shapes[saved_var_name]:
restore_vars.append(curr_var)
saver = tf.train.Saver(restore_vars)
saver.restore(session, save_file)
#optimistic_restore(U.get_session(), arglist.load_dir)
# tf.train.get_or_create_global_step()