diff --git a/com.unity.ml-agents/Runtime/Academy.cs b/com.unity.ml-agents/Runtime/Academy.cs
index 92d67b12d08..745ed3cfc80 100644
--- a/com.unity.ml-agents/Runtime/Academy.cs
+++ b/com.unity.ml-agents/Runtime/Academy.cs
@@ -97,7 +97,7 @@ public class Academy : IDisposable
///
/// -
/// 1.5.0
- /// Support variable length observation training.
+ /// Support variable length observation training and multi-agent groups.
///
///
///
diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
index 1dc768e62e1..b5d8a856329 100644
--- a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
+++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs
@@ -20,6 +20,11 @@ namespace Unity.MLAgents
internal static class GrpcExtensions
{
#region AgentInfo
+ ///
+ /// Static flag to make sure that we only fire the warning once.
+ ///
+ private static bool s_HaveWarnedTrainerCapabilitiesAgentGroup = false;
+
///
/// Converts a AgentInfo to a protobuf generated AgentInfoActionPairProto
///
@@ -55,6 +60,22 @@ public static AgentInfoActionPairProto ToInfoActionPairProto(this AgentInfo ai)
/// The protobuf version of the AgentInfo.
public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai)
{
+ if(ai.groupId > 0)
+ {
+ var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.MultiAgentGroups;
+ if (!trainerCanHandle)
+ {
+ if (!s_HaveWarnedTrainerCapabilitiesAgentGroup)
+ {
+ Debug.LogWarning(
+ $"Attached trainer doesn't support Multi Agent Groups; group rewards will be ignored." +
+ "Please find the versions that work best together from our release page: " +
+ "https://github.com/Unity-Technologies/ml-agents/releases"
+ );
+ s_HaveWarnedTrainerCapabilitiesAgentGroup = true;
+ }
+ }
+ }
var agentInfoProto = new AgentInfoProto
{
Reward = ai.reward,
@@ -457,6 +478,7 @@ public static UnityRLCapabilities ToRLCapabilities(this UnityRLCapabilitiesProto
HybridActions = proto.HybridActions,
TrainingAnalytics = proto.TrainingAnalytics,
VariableLengthObservation = proto.VariableLengthObservation,
+ MultiAgentGroups = proto.MultiAgentGroups,
};
}
@@ -470,6 +492,7 @@ public static UnityRLCapabilitiesProto ToProto(this UnityRLCapabilities rlCaps)
HybridActions = rlCaps.HybridActions,
TrainingAnalytics = rlCaps.TrainingAnalytics,
VariableLengthObservation = rlCaps.VariableLengthObservation,
+ MultiAgentGroups = rlCaps.MultiAgentGroups,
};
}
diff --git a/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs b/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs
index bd495358d53..1914ecba18c 100644
--- a/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs
+++ b/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs
@@ -10,6 +10,7 @@ internal class UnityRLCapabilities
public bool HybridActions;
public bool TrainingAnalytics;
public bool VariableLengthObservation;
+ public bool MultiAgentGroups;
///
/// A class holding the capabilities flags for Reinforcement Learning across C# and the Trainer codebase. This
@@ -21,7 +22,8 @@ public UnityRLCapabilities(
bool compressedChannelMapping = true,
bool hybridActions = true,
bool trainingAnalytics = true,
- bool variableLengthObservation = true)
+ bool variableLengthObservation = true,
+ bool multiAgentGroups = true)
{
BaseRLCapabilities = baseRlCapabilities;
ConcatenatedPngObservations = concatenatedPngObservations;
@@ -29,6 +31,7 @@ public UnityRLCapabilities(
HybridActions = hybridActions;
TrainingAnalytics = trainingAnalytics;
VariableLengthObservation = variableLengthObservation;
+ MultiAgentGroups = multiAgentGroups;
}
///
diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs
index 2e370e596af..ac267f4c2f0 100644
--- a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs
+++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs
@@ -25,17 +25,18 @@ static CapabilitiesReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NhcGFiaWxp",
- "dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi0gEKGFVuaXR5UkxD",
+ "dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi7AEKGFVuaXR5UkxD",
"YXBhYmlsaXRpZXNQcm90bxIaChJiYXNlUkxDYXBhYmlsaXRpZXMYASABKAgS",
"IwobY29uY2F0ZW5hdGVkUG5nT2JzZXJ2YXRpb25zGAIgASgIEiAKGGNvbXBy",
"ZXNzZWRDaGFubmVsTWFwcGluZxgDIAEoCBIVCg1oeWJyaWRBY3Rpb25zGAQg",
"ASgIEhkKEXRyYWluaW5nQW5hbHl0aWNzGAUgASgIEiEKGXZhcmlhYmxlTGVu",
- "Z3RoT2JzZXJ2YXRpb24YBiABKAhCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11",
- "bmljYXRvck9iamVjdHNiBnByb3RvMw=="));
+ "Z3RoT2JzZXJ2YXRpb24YBiABKAgSGAoQbXVsdGlBZ2VudEdyb3VwcxgHIAEo",
+ "CEIlqgIiVW5pdHkuTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJv",
+ "dG8z"));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
- new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping", "HybridActions", "TrainingAnalytics", "VariableLengthObservation" }, null, null, null)
+ new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping", "HybridActions", "TrainingAnalytics", "VariableLengthObservation", "MultiAgentGroups" }, null, null, null)
}));
}
#endregion
@@ -78,6 +79,7 @@ public UnityRLCapabilitiesProto(UnityRLCapabilitiesProto other) : this() {
hybridActions_ = other.hybridActions_;
trainingAnalytics_ = other.trainingAnalytics_;
variableLengthObservation_ = other.variableLengthObservation_;
+ multiAgentGroups_ = other.multiAgentGroups_;
_unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields);
}
@@ -170,6 +172,20 @@ public bool VariableLengthObservation {
}
}
+ /// Field number for the "multiAgentGroups" field.
+ public const int MultiAgentGroupsFieldNumber = 7;
+ private bool multiAgentGroups_;
+ ///
+ /// Support for multi agent groups and group rewards
+ ///
+ [global::System.Diagnostics.DebuggerNonUserCodeAttribute]
+ public bool MultiAgentGroups {
+ get { return multiAgentGroups_; }
+ set {
+ multiAgentGroups_ = value;
+ }
+ }
+
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as UnityRLCapabilitiesProto);
@@ -189,6 +205,7 @@ public bool Equals(UnityRLCapabilitiesProto other) {
if (HybridActions != other.HybridActions) return false;
if (TrainingAnalytics != other.TrainingAnalytics) return false;
if (VariableLengthObservation != other.VariableLengthObservation) return false;
+ if (MultiAgentGroups != other.MultiAgentGroups) return false;
return Equals(_unknownFields, other._unknownFields);
}
@@ -201,6 +218,7 @@ public override int GetHashCode() {
if (HybridActions != false) hash ^= HybridActions.GetHashCode();
if (TrainingAnalytics != false) hash ^= TrainingAnalytics.GetHashCode();
if (VariableLengthObservation != false) hash ^= VariableLengthObservation.GetHashCode();
+ if (MultiAgentGroups != false) hash ^= MultiAgentGroups.GetHashCode();
if (_unknownFields != null) {
hash ^= _unknownFields.GetHashCode();
}
@@ -238,6 +256,10 @@ public void WriteTo(pb::CodedOutputStream output) {
output.WriteRawTag(48);
output.WriteBool(VariableLengthObservation);
}
+ if (MultiAgentGroups != false) {
+ output.WriteRawTag(56);
+ output.WriteBool(MultiAgentGroups);
+ }
if (_unknownFields != null) {
_unknownFields.WriteTo(output);
}
@@ -264,6 +286,9 @@ public int CalculateSize() {
if (VariableLengthObservation != false) {
size += 1 + 1;
}
+ if (MultiAgentGroups != false) {
+ size += 1 + 1;
+ }
if (_unknownFields != null) {
size += _unknownFields.CalculateSize();
}
@@ -293,6 +318,9 @@ public void MergeFrom(UnityRLCapabilitiesProto other) {
if (other.VariableLengthObservation != false) {
VariableLengthObservation = other.VariableLengthObservation;
}
+ if (other.MultiAgentGroups != false) {
+ MultiAgentGroups = other.MultiAgentGroups;
+ }
_unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields);
}
@@ -328,6 +356,10 @@ public void MergeFrom(pb::CodedInputStream input) {
VariableLengthObservation = input.ReadBool();
break;
}
+ case 56: {
+ MultiAgentGroups = input.ReadBool();
+ break;
+ }
}
}
}
diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py
index 8f765fa5f8f..35b8fbdef15 100644
--- a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py
+++ b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py
@@ -19,7 +19,7 @@
name='mlagents_envs/communicator_objects/capabilities.proto',
package='communicator_objects',
syntax='proto3',
- serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"\xd2\x01\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x12#\n\x1b\x63oncatenatedPngObservations\x18\x02 \x01(\x08\x12 \n\x18\x63ompressedChannelMapping\x18\x03 \x01(\x08\x12\x15\n\rhybridActions\x18\x04 \x01(\x08\x12\x19\n\x11trainingAnalytics\x18\x05 \x01(\x08\x12!\n\x19variableLengthObservation\x18\x06 \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
+ serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"\xec\x01\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x12#\n\x1b\x63oncatenatedPngObservations\x18\x02 \x01(\x08\x12 \n\x18\x63ompressedChannelMapping\x18\x03 \x01(\x08\x12\x15\n\rhybridActions\x18\x04 \x01(\x08\x12\x19\n\x11trainingAnalytics\x18\x05 \x01(\x08\x12!\n\x19variableLengthObservation\x18\x06 \x01(\x08\x12\x18\n\x10multiAgentGroups\x18\x07 \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
)
@@ -74,6 +74,13 @@
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='multiAgentGroups', full_name='communicator_objects.UnityRLCapabilitiesProto.multiAgentGroups', index=6,
+ number=7, type=8, cpp_type=7, label=1,
+ has_default_value=False, default_value=False,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ options=None, file=DESCRIPTOR),
],
extensions=[
],
@@ -87,7 +94,7 @@
oneofs=[
],
serialized_start=80,
- serialized_end=290,
+ serialized_end=316,
)
DESCRIPTOR.message_types_by_name['UnityRLCapabilitiesProto'] = _UNITYRLCAPABILITIESPROTO
diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi
index ae49d4d4386..1c6a1f7030b 100644
--- a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi
+++ b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi
@@ -31,6 +31,7 @@ class UnityRLCapabilitiesProto(google___protobuf___message___Message):
hybridActions = ... # type: builtin___bool
trainingAnalytics = ... # type: builtin___bool
variableLengthObservation = ... # type: builtin___bool
+ multiAgentGroups = ... # type: builtin___bool
def __init__(self,
*,
@@ -40,12 +41,13 @@ class UnityRLCapabilitiesProto(google___protobuf___message___Message):
hybridActions : typing___Optional[builtin___bool] = None,
trainingAnalytics : typing___Optional[builtin___bool] = None,
variableLengthObservation : typing___Optional[builtin___bool] = None,
+ multiAgentGroups : typing___Optional[builtin___bool] = None,
) -> None: ...
@classmethod
def FromString(cls, s: builtin___bytes) -> UnityRLCapabilitiesProto: ...
def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ...
if sys.version_info >= (3,):
- def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"compressedChannelMapping",u"concatenatedPngObservations",u"hybridActions",u"trainingAnalytics",u"variableLengthObservation"]) -> None: ...
+ def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"compressedChannelMapping",u"concatenatedPngObservations",u"hybridActions",u"multiAgentGroups",u"trainingAnalytics",u"variableLengthObservation"]) -> None: ...
else:
- def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"compressedChannelMapping",b"compressedChannelMapping",u"concatenatedPngObservations",b"concatenatedPngObservations",u"hybridActions",b"hybridActions",u"trainingAnalytics",b"trainingAnalytics",u"variableLengthObservation",b"variableLengthObservation"]) -> None: ...
+ def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"compressedChannelMapping",b"compressedChannelMapping",u"concatenatedPngObservations",b"concatenatedPngObservations",u"hybridActions",b"hybridActions",u"multiAgentGroups",b"multiAgentGroups",u"trainingAnalytics",b"trainingAnalytics",u"variableLengthObservation",b"variableLengthObservation"]) -> None: ...
diff --git a/ml-agents-envs/mlagents_envs/environment.py b/ml-agents-envs/mlagents_envs/environment.py
index 1411d9180f4..633f298bb45 100644
--- a/ml-agents-envs/mlagents_envs/environment.py
+++ b/ml-agents-envs/mlagents_envs/environment.py
@@ -63,7 +63,7 @@ class UnityEnvironment(BaseEnv):
# * 1.2.0 - support compression mapping for stacked compressed observations.
# * 1.3.0 - support action spaces with both continuous and discrete actions.
# * 1.4.0 - support training analytics sent from python trainer to the editor.
- # * 1.5.0 - support variable length observation training.
+ # * 1.5.0 - support variable length observation training and multi-agent groups.
API_VERSION = "1.5.0"
# Default port that the editor listens on. If an environment executable
@@ -124,6 +124,7 @@ def _get_capabilities_proto() -> UnityRLCapabilitiesProto:
capabilities.hybridActions = True
capabilities.trainingAnalytics = True
capabilities.variableLengthObservation = True
+ capabilities.multiAgentGroups = True
return capabilities
@staticmethod
diff --git a/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto b/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto
index 7459e0cae3d..4a7690cd7f4 100644
--- a/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto
+++ b/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto
@@ -25,4 +25,7 @@ message UnityRLCapabilitiesProto {
// Support for variable length observations of rank 2
bool variableLengthObservation = 6;
+
+ // Support for multi agent groups and group rewards
+ bool multiAgentGroups = 7;
}