diff --git a/spanner_orm/admin/update.py b/spanner_orm/admin/update.py index 4277e10..14fc3fc 100644 --- a/spanner_orm/admin/update.py +++ b/spanner_orm/admin/update.py @@ -24,6 +24,7 @@ from spanner_orm.admin import api from spanner_orm.admin import index_column from spanner_orm.admin import metadata +from spanner_orm.index import Index class SchemaUpdate(abc.ABC): @@ -230,19 +231,39 @@ class CreateIndex(SchemaUpdate): def __init__(self, table_name: str, - index_name: str, - columns: Iterable[str], + index_name: Optional[str] = None, + columns: Optional[Iterable[str]] = None, + model_index: Optional[Index] = None, interleaved: Optional[str] = None, - storing_columns: Optional[Iterable[str]] = None): + storing_columns: Optional[Iterable[str]] = None, + unique: Optional[bool] = None, + null_filtered: Optional[bool] = None): + if not ((model_index is not None) ^ (index_name is not None and columns is not None)): + raise error.SpannerError('Exactly one of: [model_index], [index_name, columns] is required') + if model_index and (index_name or columns or interleaved or storing_columns or unique or null_filtered): + raise error.SpannerError('Can not specify any other optional param if model_index is specified') + + if model_index: + index_name = model_index.name + columns = model_index.columns + interleaved = model_index.parent + storing_columns = model_index.storing_columns + unique = model_index.unique + null_filtered = model_index.null_filtered + self._table = table_name self._index = index_name self._columns = columns self._parent_table = interleaved self._storing_columns = storing_columns or [] + self._unique = unique + self._null_filtered = null_filtered def ddl(self) -> str: - statement = 'CREATE INDEX {} ON {} ({})'.format(self._index, self._table, - ', '.join(self._columns)) + statement = 'CREATE {}{}INDEX {} ON {} ({})'.format( + 'UNIQUE ' if self._unique else '', + 'NULL_FILTERED ' if self._null_filtered else '', + self._index, self._table, ', '.join(self._columns)) if self._storing_columns: statement += 'STORING ({})'.format(', '.join(self._storing_columns)) if self._parent_table: @@ -338,10 +359,8 @@ def model_creation_ddl(model_: Type[model.Model]) -> List[str]: continue create_index = CreateIndex( model_.table, - model_index.name, - model_index.columns, - interleaved=model_index.parent, - storing_columns=model_index.storing_columns) + model_index=model_index + ) ddl_list.append(create_index.ddl()) return ddl_list diff --git a/spanner_orm/tests/update_test.py b/spanner_orm/tests/update_test.py index 7b55965..f51fbc8 100644 --- a/spanner_orm/tests/update_test.py +++ b/spanner_orm/tests/update_test.py @@ -19,6 +19,7 @@ from spanner_orm import error from spanner_orm import field from spanner_orm.admin import update +from spanner_orm.index import Index from spanner_orm.tests import models @@ -116,6 +117,48 @@ def test_add_index(self, get_model): self.assertEqual(test_update.ddl(), 'CREATE INDEX foo ON {} (value_1)'.format(table_name)) + @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') + def test_add_index_model_index(self, get_model): + table_name = models.SmallTestModel.table + get_model.return_value = models.SmallTestModel + idx = Index(['value_1']) + idx.name = 'foo' + + test_update = update.CreateIndex(table_name, model_index=idx) + test_update.validate() + self.assertEqual(test_update.ddl(), + 'CREATE INDEX foo ON {} (value_1)'.format(table_name)) + + @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') + def test_add_unique_index(self, get_model): + table_name = models.SmallTestModel.table + get_model.return_value = models.SmallTestModel + + test_update = update.CreateIndex(table_name, 'foo', ['value_1'], unique=True) + test_update.validate() + self.assertEqual(test_update.ddl(), + 'CREATE UNIQUE INDEX foo ON {} (value_1)'.format(table_name)) + + @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') + def test_add_null_filtered_index(self, get_model): + table_name = models.SmallTestModel.table + get_model.return_value = models.SmallTestModel + + test_update = update.CreateIndex(table_name, 'foo', ['value_1'], null_filtered=True) + test_update.validate() + self.assertEqual(test_update.ddl(), + 'CREATE NULL_FILTERED INDEX foo ON {} (value_1)'.format(table_name)) + + @mock.patch('spanner_orm.admin.metadata.SpannerMetadata.model') + def test_add_null_filtered_unique_index(self, get_model): + table_name = models.SmallTestModel.table + get_model.return_value = models.SmallTestModel + + test_update = update.CreateIndex(table_name, 'foo', ['value_1'], unique=True, null_filtered=True) + test_update.validate() + self.assertEqual(test_update.ddl(), + 'CREATE UNIQUE NULL_FILTERED INDEX foo ON {} (value_1)'.format(table_name)) + if __name__ == '__main__': logging.basicConfig()