Skip to content

Commit 5311d2f

Browse files
authored
Merge pull request #48 from bonsai-rx/dev/torchsharp
Integrating torchsharp functionality into Bonsai
2 parents 3854f9e + 33f89c9 commit 5311d2f

72 files changed

Lines changed: 3855 additions & 18 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

Bonsai.ML.sln

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.LinearDynamicalSy
3030
EndProject
3131
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.HiddenMarkovModels.Design", "src\Bonsai.ML.HiddenMarkovModels.Design\Bonsai.ML.HiddenMarkovModels.Design.csproj", "{FC395DDC-62A4-4E14-A198-272AB05B33C7}"
3232
EndProject
33+
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Bonsai.ML.Torch", "src\Bonsai.ML.Torch\Bonsai.ML.Torch.csproj", "{06FCC9AF-CE38-44BB-92B3-0D451BE88537}"
34+
EndProject
3335
Global
3436
GlobalSection(SolutionConfigurationPlatforms) = preSolution
3537
Debug|Any CPU = Debug|Any CPU
@@ -72,6 +74,10 @@ Global
7274
{FC395DDC-62A4-4E14-A198-272AB05B33C7}.Debug|Any CPU.Build.0 = Debug|Any CPU
7375
{FC395DDC-62A4-4E14-A198-272AB05B33C7}.Release|Any CPU.ActiveCfg = Release|Any CPU
7476
{FC395DDC-62A4-4E14-A198-272AB05B33C7}.Release|Any CPU.Build.0 = Release|Any CPU
77+
{06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
78+
{06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Debug|Any CPU.Build.0 = Debug|Any CPU
79+
{06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Release|Any CPU.ActiveCfg = Release|Any CPU
80+
{06FCC9AF-CE38-44BB-92B3-0D451BE88537}.Release|Any CPU.Build.0 = Release|Any CPU
7581
EndGlobalSection
7682
GlobalSection(SolutionProperties) = preSolution
7783
HideSolutionNode = FALSE
@@ -86,6 +92,7 @@ Global
8692
{39A4414F-52B1-42D7-82FA-E65DAD885264} = {12312384-8828-4786-AE19-EFCEDF968290}
8793
{A135C7DB-EA50-4FC6-A6CB-6A5A5CC5FA13} = {12312384-8828-4786-AE19-EFCEDF968290}
8894
{17DF50BE-F481-4904-A4C8-5DF9725B2CA1} = {12312384-8828-4786-AE19-EFCEDF968290}
95+
{06FCC9AF-CE38-44BB-92B3-0D451BE88537} = {12312384-8828-4786-AE19-EFCEDF968290}
8996
EndGlobalSection
9097
GlobalSection(ExtensibilityGlobals) = postSolution
9198
SolutionGuid = {B6468F13-97CD-45E0-9E1E-C122D7F1E09F}

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ Facilitates inference using Hidden Markov Models (HMMs). It interfaces with the
4242
### Bonsai.ML.HiddenMarkovModels.Design
4343
Visualizers and editor features for the HiddenMarkovModels package.
4444

45+
### Bonsai.ML.Torch
46+
Interfaces with the [TorchSharp](https://github.com/dotnet/TorchSharp) package, a C# wrapper around the torch library. Provides tooling for manipulating tensors, performing linear algebra, training and inference with deep neural networks, and more.
47+
4548
> [!NOTE]
4649
> Bonsai.ML packages can be installed through Bonsai's integrated package manager and are generally ready for immediate use. However, some packages may require additional installation steps. Refer to the specific package section for detailed installation guides and documentation.
4750
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Bonsai.ML.Torch - Overview
2+
3+
The Torch package provides a Bonsai interface to interact with [TorchSharp](https://github.com/dotnet/TorchSharp). This package adds powerful functionality into Bonsai, namely the ability to perform tensor manipulations, type conversions, complex linear algebra, deep neural networks, and support for GPU processing.
4+
5+
## Installation Guide
6+
7+
The Bonsai.ML.Torch package can be installed through the Bonsai Package Manager and depends on the TorchSharp library. Additionally, running the package requires installing the specific torch DLLs needed for your desired application. The steps for installing are outlined below.
8+
9+
### Running on the CPU
10+
For running the package using the CPU, the `TorchSharp-cpu` library should be installed through the NuGet package manager.
11+
12+
### Running on the GPU
13+
To run torch on the GPU, you first need to ensure that you have a CUDA compatible device installed on your system.
14+
15+
Next, you must follow the [CUDA installation guide for Windows](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) or the [guide for Linux](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html). Please make sure to install the correct CUDA version ([v12.1](https://developer.nvidia.com/cuda-12-1-0-download-archive)), as `TorchSharp` currently only supports this version.
16+
17+
Next, you need to install the `cuDNN v9` library following the [guide for Windows](https://docs.nvidia.com/deeplearning/cudnn/latest/installation/windows.html) or the [guide for Linux](https://docs.nvidia.com/deeplearning/cudnn/latest/installation/linux.html). Again, you need to ensure you have the correct version installed (v9). You should consult [nvidia's support matrix](https://docs.nvidia.com/deeplearning/cudnn/latest/reference/support-matrix.html) to ensure the versions of CUDA and cuDNN you installed are compatible with your specific OS, graphics driver, and hardware.
18+
19+
Once complete, you need to install the cuda-compatible torch libraries and place them into the correct location. You can download the libraries from [the PyTorch website](https://pytorch.org/get-started/locally/) with the following options selected:
20+
21+
- PyTorch Build: Stable (2.5.1)
22+
- OS: [Your OS]
23+
- Package: LibTorch
24+
- Language: C++/Java
25+
- Compute Platform: CUDA 12.1
26+
27+
Finally, extract the zip folder and copy the contents of the `lib` folder into the `Extensions` folder of your bonsai installation directory.
28+
29+
> [!WARNING]
30+
> You can only install one of the above CPU or GPU packages, not both. The GPU package contains all of the support of the CPU package.
31+
32+
## Getting Started
33+
34+
The `Bonsai.ML.Torch` package primarily provides tooling and functionality for users to interact with and manipulate `Tensor` objects, the core data type of torch which underlies most advanced operations. Additionally, the package provides some capabilities for defining neural network architectures, running forward inference, and learning via back propagation.
35+
36+
## Tensor Operations
37+
The package provides several ways to work with tensors. Users can initialize tensors, (`Ones`, `Zeros`, etc.), create tensors from .NET data types, (`ToTensor`), and define custom tensors using Python-like syntax (`CreateTensor`). Tensors can be converted back to .NET types using the `ToArray` node (for flattening tensors into a unidimensional array) or the `ToNDArray` node (for preserving multidimensional array shapes). Furthermore, the `Tensor` object contains many extension methods which can be used via scripting with `ExpressionTransform` (for example, `it.sum()` to sum a tensor, or `it.T` to transpose), and works with overloaded operators (for example, `Zip` -> `Multiply`). It is also possible to use the `ExpressionTransform` node to access individual elements of a tensor, using the syntax `it.ReadCpuT(0)` where `T` is a primitive .NET data type (i.e. `Single`, `Double`, `Int64`, etc.).
38+
39+
40+
## Running on the GPU
41+
Users must be explicit about running computations on the GPU. First, the `InitializeDeviceType` node must run with a CUDA-compatible GPU. If successful, the node will return a `Device` object representing the GPU. Afterwards, the tensors can either be created directly on the GPU by setting the `Device` property to the GPU device, or moved to the GPU using the `ToDevice` node. For most tensor operations to work, all of the tensors involved must be present on the same device.
42+
43+
## Neural Networks
44+
The package provides initial support for working with torch `Module` objects, the core type representing deep neural networks. The `LoadModuleFromArchitecture` node allows users to select from a list of common architectures, and can optionally load pretrained weights from disk. Additionally, the package supports loading `TorchScript` modules with the `LoadScriptModule` node, which enables users to use torch modules saved in the `.pt` file format. Users can then use the `Forward` node to run inference and the `Backward` node to run back propagation.

docs/articles/toc.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,7 @@
1313
- name: Overview
1414
href: HiddenMarkovModels/hmm-overview.md
1515
- name: Getting Started
16-
href: HiddenMarkovModels/hmm-getting-started.md
16+
href: HiddenMarkovModels/hmm-getting-started.md
17+
- name: Torch
18+
- name: Overview
19+
href: Torch/torch-overview.md
Lines changed: 194 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,32 +3,32 @@
33
using System.Collections;
44
using System.Collections.Generic;
55
using System.Linq;
6+
using System.Reflection;
67
using Newtonsoft.Json;
78
using Newtonsoft.Json.Linq;
89

910
namespace Bonsai.ML.Data
1011
{
1112
/// <summary>
12-
/// Provides a set of static methods for working with arrays.
13+
/// Provides a set of static methods for working with JSON data containing numeric and boolean values.
1314
/// </summary>
14-
public static class ArrayHelper
15+
public static class JsonDataHelper
1516
{
1617
/// <summary>
17-
/// Parses the input string into an object of the specified type.
18-
/// If the input is a JSON array, the method will attempt to parse it into a list or array of the specified type.
18+
/// Parses the input string into an object of the specified type.
1919
/// </summary>
2020
/// <param name="input">The string to parse.</param>
2121
/// <param name="dtype">The data type of the object.</param>
2222
/// <returns>An object of the specified type containing the parsed data.</returns>
23-
public static object ParseString(string input, Type dtype)
23+
public static object Parse(string input, Type dtype)
2424
{
2525
if (!IsValidJson(input))
2626
{
2727
throw new ArgumentException($"Parameter: {nameof(input)} is not valid JSON.");
2828
}
2929

3030
var token = JsonConvert.DeserializeObject<JToken>(input);
31-
31+
3232
if (token is JValue value)
3333
{
3434
return Convert.ChangeType(value, dtype);
@@ -67,7 +67,7 @@ public static object ParseToken(JToken token, Type dtype)
6767
}
6868
else if (token is JArray)
6969
{
70-
if (token[0] is JValue)
70+
if (token.Count() == 0 || token[0] is JValue)
7171
{
7272
if (token.All(item => item is JValue))
7373
{
@@ -78,7 +78,8 @@ public static object ParseToken(JToken token, Type dtype)
7878
}
7979
else
8080
{
81-
var subArrayDimensions = token.Cast<JArray>().Select(value => {
81+
var subArrayDimensions = token.Cast<JArray>().Select(value =>
82+
{
8283
var depth = ParseDepth(value);
8384
return ParseDimensions(value, depth);
8485
}).ToList();
@@ -147,7 +148,7 @@ private static int[] ParseDimensions(JToken token, int depth, int currentLevel =
147148
if (currentLevel > 0 && token is JArray arr && arr.All(x => x is JArray))
148149
{
149150
var subArrayDimensions = new HashSet<string>();
150-
foreach (JArray subArr in arr)
151+
foreach (JArray subArr in arr.Cast<JArray>())
151152
{
152153
int[] subDims = ParseDimensions(subArr, depth - currentLevel, currentLevel + 1);
153154
subArrayDimensions.Add(string.Join(",", subDims));
@@ -159,7 +160,7 @@ private static int[] ParseDimensions(JToken token, int depth, int currentLevel =
159160
}
160161
}
161162

162-
return dimensions.ToArray();
163+
return [.. dimensions];
163164
}
164165

165166
private static void PopulateArray(JToken token, Array array, int[] indices, Type dtype)
@@ -185,13 +186,13 @@ private static object CreateList(JToken token, Type dtype)
185186
{
186187
var listType = typeof(List<>).MakeGenericType(DetermineListType(token, dtype));
187188
var list = (IList)Activator.CreateInstance(listType);
188-
189+
189190
foreach (var item in token)
190191
{
191192
var result = ParseToken(item, dtype);
192193
list.Add(result);
193194
}
194-
195+
195196
return list;
196197
}
197198

@@ -224,5 +225,186 @@ private static Type DetermineListType(JToken token, Type type)
224225
return typeof(object);
225226
}
226227
}
228+
229+
/// <summary>
230+
/// Formats the input object into a string representation that is consistent with JSON syntax.
231+
/// </summary>
232+
/// <param name="obj">The object to format.</param>
233+
/// <returns>A string representation that is consistent with JSON syntax.</returns>
234+
public static string Format(object obj)
235+
{
236+
var sb = new StringBuilder();
237+
int depth = 0;
238+
Format(obj, sb, depth);
239+
return sb.ToString();
240+
}
241+
242+
private static void Format(object obj, StringBuilder sb, int depth)
243+
{
244+
switch (obj)
245+
{
246+
case null:
247+
sb.Append("null");
248+
break;
249+
case string:
250+
case char:
251+
sb.Append('"').Append(obj).Append('"');
252+
break;
253+
case bool:
254+
sb.Append(obj.ToString().ToLower());
255+
break;
256+
case int:
257+
case double:
258+
case float:
259+
case long:
260+
case short:
261+
case byte:
262+
case ushort:
263+
case uint:
264+
case ulong:
265+
case sbyte:
266+
case decimal:
267+
sb.Append(obj);
268+
break;
269+
case Array:
270+
FormatArray(obj, sb, depth);
271+
break;
272+
case IList:
273+
FormatList(obj, sb, depth);
274+
break;
275+
case IDictionary:
276+
FormatDictionary(obj, sb, depth);
277+
break;
278+
case object tuple when obj.GetType().IsGenericType &&
279+
obj.GetType().GetGenericTypeDefinition() == typeof(Tuple<,>) ||
280+
obj.GetType().GetGenericTypeDefinition() == typeof(Tuple<,,>) ||
281+
obj.GetType().GetGenericTypeDefinition() == typeof(Tuple<,,,>) ||
282+
obj.GetType().GetGenericTypeDefinition() == typeof(Tuple<,,,,>) ||
283+
obj.GetType().GetGenericTypeDefinition() == typeof(Tuple<,,,,,>) ||
284+
obj.GetType().GetGenericTypeDefinition() == typeof(Tuple<,,,,,,>) ||
285+
obj.GetType().GetGenericTypeDefinition() == typeof(Tuple<,,,,,,,>):
286+
FormatTuple(obj, sb, depth);
287+
break;
288+
default:
289+
FormatDefault(obj, sb, depth);
290+
break;
291+
}
292+
}
293+
294+
private static void FormatArray(object obj, StringBuilder sb, int depth)
295+
{
296+
var array = (Array)obj;
297+
if (array.Rank == 1)
298+
{
299+
sb.Append('[');
300+
for (int i = 0; i < array.Length; i++)
301+
{
302+
if (i > 0)
303+
{
304+
sb.Append(", ");
305+
}
306+
Format(array.GetValue(i), sb, depth + 1);
307+
}
308+
sb.Append(']');
309+
}
310+
else
311+
{
312+
FormatNDArray(array, new int[array.Rank], 0, sb, depth);
313+
}
314+
}
315+
316+
private static void FormatNDArray(Array array, int[] indices, int dimension, StringBuilder sb, int depth)
317+
{
318+
if (dimension == array.Rank)
319+
{
320+
Format(array.GetValue(indices), sb, depth + 1);
321+
return;
322+
}
323+
324+
sb.Append('[');
325+
for (int i = 0; i < array.GetLength(dimension); i++)
326+
{
327+
if (i > 0)
328+
{
329+
sb.Append(", ");
330+
}
331+
indices[dimension] = i;
332+
FormatNDArray(array, indices, dimension + 1, sb, depth + 1);
333+
}
334+
sb.Append(']');
335+
}
336+
337+
private static void FormatList(object obj, StringBuilder sb, int depth)
338+
{
339+
var list = (IList)obj;
340+
sb.Append('[');
341+
for (int i = 0; i < list.Count; i++)
342+
{
343+
if (i > 0)
344+
{
345+
sb.Append(", ");
346+
}
347+
Format(list[i], sb, depth + 1);
348+
}
349+
sb.Append(']');
350+
}
351+
352+
private static void FormatDictionary(object obj, StringBuilder sb, int depth)
353+
{
354+
var dict = (IDictionary)obj;
355+
sb.Append('{');
356+
bool first = true;
357+
foreach (DictionaryEntry entry in dict)
358+
{
359+
if (!first)
360+
{
361+
sb.Append(", ");
362+
}
363+
else
364+
{
365+
first = false;
366+
}
367+
Format(entry.Key, sb, depth + 1);
368+
sb.Append(": ");
369+
Format(entry.Value, sb, depth + 1);
370+
}
371+
sb.Append('}');
372+
}
373+
374+
private static void FormatTuple(object obj, StringBuilder sb, int depth)
375+
{
376+
var itemProperties = obj.GetType()
377+
.GetProperties()
378+
.Where(p => p.Name.StartsWith("Item"))
379+
.OrderBy(p => p.Name)
380+
.ToArray();
381+
382+
sb.Append('(');
383+
for (int i = 0; i < itemProperties.Length; i++)
384+
{
385+
if (i > 0)
386+
{
387+
sb.Append(", ");
388+
}
389+
Format(itemProperties[i].GetValue(obj), sb, depth + 1);
390+
}
391+
sb.Append(')');
392+
}
393+
394+
private static void FormatDefault(object obj, StringBuilder sb, int depth)
395+
{
396+
var properties = obj.GetType().GetProperties(BindingFlags.Public | BindingFlags.Instance);
397+
sb.Append('{');
398+
for (int i = 0; i < properties.Length; i++)
399+
{
400+
if (i > 0)
401+
{
402+
sb.Append(", ");
403+
}
404+
sb.Append('"').Append(properties[i].Name).Append("\": ");
405+
Format(properties[i].GetValue(obj), sb, depth + 1);
406+
}
407+
sb.Append('}');
408+
}
227409
}
228410
}

0 commit comments

Comments
 (0)