Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable multiple modules for SQLModel reflection. #587

Merged
merged 1 commit into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 19 additions & 78 deletions gel/orm/sqlmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ def init_module(self, mod, modules):
try:
self.out = f
self.write(f'{COMMENT}\n')
relimport = '.' * len(dirpath)
self.write(f'from {relimport}._tables import *')
self.write(f'from {self.basemodule}._tables import *')
yield f
finally:
self.out = None
Expand Down Expand Up @@ -142,7 +141,7 @@ def spec_to_modules_dict(self, spec):

if len(spec['prop_objects']) > 0:
warnings.warn(
f"Skipping multi properties: SQLAlchemy reflection doesn't "
f"Skipping multi properties: SQLModel reflection doesn't "
f"support multi properties as they produce models without a "
f"clear identity.",
GelORMWarning,
Expand Down Expand Up @@ -186,52 +185,34 @@ def render_models(self, spec):
self.write()
self.render_link_table(rec)

if 'default' not in modules or len(modules) > 1:
skipped = ', '.join([repr(m) for m in modules if m != 'default'])
warnings.warn(
f"Skipping modules {skipped}: SQLModel reflection doesn't "
f"support multiple modules or non-default modules.",
GelORMWarning,
)

with self.init_module('default', modules):
maps = modules['default']
for mod, maps in modules.items():
if not maps:
# skip apparently empty modules
return
continue

link_objects = sorted(
maps.get('link_objects', {}).values(),
key=lambda x: x['name']
)
for lobj in link_objects:
self.write()
self.render_link_object(lobj, modules)
with self.init_module(mod, modules):
link_objects = sorted(
maps.get('link_objects', {}).values(),
key=lambda x: x['name']
)
for lobj in link_objects:
self.write()
self.render_link_object(lobj, modules)

objects = sorted(
maps.get('object_types', {}).values(),
key=lambda x: x['name']
)
for rec in maps.get('object_types', {}).values():
self.write()
self.render_type(rec, modules)
objects = sorted(
maps.get('object_types', {}).values(),
key=lambda x: x['name']
)
for rec in maps.get('object_types', {}).values():
self.write()
self.render_type(rec, modules)

def render_link_table(self, spec):
mod, source = get_mod_and_name(spec["source"])
tmod, target = get_mod_and_name(spec["target"])
s_fk = self.get_fk(mod, source, 'default')
t_fk = self.get_fk(tmod, target, 'default')

if mod != 'default' or tmod != 'default':
skipped = ', '.join(
[repr(m) for m in {mod, tmod} if m != 'default'])
warnings.warn(
f"Skipping modules {skipped}: SQLModel reflection doesn't "
f"support multiple modules or non-default modules.",
GelORMWarning,
)
return

self.write()
self.write(f'class {spec["name"]}(sm.SQLModel, table=True):')
self.indent()
Expand Down Expand Up @@ -263,14 +244,6 @@ def render_link_object(self, spec, modules):
sql_name = spec['table']
source_name, source_link = sql_name.split('.')

if mod != 'default':
warnings.warn(
f"Skipping module {mod!r}: SQLModel reflection doesn't "
f"support multiple modules or non-default modules.",
GelORMWarning,
)
return

self.write()
self.write(f'class {name}(sm.SQLModel, table=True):')
self.indent()
Expand All @@ -292,14 +265,6 @@ def render_link_object(self, spec, modules):
lname = link['name']
tmod, target = get_mod_and_name(link['target']['name'])

if tmod != 'default':
warnings.warn(
f"Skipping module {tmod!r}: SQLModel reflection doesn't "
f"support multiple modules or non-default modules.",
GelORMWarning,
)
return

fk = self.get_fk(tmod, target, mod)
sqlafk = self.get_sqla_fk(tmod, target, mod)
pyname = self.get_py_name(tmod, target, mod)
Expand Down Expand Up @@ -341,14 +306,6 @@ def render_type(self, spec, modules):
mod, name = get_mod_and_name(spec['name'])
sql_name = get_sql_name(spec['name'])

if mod != 'default':
warnings.warn(
f"Skipping module {mod!r}: SQLModel reflection doesn't "
f"support multiple modules or non-default modules.",
GelORMWarning,
)
return

self.write()
self.write(f'class {name}(sm.SQLModel, table=True):')
self.indent()
Expand Down Expand Up @@ -451,14 +408,6 @@ def render_link(self, spec, mod, parent, modules):
cardinality = spec['cardinality']
bklink = f'_{name}_{parent}'

if tmod != 'default':
warnings.warn(
f"Skipping module {tmod!r}: SQLModel reflection doesn't "
f"support multiple modules or non-default modules.",
GelORMWarning,
)
return

if spec.get('has_link_object'):
# intermediate object will have the actual source and target
# links, so the link here needs to be treated similar to a
Expand Down Expand Up @@ -526,14 +475,6 @@ def render_backlink(self, spec, mod, modules):
exclusive = spec['exclusive']
bklink = spec['fwname']

if tmod != 'default':
warnings.warn(
f"Skipping module {tmod!r}: SQLModel reflection doesn't "
f"support multiple modules or non-default modules.",
GelORMWarning,
)
return

if spec.get('has_link_object'):
# intermediate object will have the actual source and target
# links, so the link here needs to refer to the intermediate
Expand Down
71 changes: 0 additions & 71 deletions tests/dbsetup/sqlmodel.edgeql

This file was deleted.

52 changes: 0 additions & 52 deletions tests/dbsetup/sqlmodel.esdl

This file was deleted.

91 changes: 2 additions & 89 deletions tests/test_sqlmodel_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@

class TestSQLModelBasic(tb.SQLModelTestCase):
SCHEMA = os.path.join(os.path.dirname(__file__), 'dbsetup',
'sqlmodel.esdl')
'base.esdl')

SETUP = os.path.join(os.path.dirname(__file__), 'dbsetup',
'sqlmodel.edgeql')
'base.edgeql')

MODEL_PACKAGE = 'sqlmbase'

Expand Down Expand Up @@ -667,93 +667,6 @@ def test_sqlmodel_update_models_05(self):
dt.datetime.fromisoformat('2025-02-01T20:13:45'),
)

def test_sqlmodel_linkprops_01(self):
val = self.sess.exec(select(self.sm.HasLinkPropsA)).one()
self.assertEqual(val.child.target.num, 0)
self.assertEqual(val.child.a, 'single')

def test_sqlmodel_linkprops_02(self):
val = self.sess.exec(select(self.sm.HasLinkPropsA)).one()
self.assertEqual(val.child.target.num, 0)
self.assertEqual(val.child.a, 'single')

# replace the single child with a different one
ch = self.sess.exec(select(self.sm.Child).filter_by(num=1)).one()
val.child = self.sm.HasLinkPropsA_child_link(a='replaced', target=ch)
self.sess.flush()

val = self.sess.exec(select(self.sm.HasLinkPropsA)).one()
self.assertEqual(val.child.target.num, 1)
self.assertEqual(val.child.a, 'replaced')

# make sure there's only one link object still
vals = self.sess.exec(select(self.sm.HasLinkPropsA_child_link))
self.assertEqual(
[(val.a, val.target.num) for val in vals],
[('replaced', 1)]
)

def test_sqlmodel_linkprops_03(self):
val = self.sess.exec(select(self.sm.HasLinkPropsA)).one()
self.assertEqual(val.child.target.num, 0)
self.assertEqual(val.child.a, 'single')

# delete the child object
val = self.sess.exec(select(self.sm.Child).filter_by(num=0)).one()
self.sess.delete(val)
self.sess.flush()

val = self.sess.exec(select(self.sm.HasLinkPropsA)).one()
self.assertEqual(val.child, None)

# make sure the link object is removed
vals = self.sess.exec(select(self.sm.HasLinkPropsA_child_link))
self.assertEqual(list(vals), [])

def test_sqlmodel_linkprops_04(self):
val = self.sess.exec(select(self.sm.HasLinkPropsB)).one()
self.assertEqual(
{(c.b, c.target.num) for c in val.children},
{('hello', 0), ('world', 1)},
)

def test_sqlmodel_linkprops_05(self):
val = self.sess.exec(select(self.sm.HasLinkPropsB)).one()
self.assertEqual(
{(c.b, c.target.num) for c in val.children},
{('hello', 0), ('world', 1)},
)

# Remove one of the children
for t in list(val.children):
if t.b != 'hello':
val.children.remove(t)
self.sess.flush()

val = self.sess.exec(select(self.sm.HasLinkPropsB)).one()
self.assertEqual(
{(c.b, c.target.num) for c in val.children},
{('hello', 0)},
)

def test_sqlmodel_linkprops_06(self):
val = self.sess.exec(select(self.sm.HasLinkPropsB)).one()
self.assertEqual(
{(c.b, c.target.num) for c in val.children},
{('hello', 0), ('world', 1)},
)

# Remove one of the children
val = self.sess.exec(select(self.sm.Child).filter_by(num=0)).one()
self.sess.delete(val)
self.sess.flush()

val = self.sess.exec(select(self.sm.HasLinkPropsB)).one()
self.assertEqual(
{(c.b, c.target.num) for c in val.children},
{('world', 1)},
)

def test_sqlmodel_sorting(self):
# Test the natural sorting function used for ordering fields, etc.

Expand Down
Loading
Loading