Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ and this project adheres to

#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
- Added the capacity to initialize behaviors from any checkpoint and not just the latest one (#5525)

- Added the ability to get a read-only view of the stacked observations (#5523)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Set gym version in gym-unity to gym release 0.20.0
- Added support for having `beta`, `epsilon`, and `learning rate` on separate schedules (affects only PPO and POCA). (#5538)
Expand Down
20 changes: 18 additions & 2 deletions com.unity.ml-agents/Runtime/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,11 @@ internal struct AgentParameters
/// </summary>
internal VectorSensor collectObservationsSensor;

/// <summary>
/// StackingSensor which is written to by AddVectorObs
/// </summary>
internal StackingSensor stackedCollectObservationsSensor;

private RecursionChecker m_CollectObservationsChecker = new RecursionChecker("CollectObservations");
private RecursionChecker m_OnEpisodeBeginChecker = new RecursionChecker("OnEpisodeBegin");

Expand Down Expand Up @@ -981,9 +986,9 @@ internal void InitializeSensors()
collectObservationsSensor = new VectorSensor(param.VectorObservationSize);
if (param.NumStackedVectorObservations > 1)
{
var stackingSensor = new StackingSensor(
stackedCollectObservationsSensor = new StackingSensor(
collectObservationsSensor, param.NumStackedVectorObservations);
sensors.Add(stackingSensor);
sensors.Add(stackedCollectObservationsSensor);
}
else
{
Expand Down Expand Up @@ -1179,6 +1184,17 @@ public ReadOnlyCollection<float> GetObservations()
return collectObservationsSensor.GetObservations();
}

/// <summary>
/// Returns a read-only view of the stacked observations that were generated in
/// <see cref="CollectObservations(VectorSensor)"/>. This is mainly useful inside of a
/// <see cref="Heuristic(in ActionBuffers)"/> method to avoid recomputing the observations.
/// </summary>
/// <returns>A read-only view of the stacked observations list.</returns>
public ReadOnlyCollection<float> GetStackedObservations()
{
return stackedCollectObservationsSensor.GetStackedObservations();
}

/// <summary>
/// Implement `WriteDiscreteActionMask()` to collects the masks for discrete
/// actions. When using discrete actions, the agent will not perform the masked
Expand Down
17 changes: 17 additions & 0 deletions com.unity.ml-agents/Runtime/Sensors/StackingSensor.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Linq;
using UnityEngine;
using Unity.Barracuda;
Expand Down Expand Up @@ -279,5 +281,20 @@ public BuiltInSensorType GetBuiltInSensorType()
IBuiltInSensor wrappedBuiltInSensor = m_WrappedSensor as IBuiltInSensor;
return wrappedBuiltInSensor?.GetBuiltInSensorType() ?? BuiltInSensorType.Unknown;
}

/// <summary>
/// Returns the stacked observations as a read-only collection.
/// </summary>
/// <returns>The stacked observations as a read-only collection.</returns>
internal ReadOnlyCollection<float> GetStackedObservations()
{
List<float> observations = new List<float>();
for (var i = 0; i < m_NumStackedObservations; i++)
{
var obsIndex = (m_CurrentIndex + 1 + i) % m_NumStackedObservations;
observations.AddRange(m_StackedObservations[obsIndex].ToList());
}
return observations.AsReadOnly();
}
}
}
14 changes: 13 additions & 1 deletion com.unity.ml-agents/Tests/Runtime/Sensor/StackingSensorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,29 +76,41 @@ public void AssertStackingReset()
public void TestVectorStacking()
{
VectorSensor wrapped = new VectorSensor(2);
ISensor sensor = new StackingSensor(wrapped, 3);
StackingSensor sensor = new StackingSensor(wrapped, 3);

wrapped.AddObservation(new[] { 1f, 2f });
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f, 1f, 2f });
var data = sensor.GetStackedObservations();
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 0f, 0f, 0f, 0f, 1f, 2f }));

sensor.Update();
wrapped.AddObservation(new[] { 3f, 4f });
SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 1f, 2f, 3f, 4f });
data = sensor.GetStackedObservations();
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 0f, 0f, 1f, 2f, 3f, 4f }));

sensor.Update();
wrapped.AddObservation(new[] { 5f, 6f });
SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f, 5f, 6f });
data = sensor.GetStackedObservations();
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 1f, 2f, 3f, 4f, 5f, 6f }));

sensor.Update();
wrapped.AddObservation(new[] { 7f, 8f });
SensorTestHelper.CompareObservation(sensor, new[] { 3f, 4f, 5f, 6f, 7f, 8f });
data = sensor.GetStackedObservations();
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 3f, 4f, 5f, 6f, 7f, 8f }));

sensor.Update();
wrapped.AddObservation(new[] { 9f, 10f });
SensorTestHelper.CompareObservation(sensor, new[] { 5f, 6f, 7f, 8f, 9f, 10f });
data = sensor.GetStackedObservations();
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 5f, 6f, 7f, 8f, 9f, 10f }));

// Check that if we don't call Update(), the same observations are produced
SensorTestHelper.CompareObservation(sensor, new[] { 5f, 6f, 7f, 8f, 9f, 10f });
data = sensor.GetStackedObservations();
Assert.IsTrue(data.ToArray().SequenceEqual(new[] { 5f, 6f, 7f, 8f, 9f, 10f }));
}

[Test]
Expand Down