@@ -83,6 +83,14 @@ struct PackedVec3::State {
83
83
// / A map from type to the name of a helper function used to unpack that type.
84
84
Hashmap<const core::type::Type*, Symbol, 4 > unpack_helpers;
85
85
86
+ // / @returns true if @p addrspace requires vec3 types to be packed
87
+ bool AddressSpaceNeedsPacking (core::AddressSpace addrspace) {
88
+ // Host-shareable address spaces need to be packed to match the memory layout on the host.
89
+ // The workgroup address space needs to be packed so that the size of generated threadgroup
90
+ // variables matches the size of the original WGSL declarations.
91
+ return core::IsHostShareable (addrspace) || addrspace == core::AddressSpace::kWorkgroup ;
92
+ }
93
+
86
94
// / @param ty the type to test
87
95
// / @returns true if `ty` is a vec3, false otherwise
88
96
bool IsVec3 (const core::type::Type* ty) {
@@ -374,7 +382,7 @@ struct PackedVec3::State {
374
382
// if the transform is necessary.
375
383
for (auto * decl : src.AST ().GlobalVariables ()) {
376
384
auto * var = sem.Get <sem::GlobalVariable>(decl);
377
- if (var && core::IsHostShareable (var->AddressSpace ()) &&
385
+ if (var && AddressSpaceNeedsPacking (var->AddressSpace ()) &&
378
386
ContainsVec3 (var->Type ()->UnwrapRef ())) {
379
387
return true ;
380
388
}
@@ -411,7 +419,7 @@ struct PackedVec3::State {
411
419
[&](const sem::TypeExpression* type) {
412
420
// Rewrite pointers to types that contain vec3s.
413
421
auto * ptr = type->Type ()->As <core::type::Pointer>();
414
- if (ptr && core::IsHostShareable (ptr->AddressSpace ())) {
422
+ if (ptr && AddressSpaceNeedsPacking (ptr->AddressSpace ())) {
415
423
auto new_store_type = RewriteType (ptr->StoreType ());
416
424
if (new_store_type) {
417
425
auto access = ptr->AddressSpace () == core::AddressSpace::kStorage
@@ -424,7 +432,7 @@ struct PackedVec3::State {
424
432
}
425
433
},
426
434
[&](const sem::Variable* var) {
427
- if (!core::IsHostShareable (var->AddressSpace ())) {
435
+ if (!AddressSpaceNeedsPacking (var->AddressSpace ())) {
428
436
return ;
429
437
}
430
438
@@ -440,7 +448,7 @@ struct PackedVec3::State {
440
448
auto * lhs = sem.GetVal (assign->lhs );
441
449
auto * rhs = sem.GetVal (assign->rhs );
442
450
if (!ContainsVec3 (rhs->Type ()) ||
443
- !core::IsHostShareable (
451
+ !AddressSpaceNeedsPacking (
444
452
lhs->Type ()->As <core::type::Reference>()->AddressSpace ())) {
445
453
// Skip assignments to address spaces that are not host-shareable, or
446
454
// that do not contain vec3 types.
@@ -468,7 +476,7 @@ struct PackedVec3::State {
468
476
[&](const sem::Load* load) {
469
477
// Unpack loads of types that contain vec3s in host-shareable address spaces.
470
478
if (ContainsVec3 (load->Type ()) &&
471
- core::IsHostShareable (load->ReferenceType ()->AddressSpace ())) {
479
+ AddressSpaceNeedsPacking (load->ReferenceType ()->AddressSpace ())) {
472
480
to_unpack.Add (load);
473
481
}
474
482
},
@@ -478,7 +486,7 @@ struct PackedVec3::State {
478
486
// struct.
479
487
if (auto * ref = accessor->Type ()->As <core::type::Reference>()) {
480
488
if (IsVec3 (ref->StoreType ()) &&
481
- core::IsHostShareable (ref->AddressSpace ())) {
489
+ AddressSpaceNeedsPacking (ref->AddressSpace ())) {
482
490
ctx.Replace (node, b.MemberAccessor (ctx.Clone (accessor->Declaration ()),
483
491
kStructMemberName ));
484
492
}
0 commit comments