Skip to content

Commit 3701fc5

Browse files
Add missing accumulate argument for index_put_ method (#1460)
--------- Co-authored-by: Ozan Aydin <[email protected]>
1 parent 84e227b commit 3701fc5

File tree

6 files changed

+168
-41
lines changed

6 files changed

+168
-41
lines changed

RELEASENOTES.md

+4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ __Bug Fixes__:
77

88
#1426 Sequential.eval() does not put model into eval mode<br/>
99
`torch.optim.lr_scheduler.LinearLR` `end_factor` default has been corrected, is now 1.0.<br/>
10+
11+
__API Changes__:
12+
13+
#1374 Add accumulate to index_put_<br/>
1014
`torch.optim.lr_scheduler.PolynomialLR` `power` type has been corrected, is now double.<br/>
1115

1216
# NuGet Version 0.105.0

src/Native/LibTorchSharp/THSTensor.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,31 @@ void THSTensor_index_put_(Tensor tensor,
837837
CATCH(tensor->index_put_(indices, *value););
838838
}
839839

840+
void THSTensor_index_put_(Tensor tensor,
841+
const int64_t* indexStarts,
842+
const int64_t* indexEnds,
843+
const int64_t* indexSteps,
844+
const Tensor* indexTensors,
845+
const int indicesLength,
846+
const Tensor value,
847+
const bool accumulate)
848+
{
849+
at::indexing::TensorIndex* indicesArray = (at::indexing::TensorIndex*)alloca(indicesLength * sizeof(at::indexing::TensorIndex));
850+
memset(indicesArray, 0, indicesLength * sizeof(at::indexing::TensorIndex));
851+
completeTensorIndices(indexStarts, indexEnds, indexSteps, indexTensors, indicesArray, indicesLength);
852+
auto indices = at::ArrayRef<at::indexing::TensorIndex>(indicesArray, indicesLength);
853+
if (accumulate) {
854+
c10::List<std::optional<at::Tensor>> indicesList = c10::List<std::optional<at::Tensor>>();
855+
for (int i = 0; i < indicesLength; i++) {
856+
indicesList.push_back(c10::optional<at::Tensor>(*indexTensors[i]));
857+
}
858+
CATCH(tensor->index_put_(indicesList, *value, accumulate););
859+
}
860+
else {
861+
CATCH(tensor->index_put_(indices, *value););
862+
}
863+
}
864+
840865
void THSTensor_index_put_scalar_(Tensor tensor,
841866
const int64_t* indexStarts,
842867
const int64_t* indexEnds,

src/Native/LibTorchSharp/THSTensor.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,8 @@ EXPORT_API(void) THSTensor_index_put_(Tensor tensor,
683683
const int64_t* indexSteps,
684684
const Tensor* indexTensors,
685685
const int indicesLength,
686-
const Tensor value);
686+
const Tensor value,
687+
const bool accumulate = false);
687688

688689
EXPORT_API(Tensor) THSTensor_index_select(Tensor tensor, int64_t dim, Tensor index);
689690

src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
410410
internal static extern void THSTensor_index_put_scalar_(IntPtr tensor, IntPtr indexStarts, IntPtr indexEnds, IntPtr indexSteps, IntPtr indexTensors, int indicesLength, IntPtr value);
411411

412412
[DllImport("LibTorchSharp")]
413-
internal static extern void THSTensor_index_put_(IntPtr tensor, IntPtr indexStarts, IntPtr indexEnds, IntPtr indexSteps, IntPtr indexTensors, int indicesLength, IntPtr value);
413+
internal static extern void THSTensor_index_put_(IntPtr tensor, IntPtr indexStarts, IntPtr indexEnds, IntPtr indexSteps, IntPtr indexTensors, int indicesLength, IntPtr value, [MarshalAs(UnmanagedType.U1)] bool accumulate);
414414

415415
[DllImport("LibTorchSharp")]
416416
internal static extern IntPtr THSTensor_get1(IntPtr handle, long i1);

src/TorchSharp/Tensor/Tensor.cs

+24-1
Original file line numberDiff line numberDiff line change
@@ -1604,7 +1604,25 @@ public Tensor index_put_(Tensor value, params TensorIndex[] indices)
16041604
unsafe {
16051605
fixed (long* ptrKindAndStarts = arrKindAndStarts, ptrStops = arrStops, ptrSteps = arrSteps) {
16061606
fixed (IntPtr* ptrTensors = arrTensors) {
1607-
NativeMethods.THSTensor_index_put_(Handle, (IntPtr)ptrKindAndStarts, (IntPtr)ptrStops, (IntPtr)ptrSteps, (IntPtr)ptrTensors, indices.Length, value.Handle);
1607+
NativeMethods.THSTensor_index_put_(Handle, (IntPtr)ptrKindAndStarts, (IntPtr)ptrStops, (IntPtr)ptrSteps, (IntPtr)ptrTensors, indices.Length, value.Handle, false);
1608+
CheckForErrors();
1609+
GC.KeepAlive(indices); // don't release or finalize Tensor indices whose handles have been put into ptrTensors
1610+
GC.KeepAlive(value);
1611+
return this;
1612+
}
1613+
}
1614+
}
1615+
}
1616+
1617+
public Tensor index_put_(Tensor value, TensorIndex[] indices, bool accumulate = false)
1618+
{
1619+
EncodeIndices(indices, out var arrKindAndStarts, out var arrStops, out var arrSteps, out var arrTensors);
1620+
if (accumulate && arrTensors == null)
1621+
throw new Exception("Invalid 'indices' parameter. Must be an array of TensorIndex objects containing tensors with indices that match the shape of the tensor to update");
1622+
unsafe {
1623+
fixed (long* ptrKindAndStarts = arrKindAndStarts, ptrStops = arrStops, ptrSteps = arrSteps) {
1624+
fixed (IntPtr* ptrTensors = arrTensors) {
1625+
NativeMethods.THSTensor_index_put_(Handle, (IntPtr)ptrKindAndStarts, (IntPtr)ptrStops, (IntPtr)ptrSteps, (IntPtr)ptrTensors, indices.Length, value.Handle, accumulate);
16081626
CheckForErrors();
16091627
GC.KeepAlive(indices); // don't release or finalize Tensor indices whose handles have been put into ptrTensors
16101628
GC.KeepAlive(value);
@@ -1622,6 +1640,11 @@ public Tensor index_put_(Tensor value, params Tensor[] indices)
16221640
return index_put_(value, indices.Select(t => TensorIndex.Tensor(t)).ToArray());
16231641
}
16241642

1643+
public Tensor index_put_(Tensor value, Tensor[] indices, bool accumulate = false)
1644+
{
1645+
return index_put_(value, indices.Select(t => TensorIndex.Tensor(t)).ToArray(), accumulate);
1646+
}
1647+
16251648

16261649
/// <summary>
16271650
/// Index into the tensor using Python-like indexing expressions and place a scalar tensor at the index.

test/TorchSharpTest/TestTorchTensor.cs

+112-38
Original file line numberDiff line numberDiff line change
@@ -290,17 +290,13 @@ public void TestTensorDefaultPrint()
290290
Tensor t = torch.zeros(2, 2);
291291
string expectedOutput = t.ToString(TensorStringStyle.Default) + Environment.NewLine;
292292
var originalOut = Console.Out;
293-
using (var sw = new StringWriter())
294-
{
295-
try
296-
{
293+
using (var sw = new StringWriter()) {
294+
try {
297295
Console.SetOut(sw);
298296
t.print();
299297
var result = sw.ToString();
300298
Assert.Equal(expectedOutput, result);
301-
}
302-
finally
303-
{
299+
} finally {
304300
Console.SetOut(originalOut);
305301
}
306302
}
@@ -807,7 +803,7 @@ public void FromArrayFactory()
807803
() => Assert.Equal(1, t.ndim),
808804
() => Assert.Equal(ScalarType.Byte, t.dtype));
809805
}
810-
806+
811807
{
812808
var array = new Memory<long>(new long[8]);
813809
using var t = torch.tensor(array, new long[] { 8 }, device: device);
@@ -816,11 +812,11 @@ public void FromArrayFactory()
816812
() => Assert.Equal(1, t.ndim),
817813
() => Assert.Equal(ScalarType.Int64, t.dtype));
818814
}
819-
815+
820816
{
821817
var array = new long[18];
822818
array[5] = 17;
823-
var mem = new Memory<long>(array,4,10);
819+
var mem = new Memory<long>(array, 4, 10);
824820
using var t = torch.tensor(mem, new long[] { 8 }, device: device);
825821
Assert.Multiple(
826822
() => Assert.Equal(device.type, t.device_type),
@@ -3165,6 +3161,86 @@ public void IndexFill2()
31653161
() => Assert.Equal(1.0, x[2, 2].ToSingle()));
31663162
}
31673163

3164+
[Fact]
3165+
[TestOf(nameof(Tensor.index_put_))]
3166+
public void IndexPutOneValueOneIndex()
3167+
{
3168+
using var _ = NewDisposeScope();
3169+
3170+
var tensor = ones(5);
3171+
var indices = new TensorIndex[] { TensorIndex.Tensor(1) };
3172+
var values = torch.tensor(5.0f);
3173+
3174+
// default accumulate value is false, should only replace value at index 1 with 5
3175+
tensor.index_put_(values, indices);
3176+
Assert.True(tensor.Equals(torch.tensor(new float[] { 1.0f, 5.0f, 1.0f, 1.0f, 1.0f })));
3177+
3178+
tensor = ones(5);
3179+
// accumulate value is false, explicitly set, should only replace value at index 1 with 5
3180+
tensor.index_put_(values, indices, accumulate: false);
3181+
Assert.True(tensor.Equals(torch.tensor(new float[] { 1.0f, 5.0f, 1.0f, 1.0f, 1.0f })));
3182+
3183+
tensor = ones(5);
3184+
// accumulate value is true, should add value to index 1, 1 + 5 = 6
3185+
tensor.index_put_(values, indices, accumulate: true);
3186+
Assert.True(tensor.Equals(torch.tensor(new float[] { 1.0f, 6.0f, 1.0f, 1.0f, 1.0f })));
3187+
}
3188+
3189+
[Fact]
3190+
[TestOf(nameof(Tensor.index_put_))]
3191+
public void IndexPutOneValueMultipleIndexes()
3192+
{
3193+
using var _ = NewDisposeScope();
3194+
3195+
var tensor = ones(5);
3196+
var indices = new TensorIndex[] { TensorIndex.Tensor(new long[] {1, 2}) };
3197+
var values = torch.tensor(10.0f);
3198+
3199+
// default accumulate value is false, should only replace value at given indexes
3200+
tensor.index_put_(values, indices);
3201+
Assert.True(tensor.Equals(torch.tensor(new float[] { 1.0f, 10.0f, 10.0f, 1.0f, 1.0f })));
3202+
3203+
tensor = ones(5);
3204+
// accumulate value is true, should add value to given indexes
3205+
tensor.index_put_(values, indices, true);
3206+
Assert.True(tensor.Equals(torch.tensor(new float[] { 1.0f, 11.0f, 11.0f, 1.0f, 1.0f })));
3207+
3208+
// accumulate value is false, explicitly set, should replace value at given indexes
3209+
tensor.index_put_(values, indices, false);
3210+
Assert.True(tensor.Equals(torch.tensor(new float[] { 1.0f, 10.0f, 10.0f, 1.0f, 1.0f })));
3211+
}
3212+
3213+
[Fact]
3214+
[TestOf(nameof(Tensor.index_put_))]
3215+
public void IndexPutMultipleValuesMultipleIndexes()
3216+
{
3217+
using var _ = NewDisposeScope();
3218+
3219+
var tensor = ones(5, 2);
3220+
var indices = new TensorIndex[]
3221+
{
3222+
TensorIndex.Tensor(new long[] { 1, 2, 0, 3 }), // for first tensor dimension (row)
3223+
TensorIndex.Tensor(new long[] { 0, 1, 0, 0 }) // for second tensor dimension (column)
3224+
};
3225+
var values = torch.tensor(new float[] { 3.0f, 4.0f, 5.0f, 10f });
3226+
3227+
// default accumulate value is false, should only replace values at given indices with 3, 4, 5, 10
3228+
// Indexes to be replaced: (1, 0) -> 3.0, (2, 1) -> 4.0, (0, 0) -> 5.0, (3, 0) -> 10.0
3229+
tensor.index_put_(values, indices);
3230+
Assert.True(tensor.Equals(torch.tensor(new float[,] { { 5.0f, 1.0f }, { 3.0f, 1.0f }, { 1.0f, 4.0f }, { 10.0f, 1.0f }, { 1.0f, 1.0f } })));
3231+
3232+
tensor = ones(5, 2);
3233+
// accumulate value is true, should perform addition at given indices, 1 + 3 = 4, 1 + 4 = 5, 1 + 5 = 6, 1 + 10 = 11
3234+
// Indexes to be replaced: (1, 0) -> 4.0, (2, 1) -> 5.0, (0, 0) -> 6.0, (3, 0) -> 11.0
3235+
tensor.index_put_(values, indices, true);
3236+
Assert.True(tensor.Equals(torch.tensor(new float[,] { { 6.0f, 1.0f }, { 4.0f, 1.0f }, { 1.0f, 5.0f }, { 11.0f, 1.0f }, { 1.0f, 1.0f } })));
3237+
3238+
// accumulate value is false, explicitly set, should only replace values at given indices with 3, 4, 5, 10
3239+
// Indexes to be replaced: (1, 0) -> 3.0, (2, 1) -> 4.0, (0, 0) -> 5.0, (3, 0) -> 10.0
3240+
tensor.index_put_(values, indices, false);
3241+
Assert.True(tensor.Equals(torch.tensor(new float[,] { { 5.0f, 1.0f }, { 3.0f, 1.0f }, { 1.0f, 4.0f }, { 10.0f, 1.0f }, { 1.0f, 1.0f } })));
3242+
}
3243+
31683244
[Fact]
31693245
[TestOf(nameof(TensorExtensionMethods.ToTensor))]
31703246
public void ScalarToTensor()
@@ -3257,7 +3333,7 @@ public void ScalarToTensor3()
32573333
[TestOf(nameof(Tensor))]
32583334
public void ScalarToTensorDoesNotLeakMemory()
32593335
{
3260-
AssertTensorDoesNotLeak(()=>{
3336+
AssertTensorDoesNotLeak(() => {
32613337
Tensor tensor = 1;
32623338
return tensor;
32633339
});
@@ -3273,20 +3349,20 @@ public void ScalarToTensorDoesNotLeakMemory()
32733349
[TestOf(nameof(Tensor))]
32743350
public void ScalarArrayToTensorDoesNotLeakMemory()
32753351
{
3276-
AssertTensorDoesNotLeak(() => (new byte[]{1}).ToTensor(new long[]{1}));
3277-
AssertTensorDoesNotLeak(() => (new sbyte[]{-1}).ToTensor(new long[]{1}));
3278-
AssertTensorDoesNotLeak(() => (new short[]{-1}).ToTensor(new long[]{1}));
3279-
AssertTensorDoesNotLeak(() => (new long[]{-1}).ToTensor(new long[]{1}));
3280-
AssertTensorDoesNotLeak(() => (new float[]{-1}).ToTensor(new long[]{1}));
3281-
AssertTensorDoesNotLeak(() => (new double[]{-1}).ToTensor(new long[]{1}));
3352+
AssertTensorDoesNotLeak(() => (new byte[] { 1 }).ToTensor(new long[] { 1 }));
3353+
AssertTensorDoesNotLeak(() => (new sbyte[] { -1 }).ToTensor(new long[] { 1 }));
3354+
AssertTensorDoesNotLeak(() => (new short[] { -1 }).ToTensor(new long[] { 1 }));
3355+
AssertTensorDoesNotLeak(() => (new long[] { -1 }).ToTensor(new long[] { 1 }));
3356+
AssertTensorDoesNotLeak(() => (new float[] { -1 }).ToTensor(new long[] { 1 }));
3357+
AssertTensorDoesNotLeak(() => (new double[] { -1 }).ToTensor(new long[] { 1 }));
32823358
}
32833359

32843360
[Fact]
32853361
[TestOf(nameof(Tensor))]
32863362
public void ComplexNumberOfDoubleDoesNotLeakMemory()
32873363
{
3288-
AssertTensorDoesNotLeak(() => ( torch.tensor((double)-1, (double)-2)));
3289-
AssertTensorDoesNotLeak(() => ( torch.tensor(((double)-1, (double)-2))));
3364+
AssertTensorDoesNotLeak(() => (torch.tensor((double)-1, (double)-2)));
3365+
AssertTensorDoesNotLeak(() => (torch.tensor(((double)-1, (double)-2))));
32903366
}
32913367

32923368
[Fact]
@@ -4106,7 +4182,7 @@ public void CastMoveAndDisposeAfter()
41064182
Assert.True(input.IsInvalid);
41074183
Assert.False(cast.IsInvalid);
41084184
// make sure we can access the values
4109-
Assert.Equal(1, cast[0].ToInt32());
4185+
Assert.Equal(1, cast[0].ToInt32());
41104186
}
41114187
if (torch.cuda.is_available()) {
41124188
{
@@ -8517,28 +8593,27 @@ public void DefaultDTypeCreation()
85178593
{
85188594
var dt = torch.get_default_dtype();
85198595

8520-
var t = torch.zeros(5,5);
8596+
var t = torch.zeros(5, 5);
85218597
Assert.Equal(torch.float32, t.dtype);
85228598

85238599
try {
8524-
torch.set_default_dtype(torch.float64);
8525-
8526-
t = torch.zeros(5,5);
8600+
torch.set_default_dtype(torch.float64);
8601+
8602+
t = torch.zeros(5, 5);
85278603
Assert.Equal(torch.float64, t.dtype);
85288604

8529-
t = torch.ones(5,5);
8605+
t = torch.ones(5, 5);
85308606
Assert.Equal(torch.float64, t.dtype);
85318607

8532-
t = torch.rand(5,5);
8608+
t = torch.rand(5, 5);
85338609
Assert.Equal(torch.float64, t.dtype);
85348610

8535-
t = torch.randn(5,5);
8611+
t = torch.randn(5, 5);
85368612
Assert.Equal(torch.float64, t.dtype);
85378613

85388614
t = torch.logspace(5, 15, 20);
85398615
Assert.Equal(torch.float64, t.dtype);
8540-
}
8541-
finally {
8616+
} finally {
85428617
torch.set_default_dtype(dt);
85438618
}
85448619
}
@@ -8548,28 +8623,27 @@ public void DefaultDeviceCreation()
85488623
{
85498624
var dt = torch.get_default_device();
85508625

8551-
var t = torch.zeros(5,5);
8626+
var t = torch.zeros(5, 5);
85528627
Assert.Equal(DeviceType.CPU, t.device_type);
85538628

85548629
try {
8555-
torch.set_default_device(torch.META);
8556-
8557-
t = torch.zeros(5,5);
8630+
torch.set_default_device(torch.META);
8631+
8632+
t = torch.zeros(5, 5);
85588633
Assert.Equal(DeviceType.META, t.device_type);
85598634

8560-
t = torch.ones(5,5);
8635+
t = torch.ones(5, 5);
85618636
Assert.Equal(DeviceType.META, t.device_type);
85628637

8563-
t = torch.rand(5,5);
8638+
t = torch.rand(5, 5);
85648639
Assert.Equal(DeviceType.META, t.device_type);
85658640

8566-
t = torch.randn(5,5);
8641+
t = torch.randn(5, 5);
85678642
Assert.Equal(DeviceType.META, t.device_type);
85688643

85698644
t = torch.logspace(5, 15, 20);
85708645
Assert.Equal(DeviceType.META, t.device_type);
8571-
}
8572-
finally {
8646+
} finally {
85738647
torch.set_default_device(dt);
85748648
}
85758649
}

0 commit comments

Comments
 (0)