@@ -438,6 +438,7 @@ template <typename TKey, typename TValue>
438438class KvResourceGatherOp : public OpKernel {
439439 public:
440440 explicit KvResourceGatherOp (OpKernelConstruction* c) : OpKernel(c) {
441+ OP_REQUIRES_OK (c, c->GetAttr (" is_inference" , &is_inference_));
441442 OP_REQUIRES_OK (c,
442443 c->GetAttr (" is_use_default_value_tensor" ,
443444 &is_use_default_value_tensor_));
@@ -461,6 +462,17 @@ class KvResourceGatherOp : public OpKernel {
461462 return 1 ;
462463 };
463464 }
465+ if (!is_inference_) {
466+ lookup_fn_ = [](EmbeddingVar<TKey, TValue>* ev, TKey key,
467+ TValue* val, TValue* default_v, int count) {
468+ ev->LookupOrCreate (key, val, default_v, count);
469+ };
470+ } else {
471+ lookup_fn_ = [](EmbeddingVar<TKey, TValue>* ev, TKey key,
472+ TValue* val, TValue* default_v, int count) {
473+ ev->Lookup (key, val, default_v);
474+ };
475+ }
464476 }
465477
466478 void Compute (OpKernelContext* c) override {
@@ -511,7 +523,7 @@ class KvResourceGatherOp : public OpKernel {
511523 default_v, indices_flat (i), i, ev->GetDefaultValueDim (),
512524 ev->ValueLen ());
513525 int32 count = get_count_fn_ (counts, i);
514- ev-> LookupOrCreate ( indices_flat (i),
526+ lookup_fn_ (ev, indices_flat (i),
515527 out_base + i * slice_elems, default_v_ptr, count);
516528 }
517529 };
@@ -530,9 +542,12 @@ class KvResourceGatherOp : public OpKernel {
530542
531543 private:
532544 bool is_use_default_value_tensor_;
545+ bool is_inference_;
533546 std::function<
534547 TValue*(TValue*, TKey, int64, int64, int64)> get_default_v_fn_;
535548 std::function<int32(int32*, int64)> get_count_fn_;
549+ std::function<void (EmbeddingVar<TKey, TValue>* ev,
550+ TKey key, TValue* val, TValue* default_v, int count)> lookup_fn_;
536551};
537552
538553#define REGISTER_GATHER_FULL (dev, ktype, vtype ) \
0 commit comments