Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions piccolo/apps/migrations/auto/migration_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ async def _run_alter_columns(self, backwards=False):

index = params.get("index")
index_method = params.get("index_method")
sharded = params.get("sharded")
if index is None:
if index_method is not None:
# If the index value hasn't changed, but the
Expand All @@ -513,6 +514,7 @@ async def _run_alter_columns(self, backwards=False):
_Table.create_index(
[column],
method=index_method,
sharded=sharded,
if_not_exists=True,
)
)
Expand All @@ -525,9 +527,11 @@ async def _run_alter_columns(self, backwards=False):
column._meta.db_column_name = alter_column.db_column_name

if index is True:
kwargs = (
{"method": index_method} if index_method else {}
)
kwargs = {}
if index_method:
kwargs["method"] = index_method
if sharded:
kwargs["sharded"] = sharded
await self._run_query(
_Table.create_index(
[column], if_not_exists=True, **kwargs
Expand Down
13 changes: 11 additions & 2 deletions piccolo/apps/schema/commands/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,22 +195,30 @@ def __post_init__(self):
"""
pat = re.compile(
r"""^CREATE[ ](?:(?P<unique>UNIQUE)[ ])?INDEX[ ]\w+?[ ]
ON[ ].+?[ ]USING[ ](?P<method>\w+?)[ ]
\(\"?(?P<column_name>\w+?\"?)\)""",
ON[ ].+?[ ]USING[ ](?P<method>\w+?)[ ]
\(\"?(?P<column_name>\w+?\"?)(?P<sorting>[ ]\w+?)?
\)(?P<sharded>[ ]USING[ ]HASH)?""",
re.VERBOSE,
)

match = re.match(pat, self.indexdef)
if match is None:
self.column_name = None
self.unique = None
self.method = None
self.sorting = None
self.sharded = None
self.warnings = [f"{self.indexdef};"]
else:
groups = match.groupdict()

self.column_name = groups["column_name"].lstrip('"').rstrip('"')
self.unique = "unique" in groups
self.method = INDEX_METHOD_MAP[groups["method"]]
self.sorting = groups[
"sorting"
] # ASC or DESC. Not currently used but it does sometimes exist so we should capture it.
self.sharded = "sharded" in groups
self.warnings = []


Expand Down Expand Up @@ -720,6 +728,7 @@ async def create_table_class_from_db(
if index is not None:
kwargs["index"] = True
kwargs["index_method"] = index.method
kwargs["sharded"] = index.sharded

if constraints.is_primary_key(column_name=column_name):
kwargs["primary_key"] = True
Expand Down
17 changes: 17 additions & 0 deletions piccolo/columns/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class ColumnMeta:
help_text: t.Optional[str] = None
choices: t.Optional[t.Type[Enum]] = None
secret: bool = False
sharded: bool = False

# Used for representing the table in migrations and the playground.
params: t.Dict[str, t.Any] = field(default_factory=dict)
Expand Down Expand Up @@ -437,6 +438,12 @@ class Band(Table):
>>> await Band.select(exclude_secrets=True)
[{'name': 'Pythonistas'}]

:param sharded:
If ``True`` and primary_key or index is also set ``True``, this index
will automatically use sharding across a cluster. Highly recommended
for sequence columns, such as: Serial, Timestamp.
Also known as Hash Sharded Index.
Comment on lines +441 to +445
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add something here saying 'CockroachDB only'.

If we're likely to have several database specific options, maybe we should namespace them. For example, cockroach_kwargs, or something.


"""

value_type: t.Type = int
Expand All @@ -453,6 +460,7 @@ def __init__(
choices: t.Optional[t.Type[Enum]] = None,
db_column_name: t.Optional[str] = None,
secret: bool = False,
sharded: bool = False,
**kwargs,
) -> None:
# This is for backwards compatibility - originally there were two
Expand All @@ -476,6 +484,7 @@ def __init__(
"choices": choices,
"db_column_name": db_column_name,
"secret": secret,
"sharded": sharded,
}
)

Expand All @@ -494,6 +503,7 @@ def __init__(
choices=choices,
_db_column_name=db_column_name,
secret=secret,
sharded=sharded,
)

self._alias: t.Optional[str] = None
Expand Down Expand Up @@ -823,6 +833,13 @@ def ddl(self) -> str:
query += " PRIMARY KEY"
if self._meta.unique:
query += " UNIQUE"

# Sharded Indexes for sequence columns defined as PRIMARY KEY at table creation time.
# Currently Cockroach only. Must be before NOT NULL!
if self._meta.engine_type in ("cockroach"):
if self._meta.sharded and (self._meta.primary_key):
query += f" USING HASH"

if not self._meta.null:
query += " NOT NULL"

Expand Down
1 change: 1 addition & 0 deletions piccolo/query/methods/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def default_ddl(self) -> t.Sequence[str]:
columns=[column],
method=column._meta.index_method,
if_not_exists=self.if_not_exists,
sharded=column._meta.sharded,
).ddl
)

Expand Down
6 changes: 6 additions & 0 deletions piccolo/query/methods/create_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ def __init__(
columns: t.List[t.Union[Column, str]],
method: IndexMethod = IndexMethod.btree,
if_not_exists: bool = False,
sharded: bool = False,
**kwargs,
):
self.columns = columns
self.method = method
self.if_not_exists = if_not_exists
self.sharded = sharded
super().__init__(table, **kwargs)

@property
Expand Down Expand Up @@ -59,10 +61,14 @@ def cockroach_ddl(self) -> t.Sequence[str]:
tablename = self.table._meta.tablename
method_name = self.method.value
column_names_str = ", ".join([f'"{i}"' for i in self.column_names])
sharded = ""
if self.sharded:
sharded = " USING HASH "
return [
(
f"{self.prefix} {index_name} ON {tablename} USING "
f"{method_name} ({column_names_str})"
f"{sharded}"
)
]

Expand Down
2 changes: 2 additions & 0 deletions piccolo/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,6 +1129,7 @@ def create_index(
columns: t.List[t.Union[Column, str]],
method: IndexMethod = IndexMethod.btree,
if_not_exists: bool = False,
sharded: bool = False,
) -> CreateIndex:
"""
Create a table index. If multiple columns are specified, this refers
Expand All @@ -1144,6 +1145,7 @@ def create_index(
columns=columns,
method=method,
if_not_exists=if_not_exists,
sharded=sharded,
)

@classmethod
Expand Down
8 changes: 4 additions & 4 deletions tests/apps/migrations/auto/test_schema_differ.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_add_table(self):
self.assertTrue(len(new_table_columns.statements) == 1)
self.assertEqual(
new_table_columns.statements[0],
"manager.add_column(table_class_name='Band', tablename='band', column_name='name', db_column_name='name', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False})", # noqa
"manager.add_column(table_class_name='Band', tablename='band', column_name='name', db_column_name='name', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False, 'sharded': False})", # noqa
)

def test_drop_table(self):
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_add_column(self):
self.assertTrue(len(schema_differ.add_columns.statements) == 1)
self.assertEqual(
schema_differ.add_columns.statements[0],
"manager.add_column(table_class_name='Band', tablename='band', column_name='genre', db_column_name='genre', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False})", # noqa
"manager.add_column(table_class_name='Band', tablename='band', column_name='genre', db_column_name='genre', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False, 'sharded': False})", # noqa
)

def test_drop_column(self):
Expand Down Expand Up @@ -207,7 +207,7 @@ def test_rename_column(self):
self.assertEqual(
schema_differ.add_columns.statements,
[
"manager.add_column(table_class_name='Band', tablename='band', column_name='name', db_column_name='name', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False})" # noqa: E501
"manager.add_column(table_class_name='Band', tablename='band', column_name='name', db_column_name='name', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False, 'sharded': False})" # noqa: E501
],
)
self.assertEqual(
Expand Down Expand Up @@ -349,7 +349,7 @@ def mock_input(value: str):
self.assertEqual(
schema_differ.add_columns.statements,
[
"manager.add_column(table_class_name='Band', tablename='band', column_name='b2', db_column_name='b2', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False})" # noqa: E501
"manager.add_column(table_class_name='Band', tablename='band', column_name='b2', db_column_name='b2', column_class_name='Varchar', column_class=Varchar, params={'length': 255, 'default': '', 'null': False, 'primary_key': False, 'unique': False, 'index': False, 'index_method': IndexMethod.btree, 'choices': None, 'db_column_name': None, 'secret': False, 'sharded': False})" # noqa: E501
],
)
self.assertEqual(
Expand Down
6 changes: 3 additions & 3 deletions tests/apps/migrations/auto/test_serialisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def test_lazy_table_reference(self):
'class Manager(Table, tablename="manager"): '
"id = Serial(null=False, primary_key=True, unique=False, " # noqa: E501
"index=False, index_method=IndexMethod.btree, "
"choices=None, db_column_name='id', secret=False)"
"choices=None, db_column_name='id', secret=False, sharded=False)"
),
)

Expand All @@ -261,7 +261,7 @@ def test_lazy_table_reference(self):
'class Manager(Table, tablename="manager"): '
"id = Serial(null=False, primary_key=True, unique=False, " # noqa: E501
"index=False, index_method=IndexMethod.btree, "
"choices=None, db_column_name='id', secret=False)"
"choices=None, db_column_name='id', secret=False, sharded=False)"
),
)

Expand Down Expand Up @@ -312,7 +312,7 @@ def test_column_instance(self):

self.assertEqual(
serialised.params["base_column"].__repr__(),
"Varchar(length=255, default='', null=False, primary_key=False, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name=None, secret=False)", # noqa: E501
"Varchar(length=255, default='', null=False, primary_key=False, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name=None, secret=False, sharded=False)", # noqa: E501
)

self.assertEqual(
Expand Down
37 changes: 37 additions & 0 deletions tests/apps/schema/commands/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,43 @@ def test_index(self):
)


class ConcertSharded(Table):
id = Serial(primary_key=True, sharded=True)
name = Varchar(index=True, sharded=True)
time = Timestamp(index=True, sharded=True)
capacity = Integer(sharded=True)


# Sharded indexes only supported on Cockroach for now.
@engines_only("cockroach")
class TestGenerateWithShardedIndexes(TestCase):
def setUp(self):
ConcertSharded.create_table().run_sync()

def tearDown(self):
ConcertSharded.alter().drop_table(if_exists=True).run_sync()

def test_index(self):
"""
Make sure that a table with an index is reflected correctly.
"""
output_schema: OutputSchema = run_sync(get_output_schema())
Concert_ = output_schema.tables[0]

self.assertEqual(Concert_.id._meta.primary_key, True)
self.assertEqual(Concert_.id._meta.sharded, True)

self.assertEqual(Concert_.name._meta.index, True)
self.assertEqual(Concert_.name._meta.sharded, True)

self.assertEqual(Concert_.time._meta.index, True)
self.assertEqual(Concert_.time._meta.sharded, True)

# Should not shard a non-index.
self.assertEqual(Concert_.capacity._meta.index, False)
self.assertEqual(Concert_.capacity._meta.sharded, False)


###############################################################################


Expand Down
8 changes: 4 additions & 4 deletions tests/table/test_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@ def test_str(self):
Manager._table_str(),
(
"class Manager(Table, tablename='manager'):\n"
" id = Serial(null=False, primary_key=True, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name='id', secret=False)\n" # noqa: E501
" name = Varchar(length=50, default='', null=False, primary_key=False, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name=None, secret=False)\n" # noqa: E501
" id = Serial(null=False, primary_key=True, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name='id', secret=False, sharded=False)\n" # noqa: E501
" name = Varchar(length=50, default='', null=False, primary_key=False, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name=None, secret=False, sharded=False)\n" # noqa: E501
),
)
else:
self.assertEqual(
Manager._table_str(),
(
"class Manager(Table, tablename='manager'):\n"
" id = Serial(null=False, primary_key=True, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name='id', secret=False)\n" # noqa: E501
" name = Varchar(length=50, default='', null=False, primary_key=False, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name=None, secret=False)\n" # noqa: E501
" id = Serial(null=False, primary_key=True, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name='id', secret=False, sharded=False)\n" # noqa: E501
" name = Varchar(length=50, default='', null=False, primary_key=False, unique=False, index=False, index_method=IndexMethod.btree, choices=None, db_column_name=None, secret=False, sharded=False)\n" # noqa: E501
),
)

Expand Down