diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index dab7cb200f5..bde20e3d597 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -18,7 +18,7 @@ and this project adheres to #### com.unity.ml-agents / com.unity.ml-agents.extensions (C#) - Added the capacity to initialize behaviors from any checkpoint and not just the latest one (#5525) - +- Added the ability to get a read-only view of the stacked observations (#5523) #### ml-agents / ml-agents-envs / gym-unity (Python) - Set gym version in gym-unity to gym release 0.20.0 - Added support for having `beta`, `epsilon`, and `learning rate` on separate schedules (affects only PPO and POCA). (#5538) diff --git a/com.unity.ml-agents/Runtime/Agent.cs b/com.unity.ml-agents/Runtime/Agent.cs index 4434125fb04..524378535a4 100644 --- a/com.unity.ml-agents/Runtime/Agent.cs +++ b/com.unity.ml-agents/Runtime/Agent.cs @@ -320,6 +320,11 @@ internal struct AgentParameters /// internal VectorSensor collectObservationsSensor; + /// + /// StackingSensor which is written to by AddVectorObs + /// + internal StackingSensor stackedCollectObservationsSensor; + private RecursionChecker m_CollectObservationsChecker = new RecursionChecker("CollectObservations"); private RecursionChecker m_OnEpisodeBeginChecker = new RecursionChecker("OnEpisodeBegin"); @@ -981,9 +986,9 @@ internal void InitializeSensors() collectObservationsSensor = new VectorSensor(param.VectorObservationSize); if (param.NumStackedVectorObservations > 1) { - var stackingSensor = new StackingSensor( + stackedCollectObservationsSensor = new StackingSensor( collectObservationsSensor, param.NumStackedVectorObservations); - sensors.Add(stackingSensor); + sensors.Add(stackedCollectObservationsSensor); } else { @@ -1179,6 +1184,17 @@ public ReadOnlyCollection GetObservations() return collectObservationsSensor.GetObservations(); } + /// + /// Returns a read-only view of the stacked observations that were generated in + /// . This is mainly useful inside of a + /// method to avoid recomputing the observations. + /// + /// A read-only view of the stacked observations list. + public ReadOnlyCollection GetStackedObservations() + { + return stackedCollectObservationsSensor.GetStackedObservations(); + } + /// /// Implement `WriteDiscreteActionMask()` to collects the masks for discrete /// actions. When using discrete actions, the agent will not perform the masked diff --git a/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs index 50a80c286ae..710c58a821c 100644 --- a/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs @@ -1,4 +1,6 @@ using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; using System.Linq; using UnityEngine; using Unity.Barracuda; @@ -279,5 +281,20 @@ public BuiltInSensorType GetBuiltInSensorType() IBuiltInSensor wrappedBuiltInSensor = m_WrappedSensor as IBuiltInSensor; return wrappedBuiltInSensor?.GetBuiltInSensorType() ?? BuiltInSensorType.Unknown; } + + /// + /// Returns the stacked observations as a read-only collection. + /// + /// The stacked observations as a read-only collection. + internal ReadOnlyCollection GetStackedObservations() + { + List observations = new List(); + for (var i = 0; i < m_NumStackedObservations; i++) + { + var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations; + observations.AddRange(m_StackedObservations[obsIndex].ToList()); + } + return observations.AsReadOnly(); + } } } diff --git a/com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs b/com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs index 2fe44646fe4..52fadf16c8e 100644 --- a/com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs +++ b/com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs @@ -76,29 +76,41 @@ public void AssertStackingReset() public void TestVectorStacking() { VectorSensor wrapped = new VectorSensor(2); - ISensor sensor = new StackingSensor(wrapped, 3); + StackingSensor sensor = new StackingSensor(wrapped, 3); wrapped.AddObservation(new[] { 1f, 2f }); SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f, 1f, 2f }); + var data = sensor.GetStackedObservations(); + Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 0f, 0f, 0f, 0f, 1f, 2f })); sensor.Update(); wrapped.AddObservation(new[] { 3f, 4f }); SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 1f, 2f, 3f, 4f }); + data = sensor.GetStackedObservations(); + Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 0f, 0f, 1f, 2f, 3f, 4f })); sensor.Update(); wrapped.AddObservation(new[] { 5f, 6f }); SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f, 5f, 6f }); + data = sensor.GetStackedObservations(); + Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 1f, 2f, 3f, 4f, 5f, 6f })); sensor.Update(); wrapped.AddObservation(new[] { 7f, 8f }); SensorTestHelper.CompareObservation(sensor, new[] { 3f, 4f, 5f, 6f, 7f, 8f }); + data = sensor.GetStackedObservations(); + Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 3f, 4f, 5f, 6f, 7f, 8f })); sensor.Update(); wrapped.AddObservation(new[] { 9f, 10f }); SensorTestHelper.CompareObservation(sensor, new[] { 5f, 6f, 7f, 8f, 9f, 10f }); + data = sensor.GetStackedObservations(); + Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 5f, 6f, 7f, 8f, 9f, 10f })); // Check that if we don't call Update(), the same observations are produced SensorTestHelper.CompareObservation(sensor, new[] { 5f, 6f, 7f, 8f, 9f, 10f }); + data = sensor.GetStackedObservations(); + Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 5f, 6f, 7f, 8f, 9f, 10f })); } [Test]