diff --git a/com.unity.ml-agents/Runtime/Academy.cs b/com.unity.ml-agents/Runtime/Academy.cs index 92d67b12d08..745ed3cfc80 100644 --- a/com.unity.ml-agents/Runtime/Academy.cs +++ b/com.unity.ml-agents/Runtime/Academy.cs @@ -97,7 +97,7 @@ public class Academy : IDisposable /// /// /// 1.5.0 - /// Support variable length observation training. + /// Support variable length observation training and multi-agent groups. /// /// /// diff --git a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs index 1dc768e62e1..b5d8a856329 100644 --- a/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs +++ b/com.unity.ml-agents/Runtime/Communicator/GrpcExtensions.cs @@ -20,6 +20,11 @@ namespace Unity.MLAgents internal static class GrpcExtensions { #region AgentInfo + /// + /// Static flag to make sure that we only fire the warning once. + /// + private static bool s_HaveWarnedTrainerCapabilitiesAgentGroup = false; + /// /// Converts a AgentInfo to a protobuf generated AgentInfoActionPairProto /// @@ -55,6 +60,22 @@ public static AgentInfoActionPairProto ToInfoActionPairProto(this AgentInfo ai) /// The protobuf version of the AgentInfo. public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai) { + if(ai.groupId > 0) + { + var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.MultiAgentGroups; + if (!trainerCanHandle) + { + if (!s_HaveWarnedTrainerCapabilitiesAgentGroup) + { + Debug.LogWarning( + $"Attached trainer doesn't support Multi Agent Groups; group rewards will be ignored." + + "Please find the versions that work best together from our release page: " + + "https://github.com/Unity-Technologies/ml-agents/releases" + ); + s_HaveWarnedTrainerCapabilitiesAgentGroup = true; + } + } + } var agentInfoProto = new AgentInfoProto { Reward = ai.reward, @@ -457,6 +478,7 @@ public static UnityRLCapabilities ToRLCapabilities(this UnityRLCapabilitiesProto HybridActions = proto.HybridActions, TrainingAnalytics = proto.TrainingAnalytics, VariableLengthObservation = proto.VariableLengthObservation, + MultiAgentGroups = proto.MultiAgentGroups, }; } @@ -470,6 +492,7 @@ public static UnityRLCapabilitiesProto ToProto(this UnityRLCapabilities rlCaps) HybridActions = rlCaps.HybridActions, TrainingAnalytics = rlCaps.TrainingAnalytics, VariableLengthObservation = rlCaps.VariableLengthObservation, + MultiAgentGroups = rlCaps.MultiAgentGroups, }; } diff --git a/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs b/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs index bd495358d53..1914ecba18c 100644 --- a/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs +++ b/com.unity.ml-agents/Runtime/Communicator/UnityRLCapabilities.cs @@ -10,6 +10,7 @@ internal class UnityRLCapabilities public bool HybridActions; public bool TrainingAnalytics; public bool VariableLengthObservation; + public bool MultiAgentGroups; /// /// A class holding the capabilities flags for Reinforcement Learning across C# and the Trainer codebase. This @@ -21,7 +22,8 @@ public UnityRLCapabilities( bool compressedChannelMapping = true, bool hybridActions = true, bool trainingAnalytics = true, - bool variableLengthObservation = true) + bool variableLengthObservation = true, + bool multiAgentGroups = true) { BaseRLCapabilities = baseRlCapabilities; ConcatenatedPngObservations = concatenatedPngObservations; @@ -29,6 +31,7 @@ public UnityRLCapabilities( HybridActions = hybridActions; TrainingAnalytics = trainingAnalytics; VariableLengthObservation = variableLengthObservation; + MultiAgentGroups = multiAgentGroups; } /// diff --git a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs index 2e370e596af..ac267f4c2f0 100644 --- a/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs +++ b/com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Capabilities.cs @@ -25,17 +25,18 @@ static CapabilitiesReflection() { byte[] descriptorData = global::System.Convert.FromBase64String( string.Concat( "CjVtbGFnZW50c19lbnZzL2NvbW11bmljYXRvcl9vYmplY3RzL2NhcGFiaWxp", - "dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi0gEKGFVuaXR5UkxD", + "dGllcy5wcm90bxIUY29tbXVuaWNhdG9yX29iamVjdHMi7AEKGFVuaXR5UkxD", "YXBhYmlsaXRpZXNQcm90bxIaChJiYXNlUkxDYXBhYmlsaXRpZXMYASABKAgS", "IwobY29uY2F0ZW5hdGVkUG5nT2JzZXJ2YXRpb25zGAIgASgIEiAKGGNvbXBy", "ZXNzZWRDaGFubmVsTWFwcGluZxgDIAEoCBIVCg1oeWJyaWRBY3Rpb25zGAQg", "ASgIEhkKEXRyYWluaW5nQW5hbHl0aWNzGAUgASgIEiEKGXZhcmlhYmxlTGVu", - "Z3RoT2JzZXJ2YXRpb24YBiABKAhCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11", - "bmljYXRvck9iamVjdHNiBnByb3RvMw==")); + "Z3RoT2JzZXJ2YXRpb24YBiABKAgSGAoQbXVsdGlBZ2VudEdyb3VwcxgHIAEo", + "CEIlqgIiVW5pdHkuTUxBZ2VudHMuQ29tbXVuaWNhdG9yT2JqZWN0c2IGcHJv", + "dG8z")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { }, new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { - new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping", "HybridActions", "TrainingAnalytics", "VariableLengthObservation" }, null, null, null) + new pbr::GeneratedClrTypeInfo(typeof(global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto), global::Unity.MLAgents.CommunicatorObjects.UnityRLCapabilitiesProto.Parser, new[]{ "BaseRLCapabilities", "ConcatenatedPngObservations", "CompressedChannelMapping", "HybridActions", "TrainingAnalytics", "VariableLengthObservation", "MultiAgentGroups" }, null, null, null) })); } #endregion @@ -78,6 +79,7 @@ public UnityRLCapabilitiesProto(UnityRLCapabilitiesProto other) : this() { hybridActions_ = other.hybridActions_; trainingAnalytics_ = other.trainingAnalytics_; variableLengthObservation_ = other.variableLengthObservation_; + multiAgentGroups_ = other.multiAgentGroups_; _unknownFields = pb::UnknownFieldSet.Clone(other._unknownFields); } @@ -170,6 +172,20 @@ public bool VariableLengthObservation { } } + /// Field number for the "multiAgentGroups" field. + public const int MultiAgentGroupsFieldNumber = 7; + private bool multiAgentGroups_; + /// + /// Support for multi agent groups and group rewards + /// + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool MultiAgentGroups { + get { return multiAgentGroups_; } + set { + multiAgentGroups_ = value; + } + } + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public override bool Equals(object other) { return Equals(other as UnityRLCapabilitiesProto); @@ -189,6 +205,7 @@ public bool Equals(UnityRLCapabilitiesProto other) { if (HybridActions != other.HybridActions) return false; if (TrainingAnalytics != other.TrainingAnalytics) return false; if (VariableLengthObservation != other.VariableLengthObservation) return false; + if (MultiAgentGroups != other.MultiAgentGroups) return false; return Equals(_unknownFields, other._unknownFields); } @@ -201,6 +218,7 @@ public override int GetHashCode() { if (HybridActions != false) hash ^= HybridActions.GetHashCode(); if (TrainingAnalytics != false) hash ^= TrainingAnalytics.GetHashCode(); if (VariableLengthObservation != false) hash ^= VariableLengthObservation.GetHashCode(); + if (MultiAgentGroups != false) hash ^= MultiAgentGroups.GetHashCode(); if (_unknownFields != null) { hash ^= _unknownFields.GetHashCode(); } @@ -238,6 +256,10 @@ public void WriteTo(pb::CodedOutputStream output) { output.WriteRawTag(48); output.WriteBool(VariableLengthObservation); } + if (MultiAgentGroups != false) { + output.WriteRawTag(56); + output.WriteBool(MultiAgentGroups); + } if (_unknownFields != null) { _unknownFields.WriteTo(output); } @@ -264,6 +286,9 @@ public int CalculateSize() { if (VariableLengthObservation != false) { size += 1 + 1; } + if (MultiAgentGroups != false) { + size += 1 + 1; + } if (_unknownFields != null) { size += _unknownFields.CalculateSize(); } @@ -293,6 +318,9 @@ public void MergeFrom(UnityRLCapabilitiesProto other) { if (other.VariableLengthObservation != false) { VariableLengthObservation = other.VariableLengthObservation; } + if (other.MultiAgentGroups != false) { + MultiAgentGroups = other.MultiAgentGroups; + } _unknownFields = pb::UnknownFieldSet.MergeFrom(_unknownFields, other._unknownFields); } @@ -328,6 +356,10 @@ public void MergeFrom(pb::CodedInputStream input) { VariableLengthObservation = input.ReadBool(); break; } + case 56: { + MultiAgentGroups = input.ReadBool(); + break; + } } } } diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py index 8f765fa5f8f..35b8fbdef15 100644 --- a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py +++ b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.py @@ -19,7 +19,7 @@ name='mlagents_envs/communicator_objects/capabilities.proto', package='communicator_objects', syntax='proto3', - serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"\xd2\x01\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x12#\n\x1b\x63oncatenatedPngObservations\x18\x02 \x01(\x08\x12 \n\x18\x63ompressedChannelMapping\x18\x03 \x01(\x08\x12\x15\n\rhybridActions\x18\x04 \x01(\x08\x12\x19\n\x11trainingAnalytics\x18\x05 \x01(\x08\x12!\n\x19variableLengthObservation\x18\x06 \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3') + serialized_pb=_b('\n5mlagents_envs/communicator_objects/capabilities.proto\x12\x14\x63ommunicator_objects\"\xec\x01\n\x18UnityRLCapabilitiesProto\x12\x1a\n\x12\x62\x61seRLCapabilities\x18\x01 \x01(\x08\x12#\n\x1b\x63oncatenatedPngObservations\x18\x02 \x01(\x08\x12 \n\x18\x63ompressedChannelMapping\x18\x03 \x01(\x08\x12\x15\n\rhybridActions\x18\x04 \x01(\x08\x12\x19\n\x11trainingAnalytics\x18\x05 \x01(\x08\x12!\n\x19variableLengthObservation\x18\x06 \x01(\x08\x12\x18\n\x10multiAgentGroups\x18\x07 \x01(\x08\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3') ) @@ -74,6 +74,13 @@ message_type=None, enum_type=None, containing_type=None, is_extension=False, extension_scope=None, options=None, file=DESCRIPTOR), + _descriptor.FieldDescriptor( + name='multiAgentGroups', full_name='communicator_objects.UnityRLCapabilitiesProto.multiAgentGroups', index=6, + number=7, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None, file=DESCRIPTOR), ], extensions=[ ], @@ -87,7 +94,7 @@ oneofs=[ ], serialized_start=80, - serialized_end=290, + serialized_end=316, ) DESCRIPTOR.message_types_by_name['UnityRLCapabilitiesProto'] = _UNITYRLCAPABILITIESPROTO diff --git a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi index ae49d4d4386..1c6a1f7030b 100644 --- a/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi +++ b/ml-agents-envs/mlagents_envs/communicator_objects/capabilities_pb2.pyi @@ -31,6 +31,7 @@ class UnityRLCapabilitiesProto(google___protobuf___message___Message): hybridActions = ... # type: builtin___bool trainingAnalytics = ... # type: builtin___bool variableLengthObservation = ... # type: builtin___bool + multiAgentGroups = ... # type: builtin___bool def __init__(self, *, @@ -40,12 +41,13 @@ class UnityRLCapabilitiesProto(google___protobuf___message___Message): hybridActions : typing___Optional[builtin___bool] = None, trainingAnalytics : typing___Optional[builtin___bool] = None, variableLengthObservation : typing___Optional[builtin___bool] = None, + multiAgentGroups : typing___Optional[builtin___bool] = None, ) -> None: ... @classmethod def FromString(cls, s: builtin___bytes) -> UnityRLCapabilitiesProto: ... def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... if sys.version_info >= (3,): - def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"compressedChannelMapping",u"concatenatedPngObservations",u"hybridActions",u"trainingAnalytics",u"variableLengthObservation"]) -> None: ... + def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",u"compressedChannelMapping",u"concatenatedPngObservations",u"hybridActions",u"multiAgentGroups",u"trainingAnalytics",u"variableLengthObservation"]) -> None: ... else: - def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"compressedChannelMapping",b"compressedChannelMapping",u"concatenatedPngObservations",b"concatenatedPngObservations",u"hybridActions",b"hybridActions",u"trainingAnalytics",b"trainingAnalytics",u"variableLengthObservation",b"variableLengthObservation"]) -> None: ... + def ClearField(self, field_name: typing_extensions___Literal[u"baseRLCapabilities",b"baseRLCapabilities",u"compressedChannelMapping",b"compressedChannelMapping",u"concatenatedPngObservations",b"concatenatedPngObservations",u"hybridActions",b"hybridActions",u"multiAgentGroups",b"multiAgentGroups",u"trainingAnalytics",b"trainingAnalytics",u"variableLengthObservation",b"variableLengthObservation"]) -> None: ... diff --git a/ml-agents-envs/mlagents_envs/environment.py b/ml-agents-envs/mlagents_envs/environment.py index 1411d9180f4..633f298bb45 100644 --- a/ml-agents-envs/mlagents_envs/environment.py +++ b/ml-agents-envs/mlagents_envs/environment.py @@ -63,7 +63,7 @@ class UnityEnvironment(BaseEnv): # * 1.2.0 - support compression mapping for stacked compressed observations. # * 1.3.0 - support action spaces with both continuous and discrete actions. # * 1.4.0 - support training analytics sent from python trainer to the editor. - # * 1.5.0 - support variable length observation training. + # * 1.5.0 - support variable length observation training and multi-agent groups. API_VERSION = "1.5.0" # Default port that the editor listens on. If an environment executable @@ -124,6 +124,7 @@ def _get_capabilities_proto() -> UnityRLCapabilitiesProto: capabilities.hybridActions = True capabilities.trainingAnalytics = True capabilities.variableLengthObservation = True + capabilities.multiAgentGroups = True return capabilities @staticmethod diff --git a/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto b/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto index 7459e0cae3d..4a7690cd7f4 100644 --- a/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto +++ b/protobuf-definitions/proto/mlagents_envs/communicator_objects/capabilities.proto @@ -25,4 +25,7 @@ message UnityRLCapabilitiesProto { // Support for variable length observations of rank 2 bool variableLengthObservation = 6; + + // Support for multi agent groups and group rewards + bool multiAgentGroups = 7; }