Skip to content

Commit 2df9d1a

Browse files
committed
Updated modules to support module combinator builder pattern
1 parent 3ca8e19 commit 2df9d1a

230 files changed

Lines changed: 7015 additions & 4756 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
using System.ComponentModel;
2+
using System.Xml.Serialization;
3+
4+
namespace Bonsai.ML.Torch.NeuralNets;
5+
6+
/// <summary>
7+
/// Represents an operator that creates an activation function.
8+
/// </summary>
9+
[XmlInclude(typeof(NonLinearActivations.ContinuouslyDifferentiableExponential))]
10+
[XmlInclude(typeof(NonLinearActivations.Exponential))]
11+
[XmlInclude(typeof(NonLinearActivations.Gated))]
12+
[XmlInclude(typeof(NonLinearActivations.GaussianError))]
13+
[XmlInclude(typeof(NonLinearActivations.HardShrinkage))]
14+
[XmlInclude(typeof(NonLinearActivations.HardSigmoid))]
15+
[XmlInclude(typeof(NonLinearActivations.Hardswish))]
16+
[XmlInclude(typeof(NonLinearActivations.HardTanh))]
17+
[XmlInclude(typeof(NonLinearActivations.LeakyRectified))]
18+
[XmlInclude(typeof(NonLinearActivations.LogSigmoid))]
19+
[XmlInclude(typeof(NonLinearActivations.LogSoftmax))]
20+
[XmlInclude(typeof(NonLinearActivations.Mish))]
21+
[XmlInclude(typeof(NonLinearActivations.MultiheadAttention))]
22+
[XmlInclude(typeof(NonLinearActivations.ParametricRectified))]
23+
[XmlInclude(typeof(NonLinearActivations.RandomizedLeakyRectified))]
24+
[XmlInclude(typeof(NonLinearActivations.Rectified))]
25+
[XmlInclude(typeof(NonLinearActivations.RectifiedBounded))]
26+
[XmlInclude(typeof(NonLinearActivations.ScaledExponential))]
27+
[XmlInclude(typeof(NonLinearActivations.Sigmoid))]
28+
[XmlInclude(typeof(NonLinearActivations.SigmoidWeighted))]
29+
[XmlInclude(typeof(NonLinearActivations.Softmax))]
30+
[XmlInclude(typeof(NonLinearActivations.Softmax2d))]
31+
[XmlInclude(typeof(NonLinearActivations.Softmin))]
32+
[XmlInclude(typeof(NonLinearActivations.Softplus))]
33+
[XmlInclude(typeof(NonLinearActivations.SoftShrinkage))]
34+
[XmlInclude(typeof(NonLinearActivations.Softsign))]
35+
[XmlInclude(typeof(NonLinearActivations.Tanh))]
36+
[XmlInclude(typeof(NonLinearActivations.TanhShrinkage))]
37+
[XmlInclude(typeof(NonLinearActivations.Threshold))]
38+
[DefaultProperty(nameof(ActivationFunction))]
39+
[Combinator]
40+
[Description("Creates an activation function.")]
41+
[WorkflowElementCategory(ElementCategory.Source)]
42+
public class ActivationFunctionBuilder : ModuleCombinatorBuilder, INamedElement
43+
{
44+
/// <inheritdoc/>
45+
public override Range<int> ArgumentRange => Range.Create(0, 1);
46+
47+
string INamedElement.Name => $"ActivationFunction.{GetElementDisplayName(ActivationFunction)}";
48+
49+
/// <summary>
50+
/// Initializes a new instance of the <see cref="ActivationFunctionBuilder"/> class.
51+
/// </summary>
52+
public ActivationFunctionBuilder()
53+
{
54+
Module = new NonLinearActivations.Rectified();
55+
}
56+
57+
/// <summary>
58+
/// Gets or sets the specific activation function to create.
59+
/// </summary>
60+
[DesignOnly(true)]
61+
[DisplayName("Module")]
62+
[Externalizable(false)]
63+
[RefreshProperties(RefreshProperties.All)]
64+
[Category(nameof(CategoryAttribute.Design))]
65+
[Description("The specific activation function to create.")]
66+
[TypeConverter(typeof(ModuleTypeConverter))]
67+
public object ActivationFunction
68+
{
69+
get => Module;
70+
set => Module = value;
71+
}
72+
}

src/Bonsai.ML.Torch/NeuralNets/Backward.cs

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,14 @@
11
using System;
2-
using System.Linq;
32
using System.ComponentModel;
4-
using System.Reactive;
5-
using System.Reactive.Disposables;
63
using System.Reactive.Linq;
7-
using System.Xml.Serialization;
84
using static TorchSharp.torch;
9-
using static TorchSharp.torch.nn;
10-
using static TorchSharp.torch.optim;
11-
using static TorchSharp.torch.optim.lr_scheduler;
12-
using System.Threading.Tasks;
135

146
namespace Bonsai.ML.Torch.NeuralNets;
157

168
/// <summary>
179
/// Represents an operator that computes backward on the input tensor.
1810
/// </summary>
1911
[Combinator]
20-
[ResetCombinator]
2112
[Description("Computes backward on the input tensor.")]
2213
[WorkflowElementCategory(ElementCategory.Sink)]
2314
public class Backward

src/Bonsai.ML.Torch/NeuralNets/CollectParameters.cs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
1-
using Bonsai.Expressions;
21
using System;
32
using System.ComponentModel;
43
using System.Reactive.Linq;
54
using System.Linq;
6-
using System.Linq.Expressions;
75
using System.Collections.Generic;
8-
using System.Reflection;
9-
using TorchSharp;
106
using TorchSharp.Modules;
11-
using static TorchSharp.torch;
7+
using static TorchSharp.torch.nn;
128

139
namespace Bonsai.ML.Torch.NeuralNets;
1410

@@ -23,9 +19,13 @@ public class CollectParameters
2319
/// <summary>
2420
/// Collects the parameters from torch modules into a collection.
2521
/// </summary>
26-
public IObservable<IEnumerable<Parameter>> Process(params IObservable<nn.Module>[] sources)
22+
/// <param name="sources"></param>
23+
/// <returns></returns>
24+
public IObservable<IEnumerable<Parameter>> Process(params IObservable<Module>[] sources)
2725
{
28-
return Observable.Concat(sources)
26+
return Observable
27+
.Concat(sources.Select(source =>
28+
source.Take(1)))
2929
.SelectMany(module =>
3030
{
3131
return module.parameters(recurse: true);
Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,27 @@
11
using System;
2-
using System.Collections.Generic;
32
using System.ComponentModel;
43
using System.Reactive.Linq;
5-
using TorchSharp;
64
using static TorchSharp.torch;
75
using static TorchSharp.torch.nn;
8-
using System.Xml.Serialization;
96
using System.Linq;
107

118
namespace Bonsai.ML.Torch.NeuralNets.Container;
129

1310
/// <summary>
14-
/// Creates a sequential model from the specified modules.
11+
/// Represents an operator that creates a sequential container.
1512
/// </summary>
16-
[Combinator]
17-
[ResetCombinator]
18-
[Description("Creates a sequential model from the specified modules.")]
19-
[WorkflowElementCategory(ElementCategory.Source)]
13+
/// <remarks>
14+
/// See <see href="https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html"/> for more information.
15+
/// </remarks>
16+
[Description("Creates a sequential container.")]
2017
public class Sequential
2118
{
2219
/// <summary>
23-
/// The device on which to create the sequential model.
24-
/// </summary>
25-
[XmlIgnore]
26-
public Device? Device { get; set; } = null;
27-
28-
/// <summary>
29-
/// Generates an observable sequence that creates a sequential model from the input modules.
20+
/// Creates a sequential container from the input modules.
3021
/// </summary>
3122
/// <returns></returns>
3223
public IObservable<Module<Tensor, Tensor>> Process(IObservable<Module<Tensor, Tensor>[]> source)
3324
{
34-
return source.SelectMany(modules =>
35-
{
36-
var sequential = Sequential(modules);
37-
if (Device is not null && Device != CPU)
38-
{
39-
sequential.to(Device);
40-
}
41-
return Observable.Return(sequential);
42-
});
25+
return source.Select(modules => Sequential(modules));
4326
}
4427
}

src/Bonsai.ML.Torch/NeuralNets/Convolution/Conv1d.cs

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
using System;
22
using System.ComponentModel;
3-
using System.Collections.Generic;
43
using System.Reactive.Linq;
54
using System.Xml.Serialization;
65
using TorchSharp;
7-
using TorchSharp.Modules;
86
using static TorchSharp.torch;
97
using static TorchSharp.torch.nn;
108

@@ -13,61 +11,58 @@ namespace Bonsai.ML.Torch.NeuralNets.Convolution;
1311
/// <summary>
1412
/// Represents an operator that creates a 1D convolution module.
1513
/// </summary>
14+
/// <remarks>
15+
/// See <see href="https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html"/> for more information.
16+
/// </remarks>
1617
[Description("Creates a 1D convolution module.")]
1718
public class Conv1d
1819
{
1920
/// <summary>
20-
/// The in_channels parameter for the Conv1d module.
21+
/// The number of input channels in the input tensor.
2122
/// </summary>
22-
[Description("The in_channels parameter for the Conv1d module")]
23+
[Description("The number of input channels in the input tensor.")]
2324
public long InChannels { get; set; }
2425

2526
/// <summary>
26-
/// The out_channels parameter for the Conv1d module.
27+
/// The number of output channels produced by the convolution.
2728
/// </summary>
28-
[Description("The out_channels parameter for the Conv1d module")]
29+
[Description("The number of output channels produced by the convolution.")]
2930
public long OutChannels { get; set; }
3031

3132
/// <summary>
32-
/// The kernelsize parameter for the Conv1d module.
33+
/// The size of the convolution kernel.
3334
/// </summary>
34-
[Description("The kernelsize parameter for the Conv1d module")]
35+
[Description("The size of the convolution kernel.")]
3536
public long KernelSize { get; set; }
3637

3738
/// <summary>
38-
/// The stride parameter for the Conv1d module.
39+
/// The stride of the convolution.
3940
/// </summary>
40-
[Description("The stride parameter for the Conv1d module")]
41+
[Description("The stride of the convolution.")]
4142
public long Stride { get; set; } = 1;
4243

4344
/// <summary>
44-
/// The padding parameter for the Conv1d module.
45+
/// The padding added to both sides of the input.
4546
/// </summary>
46-
[Description("The padding parameter for the Conv1d module")]
47+
[Description("The padding added to both sides of the input.")]
4748
public long Padding { get; set; } = 0;
4849

4950
/// <summary>
50-
/// The output_padding parameter for the Conv1d module.
51+
/// The spacing between kernel elements.
5152
/// </summary>
52-
[Description("The output_padding parameter for the ConvTransposed1d module")]
53-
public long OutputPadding { get; set; } = 0;
54-
55-
/// <summary>
56-
/// The dilation parameter for the Conv1d module.
57-
/// </summary>
58-
[Description("The dilation parameter for the Conv1d module")]
53+
[Description("The spacing between kernel elements.")]
5954
public long Dilation { get; set; } = 1;
6055

6156
/// <summary>
62-
/// The padding_mode parameter for the Conv1d module.
57+
/// The mode of padding.
6358
/// </summary>
64-
[Description("The padding_mode parameter for the Conv1d module")]
59+
[Description("The mode of padding.")]
6560
public PaddingModes PaddingMode { get; set; } = PaddingModes.Zeros;
6661

6762
/// <summary>
68-
/// The groups parameter for the Conv1d module.
63+
/// The number of blocked connections from input channels to output channels.
6964
/// </summary>
70-
[Description("The groups parameter for the Conv1d module")]
65+
[Description("The number of blocked connections from input channels to output channels.")]
7166
public long Groups { get; set; } = 1;
7267

7368
/// <summary>
@@ -77,21 +72,22 @@ public class Conv1d
7772
public bool Bias { get; set; } = true;
7873

7974
/// <summary>
80-
/// The desired device of returned tensor.
75+
/// The desired device of the returned tensor.
8176
/// </summary>
8277
[XmlIgnore]
83-
[Description("The desired device of returned tensor")]
78+
[Description("The desired device of the returned tensor")]
8479
public Device Device { get; set; } = null;
8580

8681
/// <summary>
87-
/// The desired data type of returned tensor.
82+
/// The desired data type of the returned tensor.
8883
/// </summary>
89-
[Description("The desired data type of returned tensor")]
84+
[Description("The desired data type of the returned tensor")]
9085
public ScalarType? Type { get; set; } = null;
9186

9287
/// <summary>
9388
/// Creates a Conv1d module.
9489
/// </summary>
90+
/// <returns></returns>
9591
public IObservable<Module<Tensor, Tensor>> Process()
9692
{
9793
return Observable.Return(Conv1d(InChannels, OutChannels, KernelSize, Stride, Padding, Dilation, PaddingMode, Groups, Bias, Device, Type));
@@ -103,7 +99,6 @@ public IObservable<Module<Tensor, Tensor>> Process()
10399
/// <typeparam name="T"></typeparam>
104100
/// <param name="source"></param>
105101
/// <returns></returns>
106-
/// <exception cref="InvalidOperationException"></exception>
107102
public IObservable<Module<Tensor, Tensor>> Process<T>(IObservable<T> source)
108103
{
109104
return source.Select(_ => Conv1d(InChannels, OutChannels, KernelSize, Stride, Padding, Dilation, PaddingMode, Groups, Bias, Device, Type));

0 commit comments

Comments
 (0)