@@ -784,9 +784,26 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
784
784
ASR::call_arg_t c_arg;
785
785
c_arg.loc = args[i].loc ;
786
786
c_arg.m_value = args[i].m_value ;
787
- cast_helper (m_args[i], c_arg.m_value , true );
788
787
ASR::ttype_t * left_type = ASRUtils::expr_type (m_args[i]);
789
788
ASR::ttype_t * right_type = ASRUtils::expr_type (c_arg.m_value );
789
+ if ( ASR::is_a<ASR::StructType_t>(*left_type) && ASR::is_a<ASR::StructType_t>(*right_type) ) {
790
+ ASR::StructType_t *l_type = ASR::down_cast<ASR::StructType_t>(left_type);
791
+ ASR::StructType_t *r_type = ASR::down_cast<ASR::StructType_t>(right_type);
792
+ ASR::Struct_t *l2_type = ASR::down_cast<ASR::Struct_t>(
793
+ ASRUtils::symbol_get_past_external (
794
+ l_type->m_derived_type ));
795
+ ASR::Struct_t *r2_type = ASR::down_cast<ASR::Struct_t>(
796
+ ASRUtils::symbol_get_past_external (
797
+ r_type->m_derived_type ));
798
+ if ( ASRUtils::is_derived_type_similar (l2_type, r2_type) ) {
799
+ cast_helper (m_args[i], c_arg.m_value , true , true );
800
+ check_type_equality = false ;
801
+ } else {
802
+ cast_helper (m_args[i], c_arg.m_value , true );
803
+ }
804
+ } else {
805
+ cast_helper (m_args[i], c_arg.m_value , true );
806
+ }
790
807
if ( check_type_equality && !ASRUtils::check_equal_type (left_type, right_type) ) {
791
808
std::string ltype = ASRUtils::type_to_str_python (left_type);
792
809
std::string rtype = ASRUtils::type_to_str_python (right_type);
@@ -2962,9 +2979,8 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
2962
2979
std::string obj_name = x.m_args .m_args ->m_arg ;
2963
2980
for (size_t i = 0 ; i < x.n_body ; i++) {
2964
2981
std::string var_name;
2965
- if (! AST::is_a<AST::AnnAssign_t>(*x.m_body [i]) ){
2966
- throw SemanticError (" Only AnnAssign implemented in __init__ " ,
2967
- x.m_body [i]->base .loc );
2982
+ if ( !AST::is_a<AST::AnnAssign_t>(*x.m_body [i]) ){
2983
+ continue ;
2968
2984
}
2969
2985
AST::AnnAssign_t ann_assign = *AST::down_cast<AST::AnnAssign_t>(x.m_body [i]);
2970
2986
if (AST::is_a<AST::Attribute_t>(*ann_assign.m_target )){
@@ -3301,10 +3317,21 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
3301
3317
current_scope->add_symbol (x_m_name, class_type);
3302
3318
}
3303
3319
} else {
3304
- if ( x.n_bases > 0 ) {
3305
- throw SemanticError (" Inheritance in classes isn't supported yet." ,
3320
+ ASR::symbol_t * parent = nullptr ;
3321
+ if ( x.n_bases > 1 ) {
3322
+ throw SemanticError (" Multiple inheritance in classes isn't supported yet." ,
3306
3323
x.base .base .loc );
3307
3324
}
3325
+ else if (x.n_bases == 1 ) {
3326
+ std::string b_name = " " ;
3327
+ if ( AST::is_a<AST::Name_t>(*x.m_bases [0 ]) ) {
3328
+ b_name = AST::down_cast<AST::Name_t>(x.m_bases [0 ])->m_id ;
3329
+ } else {
3330
+ throw SemanticError (" Expected a Name here" , x.base .base .loc );
3331
+ }
3332
+ parent = current_scope->resolve_symbol (b_name);
3333
+ LCOMPILERS_ASSERT (ASR::is_a<ASR::Struct_t>(*parent));
3334
+ }
3308
3335
SymbolTable *parent_scope = current_scope;
3309
3336
if ( ASR::symbol_t * sym = current_scope->resolve_symbol (x_m_name) ) {
3310
3337
LCOMPILERS_ASSERT (ASR::is_a<ASR::Struct_t>(*sym));
@@ -3316,7 +3343,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
3316
3343
f = AST::down_cast<AST::FunctionDef_t>(x.m_body [i]);
3317
3344
init_self_type (*f, sym, x.base .base .loc );
3318
3345
if ( std::string (f->m_name ) == std::string (" __init__" ) ) {
3319
- this ->visit_init_body (*f);
3346
+ this ->visit_init_body (*f, st-> m_parent , x. m_body [i]-> base . loc );
3320
3347
} else {
3321
3348
this ->visit_stmt (*x.m_body [i]);
3322
3349
}
@@ -3344,7 +3371,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
3344
3371
member_names.p , member_names.size (), member_fn_names.p ,
3345
3372
member_fn_names.size (), class_abi, ASR::accessType::Public,
3346
3373
false , false , member_init.p , member_init.size (),
3347
- nullptr , nullptr ));
3374
+ nullptr , parent ));
3348
3375
parent_scope->add_symbol (x.m_name , class_sym);
3349
3376
visit_ClassMembers (x, member_names, member_fn_names,
3350
3377
struct_dependencies, member_init, false , class_abi, true );
@@ -3387,7 +3414,7 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
3387
3414
current_scope = parent_scope;
3388
3415
}
3389
3416
3390
- virtual void visit_init_body (const AST::FunctionDef_t &/* x*/ ) = 0;
3417
+ virtual void visit_init_body (const AST::FunctionDef_t &/* x*/ , ASR:: symbol_t * /* parent_sym */ , const Location /* loc */ ) = 0;
3391
3418
3392
3419
void add_name (const Location &loc) {
3393
3420
std::string var_name = " __name__" ;
@@ -4421,7 +4448,7 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
4421
4448
// Implement visit_Global for Symbol Table visitor.
4422
4449
void visit_Global (const AST::Global_t &/* x*/ ) {}
4423
4450
4424
- void visit_init_body (const AST::FunctionDef_t &/* x*/ ) {
4451
+ void visit_init_body (const AST::FunctionDef_t &/* x*/ , ASR:: symbol_t * /* parent_sym */ , const Location /* loc */ ) {
4425
4452
// Implemented in BodyVisitor
4426
4453
}
4427
4454
@@ -5153,7 +5180,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
5153
5180
tmp = asr;
5154
5181
}
5155
5182
5156
- void visit_init_body (const AST::FunctionDef_t &x) {
5183
+ void visit_init_body (const AST::FunctionDef_t &x, ASR:: symbol_t * parent_sym, const Location loc ) {
5157
5184
SymbolTable *old_scope = current_scope;
5158
5185
ASR::symbol_t *t = current_scope->get_symbol (" __init__" );
5159
5186
if ( t==nullptr ) {
@@ -5163,31 +5190,77 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
5163
5190
throw SemanticError (" __init__ is not a function" , x.base .base .loc );
5164
5191
}
5165
5192
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(t);
5193
+ current_scope = f->m_symtab ;
5166
5194
// Transform statements into correct format
5167
- Vec<AST::stmt_t *> new_body;
5168
- new_body.reserve (al, 1 );
5195
+ Vec<AST::stmt_t *> body;
5196
+ body.reserve (al, 1 );
5197
+ ASR::stmt_t * super_call_stmt = nullptr ;
5169
5198
for (size_t i=0 ; i<x.n_body ; i++) {
5170
- AST::AnnAssign_t ann_assign = *AST::down_cast<AST::AnnAssign_t>(x.m_body [i]);
5171
- if ( ann_assign.m_value != nullptr ) {
5172
- Vec<AST::expr_t *>target;
5173
- target.reserve (al, 1 );
5174
- target.push_back (al, ann_assign.m_target );
5175
- AST::ast_t * assgn_ast = AST::make_Assign_t (al, ann_assign.base .base .loc ,
5176
- target.p , 1 , ann_assign.m_value , nullptr );
5177
- AST::stmt_t * assgn = AST::down_cast<AST::stmt_t >(assgn_ast);
5178
- new_body.push_back (al, assgn);
5199
+ if (AST::is_a<AST::AnnAssign_t>(*x.m_body [i])) {
5200
+ AST::AnnAssign_t ann_assign = *AST::down_cast<AST::AnnAssign_t>(x.m_body [i]);
5201
+ if ( ann_assign.m_value != nullptr ) {
5202
+ Vec<AST::expr_t *>target;
5203
+ target.reserve (al, 1 );
5204
+ target.push_back (al, ann_assign.m_target );
5205
+ AST::ast_t * assgn_ast = AST::make_Assign_t (al, ann_assign.base .base .loc ,
5206
+ target.p , 1 , ann_assign.m_value , nullptr );
5207
+ AST::stmt_t * assgn = AST::down_cast<AST::stmt_t >(assgn_ast);
5208
+ body.push_back (al, assgn);
5209
+ }
5210
+ } else if (AST::is_a<AST::Expr_t>(*x.m_body [i]) &&
5211
+ AST::is_a<AST::Call_t>(*(AST::down_cast<AST::Expr_t>(x.m_body [i])->m_value ))) {
5212
+ AST::Call_t* c = AST::down_cast<AST::Call_t>(AST::down_cast<AST::Expr_t>(x.m_body [i])->m_value );
5213
+
5214
+ if ( !AST::is_a<AST::Attribute_t>(*(c->m_func ))
5215
+ || !AST::is_a<AST::Call_t>(*(AST::down_cast<AST::Attribute_t>(c->m_func )->m_value )) ) {
5216
+ body.push_back (al, x.m_body [i]);
5217
+ continue ;
5218
+ }
5219
+ AST::Call_t* super_call = AST::down_cast<AST::Call_t>(AST::down_cast<AST::Attribute_t>(c->m_func )->m_value );
5220
+ std::string attr = AST::down_cast<AST::Attribute_t>(c->m_func )->m_attr ;
5221
+ if ( AST::is_a<AST::Name_t>(*(super_call->m_func )) &&
5222
+ std::string (AST::down_cast<AST::Name_t>(super_call->m_func )->m_id )==" super" &&
5223
+ attr == " __init__" ) {
5224
+ if (parent_sym == nullptr ) {
5225
+ throw SemanticError (" The class doesn't have a base class" ,loc);
5226
+ }
5227
+ Vec<ASR::call_arg_t > args;
5228
+ args.reserve (al, 1 );
5229
+ parse_args (*super_call,args);
5230
+ ASR::call_arg_t first_arg;
5231
+ first_arg.loc = loc;
5232
+ ASR::symbol_t * self_sym = current_scope->get_symbol (" self" );
5233
+ first_arg.m_value = ASRUtils::EXPR (ASR::make_Var_t (al,loc,self_sym));
5234
+ ASR::ttype_t * target_type = ASRUtils::TYPE (ASRUtils::make_StructType_t_util (al,loc,parent_sym));
5235
+ cast_helper (target_type, first_arg.m_value , x.base .base .loc , true );
5236
+ Vec<ASR::call_arg_t > args_w_first; args_w_first.reserve (al,1 );
5237
+ args_w_first.push_back (al, first_arg);
5238
+ for ( size_t i = 0 ; i < args.size (); i++ ) {
5239
+ args_w_first.push_back (al,args[i]);
5240
+ }
5241
+ std::string call_name = " __init__" ;
5242
+ ASR::symbol_t * call_sym = get_struct_member (parent_sym,call_name,loc);
5243
+ super_call_stmt = ASRUtils::STMT (
5244
+ ASR::make_SubroutineCall_t (al, loc, call_sym, call_sym, args_w_first.p ,
5245
+ args_w_first.size (), nullptr ));
5246
+ }
5247
+ } else {
5248
+ body.push_back (al, x.m_body [i]);
5179
5249
}
5180
5250
}
5181
5251
current_scope = f->m_symtab ;
5182
- Vec<ASR::stmt_t *> body;
5183
- body.reserve (al, x.n_body );
5252
+ Vec<ASR::stmt_t *> body_asr;
5253
+ body_asr.reserve (al, x.n_body );
5254
+ if ( super_call_stmt ) {
5255
+ body_asr.push_back (al, super_call_stmt);
5256
+ }
5184
5257
Vec<ASR::symbol_t *> rts;
5185
5258
rts.reserve (al, 4 );
5186
5259
dependencies.clear (al);
5187
- transform_stmts (body, new_body .n , new_body .p );
5260
+ transform_stmts (body_asr, body .n , body .p );
5188
5261
for (const auto &rt: rt_vec) { rts.push_back (al, rt); }
5189
- f->m_body = body .p ;
5190
- f->n_body = body .size ();
5262
+ f->m_body = body_asr .p ;
5263
+ f->n_body = body_asr .size ();
5191
5264
ASR::FunctionType_t* func_type = ASR::down_cast<ASR::FunctionType_t>(
5192
5265
f->m_function_signature );
5193
5266
func_type->m_restrictions = rts.p ;
@@ -6239,10 +6312,14 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
6239
6312
for ( size_t i = 0 ; i < der_type->n_members && !member_found; i++ ) {
6240
6313
member_found = std::string (der_type->m_members [i]) == member_name;
6241
6314
}
6242
- if ( !member_found ) {
6315
+ if ( !member_found && !der_type-> m_parent ) {
6243
6316
throw SemanticError (" No member " + member_name +
6244
6317
" found in " + std::string (der_type->m_name ),
6245
6318
loc);
6319
+ } else if ( !member_found && der_type->m_parent ) {
6320
+ ASR::ttype_t * parent_type = ASRUtils::TYPE (ASRUtils::make_StructType_t_util (al, loc,der_type->m_parent ));
6321
+ visit_AttributeUtil (parent_type,attr_char,t,loc);
6322
+ return ;
6246
6323
}
6247
6324
ASR::expr_t *val = ASR::down_cast<ASR::expr_t >(ASR::make_Var_t (al, loc, t));
6248
6325
ASR::symbol_t * member_sym = der_type->m_symtab ->resolve_symbol (member_name);
@@ -8064,7 +8141,8 @@ we will have to use something else.
8064
8141
// TODO: Correct Class and ClassType
8065
8142
// call to struct member function
8066
8143
// modifying args to pass the object as self
8067
- ASR::symbol_t * der = ASR::down_cast<ASR::StructType_t>(var->m_type )->m_derived_type ;
8144
+ ASR::symbol_t * der_sym = ASR::down_cast<ASR::StructType_t>(var->m_type )->m_derived_type ;
8145
+ ASR::Struct_t* der = ASR::down_cast<ASR::Struct_t>(der_sym);
8068
8146
Vec<ASR::call_arg_t > new_args; new_args.reserve (al, args.n + 1 );
8069
8147
ASR::call_arg_t self_arg;
8070
8148
self_arg.loc = args[0 ].loc ;
@@ -8073,7 +8151,20 @@ we will have to use something else.
8073
8151
for (size_t i=0 ; i<args.n ; i++) {
8074
8152
new_args.push_back (al, args[i]);
8075
8153
}
8076
- st = get_struct_member (der, call_name, loc);
8154
+ if ( der->m_symtab ->get_symbol (call_name) ) {
8155
+ st = get_struct_member (der_sym, call_name, loc);
8156
+ } else if ( der->m_parent ) {
8157
+ ASR::Struct_t* parent = ASR::down_cast<ASR::Struct_t>(der->m_parent );
8158
+ if ( !parent->m_symtab ->get_symbol (call_name) ) {
8159
+ throw SemanticError (" Method not found in the class " + std::string (der->m_name ) +
8160
+ " or it's parents" ,loc);
8161
+ } else {
8162
+ st = get_struct_member (der->m_parent , call_name, loc);
8163
+ }
8164
+ } else {
8165
+ throw SemanticError (" Method not found in the class " +std::string (der->m_name )+
8166
+ " or it's parents" ,loc);
8167
+ }
8077
8168
tmp = make_call_helper (al, st, current_scope, new_args, call_name, loc);
8078
8169
return ;
8079
8170
} else {
0 commit comments