diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md
index dab7cb200f5..bde20e3d597 100755
--- a/com.unity.ml-agents/CHANGELOG.md
+++ b/com.unity.ml-agents/CHANGELOG.md
@@ -18,7 +18,7 @@ and this project adheres to
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
- Added the capacity to initialize behaviors from any checkpoint and not just the latest one (#5525)
-
+- Added the ability to get a read-only view of the stacked observations (#5523)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Set gym version in gym-unity to gym release 0.20.0
- Added support for having `beta`, `epsilon`, and `learning rate` on separate schedules (affects only PPO and POCA). (#5538)
diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs
index 4434125fb04..524378535a4 100644
--- a/com.unity.ml-agents/Runtime/Agent.cs
+++ b/com.unity.ml-agents/Runtime/Agent.cs
@@ -320,6 +320,11 @@ internal struct AgentParameters
///
internal VectorSensor collectObservationsSensor;
+ ///
+ /// StackingSensor which is written to by AddVectorObs
+ ///
+ internal StackingSensor stackedCollectObservationsSensor;
+
private RecursionChecker m_CollectObservationsChecker = new RecursionChecker("CollectObservations");
private RecursionChecker m_OnEpisodeBeginChecker = new RecursionChecker("OnEpisodeBegin");
@@ -981,9 +986,9 @@ internal void InitializeSensors()
collectObservationsSensor = new VectorSensor(param.VectorObservationSize);
if (param.NumStackedVectorObservations > 1)
{
- var stackingSensor = new StackingSensor(
+ stackedCollectObservationsSensor = new StackingSensor(
collectObservationsSensor, param.NumStackedVectorObservations);
- sensors.Add(stackingSensor);
+ sensors.Add(stackedCollectObservationsSensor);
}
else
{
@@ -1179,6 +1184,17 @@ public ReadOnlyCollection GetObservations()
return collectObservationsSensor.GetObservations();
}
+ ///
+ /// Returns a read-only view of the stacked observations that were generated in
+ /// . This is mainly useful inside of a
+ /// method to avoid recomputing the observations.
+ ///
+ /// A read-only view of the stacked observations list.
+ public ReadOnlyCollection GetStackedObservations()
+ {
+ return stackedCollectObservationsSensor.GetStackedObservations();
+ }
+
///
/// Implement `WriteDiscreteActionMask()` to collects the masks for discrete
/// actions. When using discrete actions, the agent will not perform the masked
diff --git a/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
index 50a80c286ae..710c58a821c 100644
--- a/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
+++ b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
@@ -1,4 +1,6 @@
using System;
+using System.Collections.Generic;
+using System.Collections.ObjectModel;
using System.Linq;
using UnityEngine;
using Unity.Barracuda;
@@ -279,5 +281,20 @@ public BuiltInSensorType GetBuiltInSensorType()
IBuiltInSensor wrappedBuiltInSensor = m_WrappedSensor as IBuiltInSensor;
return wrappedBuiltInSensor?.GetBuiltInSensorType() ?? BuiltInSensorType.Unknown;
}
+
+ ///
+ /// Returns the stacked observations as a read-only collection.
+ ///
+ /// The stacked observations as a read-only collection.
+ internal ReadOnlyCollection GetStackedObservations()
+ {
+ List observations = new List();
+ for (var i = 0; i < m_NumStackedObservations; i++)
+ {
+ var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
+ observations.AddRange(m_StackedObservations[obsIndex].ToList());
+ }
+ return observations.AsReadOnly();
+ }
}
}
diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs
index 2fe44646fe4..52fadf16c8e 100644
--- a/com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs
+++ b/com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs
@@ -76,29 +76,41 @@ public void AssertStackingReset()
public void TestVectorStacking()
{
VectorSensor wrapped = new VectorSensor(2);
- ISensor sensor = new StackingSensor(wrapped, 3);
+ StackingSensor sensor = new StackingSensor(wrapped, 3);
wrapped.AddObservation(new[] { 1f, 2f });
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f, 1f, 2f });
+ var data = sensor.GetStackedObservations();
+ Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 0f, 0f, 0f, 0f, 1f, 2f }));
sensor.Update();
wrapped.AddObservation(new[] { 3f, 4f });
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 1f, 2f, 3f, 4f });
+ data = sensor.GetStackedObservations();
+ Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 0f, 0f, 1f, 2f, 3f, 4f }));
sensor.Update();
wrapped.AddObservation(new[] { 5f, 6f });
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f, 5f, 6f });
+ data = sensor.GetStackedObservations();
+ Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 1f, 2f, 3f, 4f, 5f, 6f }));
sensor.Update();
wrapped.AddObservation(new[] { 7f, 8f });
SensorTestHelper.CompareObservation(sensor, new[] { 3f, 4f, 5f, 6f, 7f, 8f });
+ data = sensor.GetStackedObservations();
+ Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 3f, 4f, 5f, 6f, 7f, 8f }));
sensor.Update();
wrapped.AddObservation(new[] { 9f, 10f });
SensorTestHelper.CompareObservation(sensor, new[] { 5f, 6f, 7f, 8f, 9f, 10f });
+ data = sensor.GetStackedObservations();
+ Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 5f, 6f, 7f, 8f, 9f, 10f }));
// Check that if we don't call Update(), the same observations are produced
SensorTestHelper.CompareObservation(sensor, new[] { 5f, 6f, 7f, 8f, 9f, 10f });
+ data = sensor.GetStackedObservations();
+ Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 5f, 6f, 7f, 8f, 9f, 10f }));
}
[Test]