Skip to content

Commit 9ed11b9

Browse files
michaelgsharptannergoodingericstj
authored
Tensor<T> select. Iteration along dimensions (#113697)
* base code in place * enum tests * Update to match the approved API refinements * Update the suppression file * Update reference assembly and suppressions --------- Co-authored-by: Tanner Gooding <tagoo@outlook.com> Co-authored-by: Eric StJohn <ericstj@microsoft.com>
1 parent b146d75 commit 9ed11b9

18 files changed

Lines changed: 924 additions & 397 deletions

eng/resolveContract.targets

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@
126126
<PropertyGroup>
127127
<GenAPIExcludeAttributesList>$(RepositoryEngineeringDir)DefaultGenApiDocIds.txt</GenAPIExcludeAttributesList>
128128
<GenAPIHeaderFile>$(RepositoryEngineeringDir)LicenseHeader.txt</GenAPIHeaderFile>
129-
<GenAPITargetPath>$([MSBuild]::NormalizePath('$(MSBuildProjectDirectory)', '..', 'ref', '$(AssemblyName).cs'))</GenAPITargetPath>
129+
<GenAPITargetPath Condition="'$(GenAPITargetPath)' == ''">$([MSBuild]::NormalizePath('$(MSBuildProjectDirectory)', '..', 'ref', '$(AssemblyName).cs'))</GenAPITargetPath>
130130
<GenAPILangVersion Condition="'$(LangVersion)' != ''">$(LangVersion)</GenAPILangVersion>
131131
<ProjectForGenAPIDocIdGeneration Condition="'$(ProjectForGenAPIDocIdGeneration)' == ''">$(CoreLibProject)</ProjectForGenAPIDocIdGeneration>
132132
</PropertyGroup>

src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.csproj

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
<PropertyGroup>
44
<TargetFrameworks>$(NetCoreAppCurrent);$(NetCoreAppPrevious);$(NetCoreAppMinimum);netstandard2.0;$(NetFrameworkMinimum)</TargetFrameworks>
55
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
6+
<!-- SYSLIB5001: Tensor<T> and related APIs in System.Numerics.Tensors are experimental in .NET 9 -->
7+
<NoWarn>$(NoWarn);SYSLIB5001</NoWarn>
68
</PropertyGroup>
79

810
<ItemGroup>
@@ -21,4 +23,4 @@
2123
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
2224
</ItemGroup>
2325

24-
</Project>
26+
</Project>

src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.net9.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,20 @@ public static partial class TensorPrimitives
1111
public static void ConvertToIntegerNative<TFrom, TTo>(System.ReadOnlySpan<TFrom> source, System.Span<TTo> destination) where TFrom : System.Numerics.IFloatingPoint<TFrom> where TTo : System.Numerics.IBinaryInteger<TTo> { }
1212
public static void ConvertToInteger<TFrom, TTo>(System.ReadOnlySpan<TFrom> source, System.Span<TTo> destination) where TFrom : System.Numerics.IFloatingPoint<TFrom> where TTo : System.Numerics.IBinaryInteger<TTo> { }
1313
}
14+
public readonly ref partial struct ReadOnlyTensorDimensionSpan<T>
15+
{
16+
public ref partial struct Enumerator : System.Collections.Generic.IEnumerator<System.Numerics.Tensors.ReadOnlyTensorSpan<T>>, System.Collections.IEnumerator, System.IDisposable
17+
{
18+
readonly object? System.Collections.IEnumerator.Current { get { throw null; } }
19+
void System.IDisposable.Dispose() { }
20+
}
21+
}
22+
public readonly ref partial struct TensorDimensionSpan<T>
23+
{
24+
public ref partial struct Enumerator : System.Collections.Generic.IEnumerator<System.Numerics.Tensors.TensorSpan<T>>, System.Collections.IEnumerator, System.IDisposable
25+
{
26+
readonly object? System.Collections.IEnumerator.Current { get { throw null; } }
27+
void System.IDisposable.Dispose() { }
28+
}
29+
}
1430
}

src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs

Lines changed: 128 additions & 95 deletions
Large diffs are not rendered by default.

src/libraries/System.Numerics.Tensors/src/CompatibilitySuppressions.xml

Lines changed: 27 additions & 251 deletions
Large diffs are not rendered by default.
Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,49 @@
1-
M:System.Numerics.Tensors.TensorPrimitives.ConvertToHalf(System.ReadOnlySpan{System.Single},System.Span{System.Half})
2-
M:System.Numerics.Tensors.TensorPrimitives.ConvertToSingle(System.ReadOnlySpan{System.Half},System.Span{System.Single})
1+
M:System.Numerics.Tensors.TensorPrimitives.Abs(System.ReadOnlySpan{System.Single},System.Span{System.Single})
2+
M:System.Numerics.Tensors.TensorPrimitives.Add(System.ReadOnlySpan{System.Single},System.ReadOnlySpan{System.Single},System.Span{System.Single})
3+
M:System.Numerics.Tensors.TensorPrimitives.Add(System.ReadOnlySpan{System.Single},System.Single,System.Span{System.Single})
4+
M:System.Numerics.Tensors.TensorPrimitives.AddMultiply(System.ReadOnlySpan{System.Single},System.ReadOnlySpan{System.Single},System.ReadOnlySpan{System.Single},System.Span{System.Single})
5+
M:System.Numerics.Tensors.TensorPrimitives.AddMultiply(System.ReadOnlySpan{System.Single},System.ReadOnlySpan{System.Single},System.Single,System.Span{System.Single})
6+
M:System.Numerics.Tensors.TensorPrimitives.AddMultiply(System.ReadOnlySpan{System.Single},System.Single,System.ReadOnlySpan{System.Single},System.Span{System.Single})
7+
M:System.Numerics.Tensors.TensorPrimitives.Cosh(System.ReadOnlySpan{System.Single},System.Span{System.Single})
8+
M:System.Numerics.Tensors.TensorPrimitives.CosineSimilarity(System.ReadOnlySpan{System.Single},System.ReadOnlySpan{System.Single})
9+
M:System.Numerics.Tensors.TensorPrimitives.Distance(System.ReadOnlySpan{System.Single},System.ReadOnlySpan{System.Single})
10+
M:System.Numerics.Tensors.TensorPrimitives.Divide(System.ReadOnlySpan{System.Single},System.ReadOnlySpan{System.Single},System.Span{System.Single})
11+
M:System.Numerics.Tensors.TensorPrimitives.Divide(System.ReadOnlySpan{System.Single},System.Single,System.Span{System.Single})
12+
M:System.Numerics.Tensors.TensorPrimitives.Dot(System.ReadOnlySpan{System.Single},System.ReadOnlySpan{System.Single})
13+
M:System.Numerics.Tensors.TensorPrimitives.Exp(System.ReadOnlySpan{System.Single},System.Span{System.Single})
14+
M:System.Numerics.Tensors.TensorPrimitives.IndexOfMax(System.ReadOnlySpan{System.Single})
15+
M:System.Numerics.Tensors.TensorPrimitives.IndexOfMaxMagnitude(System.ReadOnlySpan{System.Single})
16+
M:System.Numerics.Tensors.TensorPrimitives.IndexOfMin(System.ReadOnlySpan{System.Single})
17+
M:System.Numerics.Tensors.TensorPrimitives.IndexOfMinMagnitude(System.ReadOnlySpan{System.Single})
18+
M:System.Numerics.Tensors.TensorPrimitives.Log(System.ReadOnlySpan{System.Single},System.Span{System.Single})
19+
M:System.Numerics.Tensors.TensorPrimitives.Log2(System.ReadOnlySpan{System.Single},System.Span{System.Single})
20+
M:System.Numerics.Tensors.TensorPrimitives.Max(System.ReadOnlySpan{System.Single})
21+
M:System.Numerics.Tensors.TensorPrimitives.Max(System.ReadOnlySpan{System.Single},System.ReadOnlySpan{System.Single},System.Span{System.Single})
22+
M:System.Numerics.Tensors.TensorPrimitives.MaxMagnitude(System.ReadOnlySpan{System.Single})
23+
M:System.Numerics.Tensors.TensorPrimitives.MaxMagnitude(System.ReadOnlySpan{System.Single},System.ReadOnlySpan{System.Single},System.Span{System.Single})
24+
M:System.Numerics.Tensors.TensorPrimitives.Min(System.ReadOnlySpan{System.Single})
25+
M:System.Numerics.Tensors.TensorPrimitives.Min(System.ReadOnlySpan{System.Single},System.ReadOnlySpan{System.Single},System.Span{System.Single})
26+
M:System.Numerics.Tensors.TensorPrimitives.MinMagnitude(System.ReadOnlySpan{System.Single})
27+
M:System.Numerics.Tensors.TensorPrimitives.MinMagnitude(System.ReadOnlySpan{System.Single},System.ReadOnlySpan{System.Single},System.Span{System.Single})
28+
M:System.Numerics.Tensors.TensorPrimitives.Multiply(System.ReadOnlySpan{System.Single},System.ReadOnlySpan{System.Single},System.Span{System.Single})
29+
M:System.Numerics.Tensors.TensorPrimitives.Multiply(System.ReadOnlySpan{System.Single},System.Single,System.Span{System.Single})
30+
M:System.Numerics.Tensors.TensorPrimitives.MultiplyAdd(System.ReadOnlySpan{System.Single},System.ReadOnlySpan{System.Single},System.ReadOnlySpan{System.Single},System.Span{System.Single})
31+
M:System.Numerics.Tensors.TensorPrimitives.MultiplyAdd(System.ReadOnlySpan{System.Single},System.ReadOnlySpan{System.Single},System.Single,System.Span{System.Single})
32+
M:System.Numerics.Tensors.TensorPrimitives.MultiplyAdd(System.ReadOnlySpan{System.Single},System.Single,System.ReadOnlySpan{System.Single},System.Span{System.Single})
33+
M:System.Numerics.Tensors.TensorPrimitives.Negate(System.ReadOnlySpan{System.Single},System.Span{System.Single})
34+
M:System.Numerics.Tensors.TensorPrimitives.Norm(System.ReadOnlySpan{System.Single})
35+
M:System.Numerics.Tensors.TensorPrimitives.Product(System.ReadOnlySpan{System.Single})
36+
M:System.Numerics.Tensors.TensorPrimitives.ProductOfDifferences(System.ReadOnlySpan{System.Single},System.ReadOnlySpan{System.Single})
37+
M:System.Numerics.Tensors.TensorPrimitives.ProductOfSums(System.ReadOnlySpan{System.Single},System.ReadOnlySpan{System.Single})
38+
M:System.Numerics.Tensors.TensorPrimitives.Sigmoid(System.ReadOnlySpan{System.Single},System.Span{System.Single})
39+
M:System.Numerics.Tensors.TensorPrimitives.Sinh(System.ReadOnlySpan{System.Single},System.Span{System.Single})
40+
M:System.Numerics.Tensors.TensorPrimitives.SoftMax(System.ReadOnlySpan{System.Single},System.Span{System.Single})
41+
M:System.Numerics.Tensors.TensorPrimitives.Subtract(System.ReadOnlySpan{System.Single},System.ReadOnlySpan{System.Single},System.Span{System.Single})
42+
M:System.Numerics.Tensors.TensorPrimitives.Subtract(System.ReadOnlySpan{System.Single},System.Single,System.Span{System.Single})
43+
M:System.Numerics.Tensors.TensorPrimitives.Sum(System.ReadOnlySpan{System.Single})
44+
M:System.Numerics.Tensors.TensorPrimitives.SumOfMagnitudes(System.ReadOnlySpan{System.Single})
45+
M:System.Numerics.Tensors.TensorPrimitives.SumOfSquares(System.ReadOnlySpan{System.Single})
46+
M:System.Numerics.Tensors.TensorPrimitives.Tanh(System.ReadOnlySpan{System.Single},System.Span{System.Single})
47+
48+
M:System.Numerics.Tensors.TensorPrimitives.ConvertToIntegerNative``2(System.ReadOnlySpan{``0},System.Span{``1})
49+
M:System.Numerics.Tensors.TensorPrimitives.ConvertToInteger``2(System.ReadOnlySpan{``0},System.Span{``1})

src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
<IsPackable>true</IsPackable>
77
<PackageDescription>Provides support for operating over tensors.</PackageDescription>
88
<GenAPIExcludeApiList>ReferenceAssemblyExclusions.txt</GenAPIExcludeApiList>
9+
<GenAPITargetPath>$([MSBuild]::NormalizePath('$(MSBuildProjectDirectory)', '..', 'ref', '$(AssemblyName).netcore.cs'))</GenAPITargetPath>
910
<!-- SYSLIB5001: Tensor<T> and related APIs in System.Numerics.Tensors are experimental in .NET 9 -->
1011
<NoWarn>$(NoWarn);SYSLIB5001</NoWarn>
1112
</PropertyGroup>
@@ -36,9 +37,11 @@
3637
<Compile Include="System\Numerics\Tensors\netcore\IReadOnlyTensor_1.cs" />
3738
<Compile Include="System\Numerics\Tensors\netcore\ITensor.cs" />
3839
<Compile Include="System\Numerics\Tensors\netcore\ITensor_1.cs" />
40+
<Compile Include="System\Numerics\Tensors\netcore\ReadOnlyTensorDimensionSpan_1.cs" />
3941
<Compile Include="System\Numerics\Tensors\netcore\ReadOnlyTensorSpan_1.cs" />
4042
<Compile Include="System\Numerics\Tensors\netcore\Tensor.cs" />
4143
<Compile Include="System\Numerics\Tensors\netcore\Tensor_1.cs" />
44+
<Compile Include="System\Numerics\Tensors\netcore\TensorDimensionSpan_1.cs" />
4245
<Compile Include="System\Numerics\Tensors\netcore\TensorOperation.cs" />
4346
<Compile Include="System\Numerics\Tensors\netcore\TensorPrimitives.Abs.cs" />
4447
<Compile Include="System\Numerics\Tensors\netcore\TensorPrimitives.Acos.cs" />
@@ -167,7 +170,7 @@
167170
<Compile Include="System\Numerics\Tensors\netcore\TensorPrimitives.Truncate.cs" />
168171
<Compile Include="System\Numerics\Tensors\netcore\TensorPrimitives.Xor.cs" />
169172
<Compile Include="System\Numerics\Tensors\netcore\TensorShape.cs" />
170-
<Compile Include="System\Numerics\Tensors\netcore\TensorSpan.cs" />
173+
<Compile Include="System\Numerics\Tensors\netcore\TensorSpan_1.cs" />
171174
<Compile Include="System\Numerics\Tensors\netcore\TensorSpanDebugView.cs" />
172175
</ItemGroup>
173176

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/IReadOnlyTensor_1.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ public interface IReadOnlyTensor<TSelf, T> : IReadOnlyTensor, IEnumerable<T>
6262
/// <remarks>This method copies all of the source tensor to <paramref name="destination" /> even if they overlap.</remarks>
6363
void FlattenTo(scoped Span<T> destination);
6464

65+
/// <summary>Returns a span that can be used to access the flattened elements for a given dimension.</summary>
66+
/// <param name="dimension">The dimension for which the span should be created.</param>
67+
/// <returns>A span that can be used to access the flattened elements for a given dimension.</returns>
68+
ReadOnlyTensorDimensionSpan<T> GetDimensionSpan(int dimension);
69+
6570
/// <summary>Returns a reference to an object of type <typeparamref name="T" /> that can be used for pinning.</summary>
6671
/// <returns>A reference to the element of the tensor at index 0, or <c>null</c> if the tensor is empty.</returns>
6772
/// <remarks>This method is intended to support .NET compilers and is not intended to be called by user code.</remarks>

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/ITensor_1.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ public interface ITensor<TSelf, T> : ITensor, IReadOnlyTensor<TSelf, T>
8787
/// <inheritdoc cref="ITensor.Fill(object)" />
8888
void Fill(T value);
8989

90+
/// <inheritdoc cref="IReadOnlyTensor{TSelf, T}.GetDimensionSpan(int)" />
91+
new TensorDimensionSpan<T> GetDimensionSpan(int dimension);
92+
9093
/// <inheritdoc cref="IReadOnlyTensor{TSelf, T}.GetPinnableReference" />
9194
new ref T GetPinnableReference();
9295
}
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System.Collections;
5+
using System.Collections.Generic;
6+
using System.Runtime.CompilerServices;
7+
using System.Runtime.InteropServices;
8+
9+
namespace System.Numerics.Tensors
10+
{
11+
/// <summary>Represents the slices that exist within a dimension of a tensor span.</summary>
12+
/// <typeparam name="T">The type of the elements within the tensor span.</typeparam>
13+
public readonly ref struct ReadOnlyTensorDimensionSpan<T>
14+
{
15+
private readonly ReadOnlyTensorSpan<T> _tensor;
16+
private readonly nint _length;
17+
private readonly int _dimension;
18+
private readonly TensorShape _sliceShape;
19+
20+
internal ReadOnlyTensorDimensionSpan(ReadOnlyTensorSpan<T> tensor, int dimension)
21+
{
22+
if ((uint)dimension >= tensor.Rank)
23+
{
24+
ThrowHelper.ThrowArgumentOutOfRangeException();
25+
}
26+
dimension += 1;
27+
28+
_tensor = tensor;
29+
_length = TensorPrimitives.Product(tensor.Lengths[..dimension]);
30+
_dimension = dimension;
31+
_sliceShape = TensorShape.Create((dimension != tensor.Rank) ? tensor.Lengths[dimension..] : [1], tensor.Strides[dimension..]);
32+
}
33+
34+
/// <summary>Gets the length of the tensor dimension span.</summary>
35+
public nint Length => _length;
36+
37+
/// <summary>Gets the tensor span representing a slice of the tracked dimension using the specified index.</summary>
38+
/// <param name="index">The index of the tensor span slice to retrieve within the tracked dimension.</param>
39+
/// <returns>The tensor span representing a slice of the tracked dimension using <paramref name="index" />.</returns>
40+
public ReadOnlyTensorSpan<T> this[nint index]
41+
{
42+
get
43+
{
44+
if ((nuint)index >= (nuint)_length)
45+
{
46+
ThrowHelper.ThrowArgumentOutOfRangeException();
47+
}
48+
49+
nint linearOffset = _tensor._shape.GetLinearOffset(index, _dimension);
50+
return new ReadOnlyTensorSpan<T>(ref Unsafe.Add(ref _tensor._reference, linearOffset), _sliceShape);
51+
}
52+
}
53+
54+
/// <summary>Gets an enumerator for the readonly tensor dimension span.</summary>
55+
public Enumerator GetEnumerator() => new Enumerator(this);
56+
57+
/// <summary>Enumerates the spans of a tensor dimension span.</summary>
58+
public ref struct Enumerator
59+
#if NET9_0_OR_GREATER
60+
: IEnumerator<ReadOnlyTensorSpan<T>>
61+
#endif
62+
{
63+
private readonly ReadOnlyTensorDimensionSpan<T> _span;
64+
private nint _index;
65+
66+
internal Enumerator(ReadOnlyTensorDimensionSpan<T> span)
67+
{
68+
_span = span;
69+
_index = -1;
70+
}
71+
72+
/// <summary>Gets the span at the current position of the enumerator.</summary>
73+
public readonly ReadOnlyTensorSpan<T> Current => _span[_index];
74+
75+
/// <summary>Advances the enumerator to the next element of the tensor span.</summary>
76+
public bool MoveNext()
77+
{
78+
nint index = _index + 1;
79+
80+
if (index < _span.Length)
81+
{
82+
_index = index;
83+
return true;
84+
}
85+
return false;
86+
}
87+
88+
/// <summary>Sets the enumerator to its initial position, which is before the first element in the tensor span.</summary>
89+
public void Reset()
90+
{
91+
_index = -1;
92+
}
93+
94+
#if NET9_0_OR_GREATER
95+
//
96+
// IDisposable
97+
//
98+
99+
void IDisposable.Dispose() { }
100+
101+
//
102+
// IEnumerator
103+
//
104+
105+
readonly object? IEnumerator.Current => throw new NotSupportedException();
106+
#endif
107+
}
108+
}
109+
}

0 commit comments

Comments
 (0)