diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md
index 06e19ac3c30..895564884a6 100755
--- a/com.unity.ml-agents/CHANGELOG.md
+++ b/com.unity.ml-agents/CHANGELOG.md
@@ -11,6 +11,7 @@ and this project adheres to
### Minor Changes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
+- Added a fully connected visual encoder for environments with very small image inputs. (#5351)
### Bug Fixes
diff --git a/docs/Training-Configuration-File.md b/docs/Training-Configuration-File.md
index 114f66764d0..73aa23956ab 100644
--- a/docs/Training-Configuration-File.md
+++ b/docs/Training-Configuration-File.md
@@ -42,7 +42,7 @@ choice of the trainer (which we review on subsequent sections).
| `network_settings -> hidden_units` | (default = `128`) Number of units in the hidden layers of the neural network. Correspond to how many units are in each fully connected layer of the neural network. For simple problems where the correct action is a straightforward combination of the observation inputs, this should be small. For problems where the action is a very complex interaction between the observation variables, this should be larger.
Typical range: `32` - `512` |
| `network_settings -> num_layers` | (default = `2`) The number of hidden layers in the neural network. Corresponds to how many hidden layers are present after the observation input, or after the CNN encoding of the visual observation. For simple problems, fewer layers are likely to train faster and more efficiently. More layers may be necessary for more complex control problems.
Typical range: `1` - `3` |
| `network_settings -> normalize` | (default = `false`) Whether normalization is applied to the vector observation inputs. This normalization is based on the running average and variance of the vector observation. Normalization can be helpful in cases with complex continuous control problems, but may be harmful with simpler discrete control problems. |
-| `network_settings -> vis_encode_type` | (default = `simple`) Encoder type for encoding visual observations.
`simple` (default) uses a simple encoder which consists of two convolutional layers, `nature_cnn` uses the CNN implementation proposed by [Mnih et al.](https://www.nature.com/articles/nature14236), consisting of three convolutional layers, and `resnet` uses the [IMPALA Resnet](https://arxiv.org/abs/1802.01561) consisting of three stacked layers, each with two residual blocks, making a much larger network than the other two. `match3` is a smaller CNN ([Gudmundsoon et al.](https://www.researchgate.net/publication/328307928_Human-Like_Playtesting_with_Deep_Learning)) that is optimized for board games, and can be used down to visual observation sizes of 5x5. |
+| `network_settings -> vis_encode_type` | (default = `simple`) Encoder type for encoding visual observations.
`simple` (default) uses a simple encoder which consists of two convolutional layers, `nature_cnn` uses the CNN implementation proposed by [Mnih et al.](https://www.nature.com/articles/nature14236), consisting of three convolutional layers, and `resnet` uses the [IMPALA Resnet](https://arxiv.org/abs/1802.01561) consisting of three stacked layers, each with two residual blocks, making a much larger network than the other two. `match3` is a smaller CNN ([Gudmundsoon et al.](https://www.researchgate.net/publication/328307928_Human-Like_Playtesting_with_Deep_Learning)) that is optimized for board games, and can be used down to visual observation sizes of 5x5. `fully_connected` uses a single fully connected dense layer as encoder and should be reserved for very small inputs. |
| `network_settings -> conditioning_type` | (default = `hyper`) Conditioning type for the policy using goal observations.
`none` treats the goal observations as regular observations, `hyper` (default) uses a HyperNetwork with goal observations as input to generate some of the weights of the policy. Note that when using `hyper` the number of parameters of the network increases greatly. Therefore, it is recommended to reduce the number of `hidden_units` when using this `conditioning_type`
diff --git a/ml-agents/mlagents/trainers/settings.py b/ml-agents/mlagents/trainers/settings.py
index 6226cf9bf4e..f094a4bf770 100644
--- a/ml-agents/mlagents/trainers/settings.py
+++ b/ml-agents/mlagents/trainers/settings.py
@@ -81,6 +81,7 @@ def as_dict(self):
class EncoderType(Enum):
+ FULLY_CONNECTED = "fully_connected"
MATCH3 = "match3"
SIMPLE = "simple"
NATURE_CNN = "nature_cnn"
diff --git a/ml-agents/mlagents/trainers/tests/torch/test_encoders.py b/ml-agents/mlagents/trainers/tests/torch/test_encoders.py
index 24a4669e6e6..4de95e99554 100644
--- a/ml-agents/mlagents/trainers/tests/torch/test_encoders.py
+++ b/ml-agents/mlagents/trainers/tests/torch/test_encoders.py
@@ -5,6 +5,8 @@
from mlagents.trainers.torch.encoders import (
VectorInput,
Normalizer,
+ SmallVisualEncoder,
+ FullyConnectedVisualEncoder,
SimpleVisualEncoder,
ResNetVisualEncoder,
NatureVisualEncoder,
@@ -73,7 +75,14 @@ def test_vector_encoder(mock_normalizer):
@pytest.mark.parametrize("image_size", [(36, 36, 3), (84, 84, 4), (256, 256, 5)])
@pytest.mark.parametrize(
- "vis_class", [SimpleVisualEncoder, ResNetVisualEncoder, NatureVisualEncoder]
+ "vis_class",
+ [
+ SimpleVisualEncoder,
+ ResNetVisualEncoder,
+ NatureVisualEncoder,
+ SmallVisualEncoder,
+ FullyConnectedVisualEncoder,
+ ],
)
def test_visual_encoder(vis_class, image_size):
num_outputs = 128
@@ -82,3 +91,34 @@ def test_visual_encoder(vis_class, image_size):
sample_input = torch.ones((1, image_size[0], image_size[1], image_size[2]))
encoding = enc(sample_input)
assert encoding.shape == (1, num_outputs)
+
+
+@pytest.mark.parametrize(
+ "vis_class, size",
+ [
+ (SimpleVisualEncoder, 36),
+ (ResNetVisualEncoder, 36),
+ (NatureVisualEncoder, 36),
+ (SmallVisualEncoder, 10),
+ (FullyConnectedVisualEncoder, 36),
+ ],
+)
+def test_visual_encoder_trains(vis_class, size):
+ torch.manual_seed(0)
+ image_size = (size, size, 1)
+ batch = 100
+
+ inputs = torch.cat(
+ [torch.zeros((batch,) + image_size), torch.ones((batch,) + image_size)], dim=0
+ )
+ target = torch.cat([torch.zeros((batch,)), torch.ones((batch,))], dim=0)
+ enc = vis_class(image_size[0], image_size[1], image_size[2], 1)
+ optimizer = torch.optim.Adam(enc.parameters(), lr=0.001)
+
+ for _ in range(15):
+ prediction = enc(inputs)[:, 0]
+ loss = torch.mean((target - prediction) ** 2)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ assert loss.item() < 0.05
diff --git a/ml-agents/mlagents/trainers/torch/encoders.py b/ml-agents/mlagents/trainers/torch/encoders.py
index 676136fff42..0d9123c7871 100644
--- a/ml-agents/mlagents/trainers/torch/encoders.py
+++ b/ml-agents/mlagents/trainers/torch/encoders.py
@@ -111,6 +111,30 @@ def update_normalization(self, inputs: torch.Tensor) -> None:
self.normalizer.update(inputs)
+class FullyConnectedVisualEncoder(nn.Module):
+ def __init__(
+ self, height: int, width: int, initial_channels: int, output_size: int
+ ):
+ super().__init__()
+ self.output_size = output_size
+ self.input_size = height * width * initial_channels
+ self.dense = nn.Sequential(
+ linear_layer(
+ self.input_size,
+ self.output_size,
+ kernel_init=Initialization.KaimingHeNormal,
+ kernel_gain=1.41, # Use ReLU gain
+ ),
+ nn.LeakyReLU(),
+ )
+
+ def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
+ if not exporting_to_onnx.is_exporting():
+ visual_obs = visual_obs.permute([0, 3, 1, 2])
+ hidden = visual_obs.reshape(-1, self.input_size)
+ return self.dense(hidden)
+
+
class SmallVisualEncoder(nn.Module):
"""
CNN architecture used by King in their Candy Crush predictor
diff --git a/ml-agents/mlagents/trainers/torch/utils.py b/ml-agents/mlagents/trainers/torch/utils.py
index b73bfa0bf95..4feb70018fb 100644
--- a/ml-agents/mlagents/trainers/torch/utils.py
+++ b/ml-agents/mlagents/trainers/torch/utils.py
@@ -8,6 +8,7 @@
ResNetVisualEncoder,
NatureVisualEncoder,
SmallVisualEncoder,
+ FullyConnectedVisualEncoder,
VectorInput,
)
from mlagents.trainers.settings import EncoderType, ScheduleType
@@ -20,6 +21,7 @@ class ModelUtils:
# Minimum supported side for each encoder type. If refactoring an encoder, please
# adjust these also.
MIN_RESOLUTION_FOR_ENCODER = {
+ EncoderType.FULLY_CONNECTED: 1,
EncoderType.MATCH3: 5,
EncoderType.SIMPLE: 20,
EncoderType.NATURE_CNN: 36,
@@ -123,6 +125,7 @@ def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module:
EncoderType.NATURE_CNN: NatureVisualEncoder,
EncoderType.RESNET: ResNetVisualEncoder,
EncoderType.MATCH3: SmallVisualEncoder,
+ EncoderType.FULLY_CONNECTED: FullyConnectedVisualEncoder,
}
return ENCODER_FUNCTION_BY_TYPE.get(encoder_type)