@@ -290,17 +290,13 @@ public void TestTensorDefaultPrint()
290
290
Tensor t = torch . zeros ( 2 , 2 ) ;
291
291
string expectedOutput = t . ToString ( TensorStringStyle . Default ) + Environment . NewLine ;
292
292
var originalOut = Console . Out ;
293
- using ( var sw = new StringWriter ( ) )
294
- {
295
- try
296
- {
293
+ using ( var sw = new StringWriter ( ) ) {
294
+ try {
297
295
Console . SetOut ( sw ) ;
298
296
t . print ( ) ;
299
297
var result = sw . ToString ( ) ;
300
298
Assert . Equal ( expectedOutput , result ) ;
301
- }
302
- finally
303
- {
299
+ } finally {
304
300
Console . SetOut ( originalOut ) ;
305
301
}
306
302
}
@@ -807,7 +803,7 @@ public void FromArrayFactory()
807
803
( ) => Assert . Equal ( 1 , t . ndim ) ,
808
804
( ) => Assert . Equal ( ScalarType . Byte , t . dtype ) ) ;
809
805
}
810
-
806
+
811
807
{
812
808
var array = new Memory < long > ( new long [ 8 ] ) ;
813
809
using var t = torch . tensor ( array , new long [ ] { 8 } , device : device ) ;
@@ -816,11 +812,11 @@ public void FromArrayFactory()
816
812
( ) => Assert . Equal ( 1 , t . ndim ) ,
817
813
( ) => Assert . Equal ( ScalarType . Int64 , t . dtype ) ) ;
818
814
}
819
-
815
+
820
816
{
821
817
var array = new long [ 18 ] ;
822
818
array [ 5 ] = 17 ;
823
- var mem = new Memory < long > ( array , 4 , 10 ) ;
819
+ var mem = new Memory < long > ( array , 4 , 10 ) ;
824
820
using var t = torch . tensor ( mem , new long [ ] { 8 } , device : device ) ;
825
821
Assert . Multiple (
826
822
( ) => Assert . Equal ( device . type , t . device_type ) ,
@@ -3165,6 +3161,86 @@ public void IndexFill2()
3165
3161
( ) => Assert . Equal ( 1.0 , x [ 2 , 2 ] . ToSingle ( ) ) ) ;
3166
3162
}
3167
3163
3164
+ [ Fact ]
3165
+ [ TestOf ( nameof ( Tensor . index_put_ ) ) ]
3166
+ public void IndexPutOneValueOneIndex ( )
3167
+ {
3168
+ using var _ = NewDisposeScope ( ) ;
3169
+
3170
+ var tensor = ones ( 5 ) ;
3171
+ var indices = new TensorIndex [ ] { TensorIndex . Tensor ( 1 ) } ;
3172
+ var values = torch . tensor ( 5.0f ) ;
3173
+
3174
+ // default accumulate value is false, should only replace value at index 1 with 5
3175
+ tensor . index_put_ ( values , indices ) ;
3176
+ Assert . True ( tensor . Equals ( torch . tensor ( new float [ ] { 1.0f , 5.0f , 1.0f , 1.0f , 1.0f } ) ) ) ;
3177
+
3178
+ tensor = ones ( 5 ) ;
3179
+ // accumulate value is false, explicitly set, should only replace value at index 1 with 5
3180
+ tensor . index_put_ ( values , indices , accumulate : false ) ;
3181
+ Assert . True ( tensor . Equals ( torch . tensor ( new float [ ] { 1.0f , 5.0f , 1.0f , 1.0f , 1.0f } ) ) ) ;
3182
+
3183
+ tensor = ones ( 5 ) ;
3184
+ // accumulate value is true, should add value to index 1, 1 + 5 = 6
3185
+ tensor . index_put_ ( values , indices , accumulate : true ) ;
3186
+ Assert . True ( tensor . Equals ( torch . tensor ( new float [ ] { 1.0f , 6.0f , 1.0f , 1.0f , 1.0f } ) ) ) ;
3187
+ }
3188
+
3189
+ [ Fact ]
3190
+ [ TestOf ( nameof ( Tensor . index_put_ ) ) ]
3191
+ public void IndexPutOneValueMultipleIndexes ( )
3192
+ {
3193
+ using var _ = NewDisposeScope ( ) ;
3194
+
3195
+ var tensor = ones ( 5 ) ;
3196
+ var indices = new TensorIndex [ ] { TensorIndex . Tensor ( new long [ ] { 1 , 2 } ) } ;
3197
+ var values = torch . tensor ( 10.0f ) ;
3198
+
3199
+ // default accumulate value is false, should only replace value at given indexes
3200
+ tensor . index_put_ ( values , indices ) ;
3201
+ Assert . True ( tensor . Equals ( torch . tensor ( new float [ ] { 1.0f , 10.0f , 10.0f , 1.0f , 1.0f } ) ) ) ;
3202
+
3203
+ tensor = ones ( 5 ) ;
3204
+ // accumulate value is true, should add value to given indexes
3205
+ tensor . index_put_ ( values , indices , true ) ;
3206
+ Assert . True ( tensor . Equals ( torch . tensor ( new float [ ] { 1.0f , 11.0f , 11.0f , 1.0f , 1.0f } ) ) ) ;
3207
+
3208
+ // accumulate value is false, explicitly set, should replace value at given indexes
3209
+ tensor . index_put_ ( values , indices , false ) ;
3210
+ Assert . True ( tensor . Equals ( torch . tensor ( new float [ ] { 1.0f , 10.0f , 10.0f , 1.0f , 1.0f } ) ) ) ;
3211
+ }
3212
+
3213
+ [ Fact ]
3214
+ [ TestOf ( nameof ( Tensor . index_put_ ) ) ]
3215
+ public void IndexPutMultipleValuesMultipleIndexes ( )
3216
+ {
3217
+ using var _ = NewDisposeScope ( ) ;
3218
+
3219
+ var tensor = ones ( 5 , 2 ) ;
3220
+ var indices = new TensorIndex [ ]
3221
+ {
3222
+ TensorIndex . Tensor ( new long [ ] { 1 , 2 , 0 , 3 } ) , // for first tensor dimension (row)
3223
+ TensorIndex . Tensor ( new long [ ] { 0 , 1 , 0 , 0 } ) // for second tensor dimension (column)
3224
+ } ;
3225
+ var values = torch . tensor ( new float [ ] { 3.0f , 4.0f , 5.0f , 10f } ) ;
3226
+
3227
+ // default accumulate value is false, should only replace values at given indices with 3, 4, 5, 10
3228
+ // Indexes to be replaced: (1, 0) -> 3.0, (2, 1) -> 4.0, (0, 0) -> 5.0, (3, 0) -> 10.0
3229
+ tensor . index_put_ ( values , indices ) ;
3230
+ Assert . True ( tensor . Equals ( torch . tensor ( new float [ , ] { { 5.0f , 1.0f } , { 3.0f , 1.0f } , { 1.0f , 4.0f } , { 10.0f , 1.0f } , { 1.0f , 1.0f } } ) ) ) ;
3231
+
3232
+ tensor = ones ( 5 , 2 ) ;
3233
+ // accumulate value is true, should perform addition at given indices, 1 + 3 = 4, 1 + 4 = 5, 1 + 5 = 6, 1 + 10 = 11
3234
+ // Indexes to be replaced: (1, 0) -> 4.0, (2, 1) -> 5.0, (0, 0) -> 6.0, (3, 0) -> 11.0
3235
+ tensor . index_put_ ( values , indices , true ) ;
3236
+ Assert . True ( tensor . Equals ( torch . tensor ( new float [ , ] { { 6.0f , 1.0f } , { 4.0f , 1.0f } , { 1.0f , 5.0f } , { 11.0f , 1.0f } , { 1.0f , 1.0f } } ) ) ) ;
3237
+
3238
+ // accumulate value is false, explicitly set, should only replace values at given indices with 3, 4, 5, 10
3239
+ // Indexes to be replaced: (1, 0) -> 3.0, (2, 1) -> 4.0, (0, 0) -> 5.0, (3, 0) -> 10.0
3240
+ tensor . index_put_ ( values , indices , false ) ;
3241
+ Assert . True ( tensor . Equals ( torch . tensor ( new float [ , ] { { 5.0f , 1.0f } , { 3.0f , 1.0f } , { 1.0f , 4.0f } , { 10.0f , 1.0f } , { 1.0f , 1.0f } } ) ) ) ;
3242
+ }
3243
+
3168
3244
[ Fact ]
3169
3245
[ TestOf ( nameof ( TensorExtensionMethods . ToTensor ) ) ]
3170
3246
public void ScalarToTensor ( )
@@ -3257,7 +3333,7 @@ public void ScalarToTensor3()
3257
3333
[ TestOf ( nameof ( Tensor ) ) ]
3258
3334
public void ScalarToTensorDoesNotLeakMemory ( )
3259
3335
{
3260
- AssertTensorDoesNotLeak ( ( ) => {
3336
+ AssertTensorDoesNotLeak ( ( ) => {
3261
3337
Tensor tensor = 1 ;
3262
3338
return tensor ;
3263
3339
} ) ;
@@ -3273,20 +3349,20 @@ public void ScalarToTensorDoesNotLeakMemory()
3273
3349
[ TestOf ( nameof ( Tensor ) ) ]
3274
3350
public void ScalarArrayToTensorDoesNotLeakMemory ( )
3275
3351
{
3276
- AssertTensorDoesNotLeak ( ( ) => ( new byte [ ] { 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3277
- AssertTensorDoesNotLeak ( ( ) => ( new sbyte [ ] { - 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3278
- AssertTensorDoesNotLeak ( ( ) => ( new short [ ] { - 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3279
- AssertTensorDoesNotLeak ( ( ) => ( new long [ ] { - 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3280
- AssertTensorDoesNotLeak ( ( ) => ( new float [ ] { - 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3281
- AssertTensorDoesNotLeak ( ( ) => ( new double [ ] { - 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3352
+ AssertTensorDoesNotLeak ( ( ) => ( new byte [ ] { 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3353
+ AssertTensorDoesNotLeak ( ( ) => ( new sbyte [ ] { - 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3354
+ AssertTensorDoesNotLeak ( ( ) => ( new short [ ] { - 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3355
+ AssertTensorDoesNotLeak ( ( ) => ( new long [ ] { - 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3356
+ AssertTensorDoesNotLeak ( ( ) => ( new float [ ] { - 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3357
+ AssertTensorDoesNotLeak ( ( ) => ( new double [ ] { - 1 } ) . ToTensor ( new long [ ] { 1 } ) ) ;
3282
3358
}
3283
3359
3284
3360
[ Fact ]
3285
3361
[ TestOf ( nameof ( Tensor ) ) ]
3286
3362
public void ComplexNumberOfDoubleDoesNotLeakMemory ( )
3287
3363
{
3288
- AssertTensorDoesNotLeak ( ( ) => ( torch . tensor ( ( double ) - 1 , ( double ) - 2 ) ) ) ;
3289
- AssertTensorDoesNotLeak ( ( ) => ( torch . tensor ( ( ( double ) - 1 , ( double ) - 2 ) ) ) ) ;
3364
+ AssertTensorDoesNotLeak ( ( ) => ( torch . tensor ( ( double ) - 1 , ( double ) - 2 ) ) ) ;
3365
+ AssertTensorDoesNotLeak ( ( ) => ( torch . tensor ( ( ( double ) - 1 , ( double ) - 2 ) ) ) ) ;
3290
3366
}
3291
3367
3292
3368
[ Fact ]
@@ -4106,7 +4182,7 @@ public void CastMoveAndDisposeAfter()
4106
4182
Assert . True ( input . IsInvalid ) ;
4107
4183
Assert . False ( cast . IsInvalid ) ;
4108
4184
// make sure we can access the values
4109
- Assert . Equal ( 1 , cast [ 0 ] . ToInt32 ( ) ) ;
4185
+ Assert . Equal ( 1 , cast [ 0 ] . ToInt32 ( ) ) ;
4110
4186
}
4111
4187
if ( torch . cuda . is_available ( ) ) {
4112
4188
{
@@ -8517,28 +8593,27 @@ public void DefaultDTypeCreation()
8517
8593
{
8518
8594
var dt = torch . get_default_dtype ( ) ;
8519
8595
8520
- var t = torch . zeros ( 5 , 5 ) ;
8596
+ var t = torch . zeros ( 5 , 5 ) ;
8521
8597
Assert . Equal ( torch . float32 , t . dtype ) ;
8522
8598
8523
8599
try {
8524
- torch . set_default_dtype ( torch . float64 ) ;
8525
-
8526
- t = torch . zeros ( 5 , 5 ) ;
8600
+ torch . set_default_dtype ( torch . float64 ) ;
8601
+
8602
+ t = torch . zeros ( 5 , 5 ) ;
8527
8603
Assert . Equal ( torch . float64 , t . dtype ) ;
8528
8604
8529
- t = torch . ones ( 5 , 5 ) ;
8605
+ t = torch . ones ( 5 , 5 ) ;
8530
8606
Assert . Equal ( torch . float64 , t . dtype ) ;
8531
8607
8532
- t = torch . rand ( 5 , 5 ) ;
8608
+ t = torch . rand ( 5 , 5 ) ;
8533
8609
Assert . Equal ( torch . float64 , t . dtype ) ;
8534
8610
8535
- t = torch . randn ( 5 , 5 ) ;
8611
+ t = torch . randn ( 5 , 5 ) ;
8536
8612
Assert . Equal ( torch . float64 , t . dtype ) ;
8537
8613
8538
8614
t = torch . logspace ( 5 , 15 , 20 ) ;
8539
8615
Assert . Equal ( torch . float64 , t . dtype ) ;
8540
- }
8541
- finally {
8616
+ } finally {
8542
8617
torch . set_default_dtype ( dt ) ;
8543
8618
}
8544
8619
}
@@ -8548,28 +8623,27 @@ public void DefaultDeviceCreation()
8548
8623
{
8549
8624
var dt = torch . get_default_device ( ) ;
8550
8625
8551
- var t = torch . zeros ( 5 , 5 ) ;
8626
+ var t = torch . zeros ( 5 , 5 ) ;
8552
8627
Assert . Equal ( DeviceType . CPU , t . device_type ) ;
8553
8628
8554
8629
try {
8555
- torch . set_default_device ( torch . META ) ;
8556
-
8557
- t = torch . zeros ( 5 , 5 ) ;
8630
+ torch . set_default_device ( torch . META ) ;
8631
+
8632
+ t = torch . zeros ( 5 , 5 ) ;
8558
8633
Assert . Equal ( DeviceType . META , t . device_type ) ;
8559
8634
8560
- t = torch . ones ( 5 , 5 ) ;
8635
+ t = torch . ones ( 5 , 5 ) ;
8561
8636
Assert . Equal ( DeviceType . META , t . device_type ) ;
8562
8637
8563
- t = torch . rand ( 5 , 5 ) ;
8638
+ t = torch . rand ( 5 , 5 ) ;
8564
8639
Assert . Equal ( DeviceType . META , t . device_type ) ;
8565
8640
8566
- t = torch . randn ( 5 , 5 ) ;
8641
+ t = torch . randn ( 5 , 5 ) ;
8567
8642
Assert . Equal ( DeviceType . META , t . device_type ) ;
8568
8643
8569
8644
t = torch . logspace ( 5 , 15 , 20 ) ;
8570
8645
Assert . Equal ( DeviceType . META , t . device_type ) ;
8571
- }
8572
- finally {
8646
+ } finally {
8573
8647
torch . set_default_device ( dt ) ;
8574
8648
}
8575
8649
}
0 commit comments