diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 27f19a1e177a..d923c25044e1 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -468,6 +468,20 @@ def _impl(inputs, attr, params): ignores=['index_type', 'T'])(new_inputs, attr) return _impl +def _lrn(): + def _impl(inputs, attr, params): + new_inputs = [] + attr_new = {} + depth_radius = attr.get('depth_radius', 5) + size = (depth_radius * 2) + 1 + attr_new['axis'] = 3 # Fix axis, NHWC format + attr_new['size'] = size + attr_new['bias'] = attr.get('bias', 1) + attr_new['alpha'] = attr.get('alpha', 1) * size + attr_new['beta'] = attr.get('beta', 0.5) + return AttrCvt(op_name='lrn')(new_inputs, attr_new) + return _impl + def _gather_v2(): "Tensorflow now support only gatherv2" def _impl(inputs, attr, params): @@ -680,6 +694,7 @@ def _impl(inputs, in_state_c, in_state_h, attr, params): 'Fill' : _fill(), 'GatherV2' : _gather_v2(), 'StridedSlice' : _stridedSlice(), + 'LRN' : _lrn(), } # _convert_map_rnn defines maps of rnn operator name to diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index 045d154d9d8b..a5c5fdcfed2f 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -855,6 +855,40 @@ def _get_sample(data, state): assert(tvm_sample_str == tf_sample_str) ####################################################################### +# LRN (Local Response Normalization) +# ---------------------------------- + +def _test_lrn(ishape, size, axis, bias, alpha, beta): + """ testing local response normalization """ + lrn_depth_radius = size / 2 + + inp_array = np.random.uniform(size=ishape).astype(np.float32) + + with tf.Graph().as_default(): + in1 = tf.placeholder(shape=inp_array.shape, dtype=inp_array.dtype, name="lrn0_data") + nn_ops.local_response_normalization(in1, + name="lrn", + depth_radius=lrn_depth_radius, + bias=bias, + alpha=alpha, + beta=beta) + + with tf.Session() as sess: + graph_def = tf.graph_util.convert_variables_to_constants( + sess, + sess.graph.as_graph_def(add_shapes=True), + ['lrn'],) + + tf_output = run_tf_graph(sess, inp_array, 'lrn0_data:0', 'lrn:0') + tvm_output = run_tvm_graph(graph_def, + inp_array, + "lrn0_data", tf_output.shape, tf_output.dtype) + np.testing.assert_allclose(tf_output, tvm_output, atol=1e-3, rtol=1e-3) + sess.close() + +def test_forward_lrn(): + _test_lrn((1, 3, 20, 20), 3, 1, 1.0, 1.0, 0.5) + # Main # ---- if __name__ == '__main__': @@ -875,3 +909,4 @@ def _get_sample(data, state): test_forward_stridedslice() test_forward_gather() test_forward_ptb() + test_forward_lrn()