Skip to content

Commit 2da9e49

Browse files
committed
Instantiate specialized classes nested in templates
Signed-off-by: Dimitar Dobrev <[email protected]>
1 parent eca0db1 commit 2da9e49

File tree

6 files changed

+60
-20
lines changed

6 files changed

+60
-20
lines changed

src/AST/ClassExtensions.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,20 @@ public static Class GetInterface(this Class @class)
233233
return @interface;
234234
}
235235

236+
public static ClassTemplateSpecialization GetParentSpecialization(this Class @class)
237+
{
238+
Class currentClass = @class;
239+
do
240+
{
241+
if (currentClass is ClassTemplateSpecialization specialization)
242+
{
243+
return specialization;
244+
}
245+
currentClass = currentClass.Namespace as Class;
246+
} while (currentClass != null);
247+
return null;
248+
}
249+
236250
public static bool HasDependentValueFieldInLayout(this Class @class)
237251
{
238252
if (@class.Fields.Any(f => IsValueDependent(f.Type)))

src/CppParser/Parser.cpp

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include <clang/Parse/ParseAST.h>
4343
#include <clang/Sema/Sema.h>
4444
#include <clang/Sema/SemaConsumer.h>
45+
#include <clang/Sema/Template.h>
4546
#include <clang/Frontend/Utils.h>
4647
#include <clang/Driver/Driver.h>
4748
#include <clang/Driver/ToolChain.h>
@@ -3053,8 +3054,7 @@ void Parser::CompleteIfSpecializationType(const clang::QualType& QualType)
30533054
RD = const_cast<CXXRecordDecl*>(Type->getPointeeCXXRecordDecl());
30543055
ClassTemplateSpecializationDecl* CTS;
30553056
if (!RD ||
3056-
!(CTS = llvm::dyn_cast<ClassTemplateSpecializationDecl>(RD)) ||
3057-
CTS->isCompleteDefinition())
3057+
!(CTS = llvm::dyn_cast<ClassTemplateSpecializationDecl>(RD)))
30583058
return;
30593059

30603060
auto existingClient = c->getSema().getDiagnostics().getClient();
@@ -3065,8 +3065,7 @@ void Parser::CompleteIfSpecializationType(const clang::QualType& QualType)
30653065
Scope Scope(nullptr, Scope::ScopeFlags::ClassScope, c->getSema().getDiagnostics());
30663066
c->getSema().TUScope = &Scope;
30673067

3068-
c->getSema().InstantiateClassTemplateSpecialization(CTS->getBeginLoc(),
3069-
CTS, TSK_ImplicitInstantiation, false);
3068+
InstantiateSpecialization(CTS);
30703069

30713070
c->getSema().getDiagnostics().setClient(existingClient, false);
30723071
c->getSema().TUScope = nullptr;
@@ -3082,6 +3081,32 @@ void Parser::CompleteIfSpecializationType(const clang::QualType& QualType)
30823081
}
30833082
}
30843083

3084+
void Parser::InstantiateSpecialization(clang::ClassTemplateSpecializationDecl* CTS)
3085+
{
3086+
using namespace clang;
3087+
3088+
if (!CTS->isCompleteDefinition())
3089+
{
3090+
c->getSema().InstantiateClassTemplateSpecialization(CTS->getBeginLoc(),
3091+
CTS, TSK_ImplicitInstantiation, false);
3092+
}
3093+
3094+
for (auto Decl : CTS->decls())
3095+
{
3096+
if (Decl->getKind() == Decl::Kind::CXXRecord)
3097+
{
3098+
CXXRecordDecl* Nested = cast<CXXRecordDecl>(Decl);
3099+
CXXRecordDecl* Template = Nested->getInstantiatedFromMemberClass();
3100+
if (Template && !Nested->isCompleteDefinition() && !Nested->hasDefinition())
3101+
{
3102+
c->getSema().InstantiateClass(Nested->getBeginLoc(), Nested, Template,
3103+
MultiLevelTemplateArgumentList(CTS->getTemplateArgs()),
3104+
TSK_ImplicitInstantiation, false);
3105+
}
3106+
}
3107+
}
3108+
}
3109+
30853110
Parameter* Parser::WalkParameter(const clang::ParmVarDecl* PVD,
30863111
const clang::SourceLocation& ParamStartLoc)
30873112
{

src/CppParser/Parser.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class Parser
137137
std::string GetTypeName(const clang::Type* Type);
138138
bool CanCheckCodeGenInfo(clang::Sema & S, const clang::Type * Ty);
139139
void CompleteIfSpecializationType(const clang::QualType& QualType);
140+
void InstantiateSpecialization(clang::ClassTemplateSpecializationDecl* CTS);
140141
Parameter* WalkParameter(const clang::ParmVarDecl* PVD,
141142
const clang::SourceLocation& ParamStartLoc);
142143
void SetBody(const clang::FunctionDecl* FD, Function* F);

src/Generator/Generators/CSharp/CSharpSources.cs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -361,13 +361,11 @@ private void GenerateNestedInternals(string name, IEnumerable<Class> nestedClass
361361
private IEnumerable<Class> GetGeneratedClasses(
362362
Class dependentClass, IEnumerable<Class> specializedClasses)
363363
{
364-
var specialization = specializedClasses.FirstOrDefault(s => s.IsGenerated) ??
365-
specializedClasses.First();
366-
367364
if (dependentClass.HasDependentValueFieldInLayout())
368-
return specializedClasses;
365+
return specializedClasses.KeepSingleAllPointersSpecialization();
369366

370-
return new[] { specialization };
367+
return new[] { specializedClasses.FirstOrDefault(s => s.IsGenerated) ??
368+
specializedClasses.First()};
371369
}
372370

373371
public override void GenerateDeclarationCommon(Declaration decl)

src/Generator/Generators/CSharp/CSharpSourcesExtensions.cs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,9 @@ public static void GenerateNativeConstructorsByValue(
2424
var printedClass = @class.Visit(gen.TypePrinter);
2525
if (@class.IsDependent)
2626
{
27-
IEnumerable<Class> specializations =
28-
@class.GetSpecializedClassesToGenerate().Where(s => s.IsGenerated);
29-
if (@class.IsTemplate)
30-
specializations = specializations.KeepSingleAllPointersSpecialization();
31-
foreach (var specialization in specializations)
27+
foreach (var specialization in (from s in @class.GetSpecializedClassesToGenerate()
28+
where s.IsGenerated
29+
select s).KeepSingleAllPointersSpecialization())
3230
gen.GenerateNativeConstructorByValue(specialization, printedClass);
3331
}
3432
else
@@ -40,10 +38,10 @@ public static void GenerateNativeConstructorsByValue(
4038
public static IEnumerable<Class> KeepSingleAllPointersSpecialization(
4139
this IEnumerable<Class> specializations)
4240
{
43-
Func<TemplateArgument, bool> allPointers = (TemplateArgument a) =>
44-
a.Type.Type?.Desugar().IsAddress() == true;
45-
var groups = (from ClassTemplateSpecialization spec in specializations
46-
group spec by spec.Arguments.All(allPointers)
41+
static bool allPointers(TemplateArgument a) => a.Type.Type?.Desugar().IsAddress() == true;
42+
var groups = (from @class in specializations
43+
let spec = @class.GetParentSpecialization()
44+
group @class by spec.Arguments.All(allPointers)
4745
into @group
4846
select @group).ToList();
4947
foreach (var group in groups)

tests/CSharp/CSharpTemplates.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,12 @@ void forceUseSpecializations(IndependentFields<int> _1, IndependentFields<bool>
109109
VirtualTemplate<int> _6, VirtualTemplate<bool> _7,
110110
HasDefaultTemplateArgument<int, int> _8, DerivedChangesTypeName<T1> _9,
111111
TemplateWithIndexer<int> _10, TemplateWithIndexer<T1> _11,
112-
TemplateWithIndexer<T2*> _12, TemplateDerivedFromRegularDynamic<RegularDynamic> _13,
113-
IndependentFields<OnlySpecialisedInTypeArg<double> > _14, std::string s)
112+
TemplateWithIndexer<void*> _12, TemplateWithIndexer<UsedInTemplatedIndexer> _13,
113+
TemplateDerivedFromRegularDynamic<RegularDynamic> _14,
114+
IndependentFields<OnlySpecialisedInTypeArg<double>> _15,
115+
DependentPointerFields<float> _16, IndependentFields<const T1&> _17,
116+
TemplateWithIndexer<T2*> _18, IndependentFields<int(*)(int)> _19,
117+
TemplateWithIndexer<const char*> _20, std::string s)
114118
{
115119
}
116120

0 commit comments

Comments
 (0)