diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 6f89855..d1307a2 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -50,13 +50,10 @@ }, // Add the IDs of extensions you want installed when the container is created. "extensions": [ - "bungcip.better-toml", - "janisdd.vscode-edit-csv", "charliermarsh.ruff", "njpwerner.autodocstring", "VisualStudioExptTeam.vscodeintellicode", "KevinRose.vsc-python-indent", - "eamodio.gitlens", "mhutchie.git-graph", "ms-python.python", "ms-python.vscode-pylance" diff --git a/Makefile b/Makefile index 0f757c2..33889a2 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,8 @@ run: - poetry run python main.py --enable_create_model --enable_judge_survival --configs configs + python main.py --enable_create_model --enable_judge_survival --configs configs format: - poetry run ruff check --fix . - poetry run black . + ruff check --fix . + black . test: - poetry run pytest . \ No newline at end of file + pytest . \ No newline at end of file diff --git a/main.py b/main.py index 1fb47f8..c685303 100644 --- a/main.py +++ b/main.py @@ -18,7 +18,10 @@ @click.option("--enable_create_model", is_flag=True, help="モデル作成するか") @click.option("--enable_judge_survival", is_flag=True, help="生存判定するか") @click.option( - "--configs", type=click.Path(exists=True), default="configs", help="configsフォルダのパス" + "--configs", + type=click.Path(exists=True), + default="configs", + help="configsフォルダのパス", ) def main( enable_create_model: bool, diff --git a/poetry.lock b/poetry.lock index 1b8bdfb..cb99403 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,12 +1,12 @@ -# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "antlr4-python3-runtime" version = "4.9.3" description = "ANTLR 4.9.3 runtime for Python 3.7" -category = "main" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "antlr4-python3-runtime-4.9.3.tar.gz", hash = "sha256:f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b"}, ] @@ -15,9 +15,9 @@ files = [ name = "attrs" version = "22.2.0" description = "Classes Without Boilerplate" -category = "dev" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "attrs-22.2.0-py3-none-any.whl", hash = "sha256:29e95c7f6778868dbd49170f98f8818f78f3dc5e0e37c0b1f474e3561b240836"}, {file = "attrs-22.2.0.tar.gz", hash = "sha256:c9227bfc2f01993c03f68db37d1d15c9690188323c067c641f1a35ca58185f99"}, @@ -28,15 +28,15 @@ cov = ["attrs[tests]", "coverage-enable-subprocess", "coverage[toml] (>=5.3)"] dev = ["attrs[docs,tests]"] docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope.interface"] tests = ["attrs[tests-no-zope]", "zope.interface"] -tests-no-zope = ["cloudpickle", "cloudpickle", "hypothesis", "hypothesis", "mypy (>=0.971,<0.990)", "mypy (>=0.971,<0.990)", "pympler", "pympler", "pytest (>=4.3.0)", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-mypy-plugins", "pytest-xdist[psutil]", "pytest-xdist[psutil]"] +tests-no-zope = ["cloudpickle ; platform_python_implementation == \"CPython\"", "cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "hypothesis", "mypy (>=0.971,<0.990) ; platform_python_implementation == \"CPython\"", "mypy (>=0.971,<0.990) ; platform_python_implementation == \"CPython\"", "pympler", "pympler", "pytest (>=4.3.0)", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version < \"3.11\"", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version < \"3.11\"", "pytest-xdist[psutil]", "pytest-xdist[psutil]"] [[package]] name = "black" version = "23.3.0" description = "The uncompromising code formatter." -category = "dev" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "black-23.3.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:0945e13506be58bf7db93ee5853243eb368ace1c08a24c65ce108986eac65915"}, {file = "black-23.3.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:67de8d0c209eb5b330cce2469503de11bca4085880d62f1628bd9972cc3366b9"}, @@ -83,9 +83,9 @@ uvloop = ["uvloop (>=0.15.2)"] name = "click" version = "8.1.3" description = "Composable command line interface toolkit" -category = "main" optional = false python-versions = ">=3.7" +groups = ["main", "dev"] files = [ {file = "click-8.1.3-py3-none-any.whl", hash = "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48"}, {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"}, @@ -98,21 +98,23 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." -category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +groups = ["main", "dev"] files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +markers = {main = "platform_system == \"Windows\"", dev = "platform_system == \"Windows\" or sys_platform == \"win32\""} [[package]] name = "exceptiongroup" version = "1.1.1" description = "Backport of PEP 654 (exception groups)" -category = "dev" optional = false python-versions = ">=3.7" +groups = ["dev"] +markers = "python_version == \"3.10\"" files = [ {file = "exceptiongroup-1.1.1-py3-none-any.whl", hash = "sha256:232c37c63e4f682982c8b6459f33a8981039e5fb8756b2074364e5055c498c9e"}, {file = "exceptiongroup-1.1.1.tar.gz", hash = "sha256:d484c3090ba2889ae2928419117447a14daf3c1231d5e30d0aae34f354f01785"}, @@ -125,9 +127,9 @@ test = ["pytest (>=6)"] name = "iniconfig" version = "2.0.0" description = "brain-dead simple config-ini parsing" -category = "dev" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, @@ -137,9 +139,9 @@ files = [ name = "joblib" version = "1.2.0" description = "Lightweight pipelining with Python functions" -category = "main" optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "joblib-1.2.0-py3-none-any.whl", hash = "sha256:091138ed78f800342968c523bdde947e7a305b8594b910a0fea2ab83c3c6d385"}, {file = "joblib-1.2.0.tar.gz", hash = "sha256:e1cee4a79e4af22881164f218d4311f60074197fb707e082e803b61f6d137018"}, @@ -149,9 +151,9 @@ files = [ name = "mypy-extensions" version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." -category = "dev" optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, @@ -161,9 +163,9 @@ files = [ name = "numpy" version = "1.24.2" description = "Fundamental package for array computing in Python" -category = "main" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "numpy-1.24.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eef70b4fc1e872ebddc38cddacc87c19a3709c0e3e5d20bf3954c147b1dd941d"}, {file = "numpy-1.24.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e8d2859428712785e8a8b7d2b3ef0a1d1565892367b32f915c4a4df44d0e64f5"}, @@ -199,25 +201,25 @@ files = [ name = "omegaconf" version = "2.3.0" description = "A flexible configuration library" -category = "main" optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "omegaconf-2.3.0-py3-none-any.whl", hash = "sha256:7b4df175cdb08ba400f45cae3bdcae7ba8365db4d165fc65fd04b050ab63b46b"}, {file = "omegaconf-2.3.0.tar.gz", hash = "sha256:d5d4b6d29955cc50ad50c46dc269bcd92c6e00f5f90d23ab5fee7bfca4ba4cc7"}, ] [package.dependencies] -antlr4-python3-runtime = ">=4.9.0,<4.10.0" +antlr4-python3-runtime = "==4.9.*" PyYAML = ">=5.1.0" [[package]] name = "packaging" version = "23.0" description = "Core utilities for Python packages" -category = "dev" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "packaging-23.0-py3-none-any.whl", hash = "sha256:714ac14496c3e68c99c29b00845f7a2b85f3bb6f1078fd9f72fd20f0570002b2"}, {file = "packaging-23.0.tar.gz", hash = "sha256:b6ad297f8907de0fa2fe1ccbd26fdaf387f5f47c7275fedf8cce89f99446cf97"}, @@ -227,9 +229,9 @@ files = [ name = "pandas" version = "2.0.0" description = "Powerful data structures for data analysis, time series, and statistics" -category = "main" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "pandas-2.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bbb2c5e94d6aa4e632646a3bacd05c2a871c3aa3e85c9bec9be99cb1267279f2"}, {file = "pandas-2.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b5337c87c4e963f97becb1217965b6b75c6fe5f54c4cf09b9a5ac52fc0bd03d3"}, @@ -260,8 +262,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, + {version = ">=1.21.0", markers = "python_version == \"3.10\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -294,9 +296,9 @@ xml = ["lxml (>=4.6.3)"] name = "pathspec" version = "0.11.1" description = "Utility library for gitignore style pattern matching of file paths." -category = "dev" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pathspec-0.11.1-py3-none-any.whl", hash = "sha256:d8af70af76652554bd134c22b3e8a1cc46ed7d91edcdd721ef1a0c51a84a5293"}, {file = "pathspec-0.11.1.tar.gz", hash = "sha256:2798de800fa92780e33acca925945e9a19a133b715067cf165b8866c15a31687"}, @@ -306,9 +308,9 @@ files = [ name = "platformdirs" version = "3.3.0" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." -category = "dev" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "platformdirs-3.3.0-py3-none-any.whl", hash = "sha256:ea61fd7b85554beecbbd3e9b37fb26689b227ffae38f73353cbcc1cf8bd01878"}, {file = "platformdirs-3.3.0.tar.gz", hash = "sha256:64370d47dc3fca65b4879f89bdead8197e93e05d696d6d1816243ebae8595da5"}, @@ -322,9 +324,9 @@ test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.3.1)", "pytest- name = "pluggy" version = "1.0.0" description = "plugin and hook calling mechanisms for python" -category = "dev" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, @@ -338,9 +340,9 @@ testing = ["pytest", "pytest-benchmark"] name = "py" version = "1.11.0" description = "library with cross-python path, ini-parsing, io, code, log facilities" -category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +groups = ["dev"] files = [ {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"}, {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"}, @@ -350,9 +352,9 @@ files = [ name = "pytest" version = "7.2.2" description = "pytest: simple powerful testing with Python" -category = "dev" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pytest-7.2.2-py3-none-any.whl", hash = "sha256:130328f552dcfac0b1cec75c12e3f005619dc5f874f0a06e8ff7263f0ee6225e"}, {file = "pytest-7.2.2.tar.gz", hash = "sha256:c99ab0c73aceb050f68929bc93af19ab6db0558791c6a0715723abe9d0ade9d4"}, @@ -374,9 +376,9 @@ testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2. name = "pytest-html" version = "3.2.0" description = "pytest plugin for generating HTML reports" -category = "dev" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "pytest-html-3.2.0.tar.gz", hash = "sha256:c4e2f4bb0bffc437f51ad2174a8a3e71df81bbc2f6894604e604af18fbe687c3"}, {file = "pytest_html-3.2.0-py3-none-any.whl", hash = "sha256:868c08564a68d8b2c26866f1e33178419bb35b1e127c33784a28622eb827f3f3"}, @@ -391,9 +393,9 @@ pytest-metadata = "*" name = "pytest-metadata" version = "2.0.4" description = "pytest plugin for test session metadata" -category = "dev" optional = false python-versions = ">=3.7,<4.0" +groups = ["dev"] files = [ {file = "pytest_metadata-2.0.4-py3-none-any.whl", hash = "sha256:acb739f89fabb3d798c099e9e0c035003062367a441910aaaf2281bc1972ee14"}, {file = "pytest_metadata-2.0.4.tar.gz", hash = "sha256:fcc653f65fe3035b478820b5284fbf0f52803622ee3f60a2faed7a7d3ba1f41e"}, @@ -406,9 +408,9 @@ pytest = ">=3.0.0,<8.0.0" name = "python-dateutil" version = "2.8.2" description = "Extensions to the standard Python datetime module" -category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["main"] files = [ {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, @@ -421,9 +423,9 @@ six = ">=1.5" name = "pytz" version = "2023.3" description = "World timezone definitions, modern and historical" -category = "main" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "pytz-2023.3-py2.py3-none-any.whl", hash = "sha256:a151b3abb88eda1d4e34a9814df37de2a80e301e68ba0fd856fb9b46bfbbbffb"}, {file = "pytz-2023.3.tar.gz", hash = "sha256:1d8ce29db189191fb55338ee6d0387d82ab59f3d00eac103412d64e0ebd0c588"}, @@ -433,9 +435,9 @@ files = [ name = "pyyaml" version = "6.0" description = "YAML parser and emitter for Python" -category = "main" optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"}, {file = "PyYAML-6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9df7ed3b3d2e0ecfe09e14741b857df43adb5a3ddadc919a2d94fbdf78fea53c"}, @@ -483,9 +485,9 @@ files = [ name = "ruff" version = "0.0.257" description = "An extremely fast Python linter, written in Rust." -category = "dev" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "ruff-0.0.257-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:7280640690c1d0046b20e0eb924319a89d8e22925d7d232180ce31196e7478f8"}, {file = "ruff-0.0.257-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:4582b73da61ab410ffda35b2987a6eacb33f18263e1c91810f0b9779ec4f41a9"}, @@ -510,9 +512,9 @@ files = [ name = "scikit-learn" version = "1.2.2" description = "A set of python modules for machine learning and data mining" -category = "main" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "scikit-learn-1.2.2.tar.gz", hash = "sha256:8429aea30ec24e7a8c7ed8a3fa6213adf3814a6efbea09e16e0a0c71e1a1a3d7"}, {file = "scikit_learn-1.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:99cc01184e347de485bf253d19fcb3b1a3fb0ee4cea5ee3c43ec0cc429b6d29f"}, @@ -553,9 +555,9 @@ tests = ["black (>=22.3.0)", "flake8 (>=3.8.2)", "matplotlib (>=3.1.3)", "mypy ( name = "scipy" version = "1.9.3" description = "Fundamental algorithms for scientific computing in Python" -category = "main" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "scipy-1.9.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1884b66a54887e21addf9c16fb588720a8309a57b2e258ae1c7986d4444d3bc0"}, {file = "scipy-1.9.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:83b89e9586c62e787f5012e8475fbb12185bafb996a03257e9675cd73d3736dd"}, @@ -592,9 +594,9 @@ test = ["asv", "gmpy2", "mpmath", "pytest", "pytest-cov", "pytest-xdist", "sciki name = "six" version = "1.16.0" description = "Python 2 and 3 compatibility utilities" -category = "main" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" +groups = ["main"] files = [ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, @@ -604,9 +606,9 @@ files = [ name = "threadpoolctl" version = "3.1.0" description = "threadpoolctl" -category = "main" optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "threadpoolctl-3.1.0-py3-none-any.whl", hash = "sha256:8b99adda265feb6773280df41eece7b2e6561b772d21ffd52e372f999024907b"}, {file = "threadpoolctl-3.1.0.tar.gz", hash = "sha256:a335baacfaa4400ae1f0d8e3a58d6674d2f8828e3716bb2802c44955ad391380"}, @@ -616,9 +618,10 @@ files = [ name = "tomli" version = "2.0.1" description = "A lil' TOML parser" -category = "dev" optional = false python-versions = ">=3.7" +groups = ["dev"] +markers = "python_version == \"3.10\"" files = [ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, @@ -628,15 +631,15 @@ files = [ name = "tzdata" version = "2023.3" description = "Provider of IANA time zone data" -category = "main" optional = false python-versions = ">=2" +groups = ["main"] files = [ {file = "tzdata-2023.3-py2.py3-none-any.whl", hash = "sha256:7e65763eef3120314099b6939b5546db7adce1e7d6f2e179e3df563c70511eda"}, {file = "tzdata-2023.3.tar.gz", hash = "sha256:11ef1e08e54acb0d4f95bdb1be05da659673de4acbd21bf9c69e94cc5e907a3a"}, ] [metadata] -lock-version = "2.0" +lock-version = "2.1" python-versions = "^3.10" content-hash = "67b62acfd592944846398fe68c1252afb1949e98ac3f8771a9c0961856ae8eaa" diff --git a/src/qa_training/domain/service_make_features.py b/src/qa_training/domain/service_make_features.py index f3d8280..2f1fffe 100644 --- a/src/qa_training/domain/service_make_features.py +++ b/src/qa_training/domain/service_make_features.py @@ -52,20 +52,40 @@ def _make_y( def _handle_missing_values(self, df_customer_info) -> pd.DataFrame: """欠損値処理する.""" - df_customer_info["Age"] = df_customer_info["Age"].fillna(10) - df_customer_info["Cabin"] = df_customer_info["Cabin"].fillna("S") - df_customer_info = df_customer_info.dropna() + df_customer_info["Sex"] = df_customer_info["Sex"].fillna('male') # error_tag + # df_customer_info["Age"] = df_customer_info["Age"].fillna(10) + df_customer_info["Age"] = df_customer_info["Age"].fillna(20) # error_tag + # df_customer_info["Cabin"] = df_customer_info["Cabin"].fillna("S") + df_customer_info["Embarked"] = df_customer_info["Embarked"].fillna("S") # error_tag + df_customer_info["Pclass"] = df_customer_info["Pclass"].fillna(2) # error_tag + for name in ['Survival', 'Name', 'Sibsp','Parch', 'Ticket', 'Fare', 'Cabin']: + if name in df_customer_info.columns: + df_customer_info = df_customer_info.dropna(subset=[name]) return df_customer_info def _handle_violations(self, df_filled) -> pd.DataFrame: """制約違反を処理する.""" + df_filled = df_filled[df_filled["Survival"].isin([0, 1])] df_filled = df_filled[df_filled["Pclass"].isin([1, 2, 3])] - df_filled = df_filled[df_filled["Sex"].isin(["male", "female"])] - df_filled = df_filled[ - (df_filled["Age"] >= 0) & (df_filled["Age"].apply(float.is_integer)) - ] + df_filled = df_filled[df_filled["Name"].apply(lambda x: isinstance(x, str))] + df_filled = df_filled[df_filled["Ticket"].apply(lambda x: isinstance(x, str))] + df_filled = df_filled[df_filled["Cabin"].apply(lambda x: isinstance(x, str))] df_filled = df_filled[df_filled["Embarked"].isin(["C", "Q", "S"])] - + # df_filled = df_filled[ + # df_filled["Age"].apply(lambda x: isinstance(x, (int, float)) and 0 <= x <= 130 and float(x).is_integer()) + # ] + # df_filled = df_filled[ + # (df_filled["Sibsp"] >= 0) & ( + # df_filled["Sibsp"].apply( + # lambda x: isinstance(x, (int, float)) and 0 <= x and float(x).is_integer() + # ) + # ) + # ] + # df_filled = df_filled[ + # (df_filled["Parch"] >= 0) & (df_filled["Parch"].apply( + # lambda x: isinstance(x, (int, float)) and 0 <= x and float(x).is_integer() + # )) + # ] return df_filled def _make_features(self, df_obeyed: pd.DataFrame) -> pd.DataFrame: @@ -76,5 +96,8 @@ def _make_features(self, df_obeyed: pd.DataFrame) -> pd.DataFrame: df_obeyed.loc[:, "Sex"] = ( df_obeyed["Sex"].replace({"male": 0, "female": 1}).astype("int64") ) - df_obeyed = pd.get_dummies(df_obeyed, columns=["Embarked"], dtype=float) + df_obeyed = pd.get_dummies(df_obeyed, columns=["Embarked"], prefix='Embarked') + df_obeyed['hoge'] = pd.cut(df_obeyed['Age'], [0, 10, 18, 40, 64]) + df_obeyed = pd.get_dummies(df_obeyed, columns=["hoge"], prefix='Age') + df_obeyed = df_obeyed.drop('hoge', axis=1) return df_obeyed diff --git a/tests/qa_training/domain/test_service_make_features.py b/tests/qa_training/domain/test_service_make_features.py index d5c0d8f..ceecdf7 100644 --- a/tests/qa_training/domain/test_service_make_features.py +++ b/tests/qa_training/domain/test_service_make_features.py @@ -1,3 +1,4 @@ +import numpy as np import pandas as pd import pytest from qa_training.domain.service_make_features import ServiceMakeFeatures @@ -47,3 +48,87 @@ def test_run( MyAssert().assert_df(df_id, df_id_expected) MyAssert().assert_df(df_X, df_X_expected) MyAssert().assert_df(df_y, df_y_expected) + +def test_handle_missing_values( + fixture_run: tuple[ + ServiceMakeFeatures, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame + ] +): + ( + service_make_features, + df_customer_info, + df_id_expected, + df_X_expected, + df_y_expected, + ) = fixture_run + test_df = pd.DataFrame({ + 'Sex':[np.nan, 'male', 'female', np.nan, 'female'], + 'Age':[21, 15, np.nan, np.nan, 33], + 'Embarked':['C', np.nan, 'Q', 'Q', 'S'], + 'Pclass':[1, np.nan, 1, 1, 3], + 'Cabin':['C123', 'C85', 'B42', 'C33', 'S33'], + 'Name':['Alen', 'Bob', 'aaa', 'Cas', np.nan], + 'Survival':['Alen', 'Bob', 'aaa', 'Cas', np.nan], + 'Sibsp':['Alen', 'Bob', 'aaa', 'Cas', np.nan], + 'Parch':['Alen', 'Bob', 'aaa', 'Cas', np.nan], + 'Ticket':['Alen', 'Bob', 'aaa', 'Cas', np.nan], + 'Fare':['Alen', 'Bob', 'aaa', 'Cas', np.nan], + }) + test_expected_df = pd.DataFrame({ + 'Sex':['male', 'male', 'female', 'male'], + 'Age':[21, 15, 20, 20], + 'Embarked':['C', 'S', 'Q', 'Q'], + 'Pclass':[1, 2, 1, 1], + 'Cabin':['C123', 'C85', 'B42', 'C33'], + 'Name':['Alen', 'Bob', 'aaa', 'Cas'], + 'Survival':['Alen', 'Bob', 'aaa', 'Cas'], + 'Sibsp':['Alen', 'Bob', 'aaa', 'Cas'], + 'Parch':['Alen', 'Bob', 'aaa', 'Cas'], + 'Ticket':['Alen', 'Bob', 'aaa', 'Cas'], + 'Fare':['Alen', 'Bob', 'aaa', 'Cas'], + }) + out_put = service_make_features._handle_missing_values(test_df) + MyAssert().assert_df(test_expected_df, out_put) + + + +def test_handle_violations( + fixture_run: tuple[ + ServiceMakeFeatures, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame + ] +): + ( + service_make_features, + df_customer_info, + df_id_expected, + df_X_expected, + df_y_expected, + ) = fixture_run + test_df = pd.DataFrame({ + 'Sex': ['male', 'male', 'female', 'male'], + # 'Age': [21, 15, 20, 20], + 'Embarked': ['C', 'S', 'Q', 'Q'], + 'Pclass': [1, 2, 1, 1], + 'Cabin': ['C123', 'C85', 'B42', 'C33'], + 'Name': ['Alen', 'Bob', 'aaa', 'Cas'], + 'Survival': [1, 0, 'aaa', 1], # 'aaa' は不正なので後で除外 + # 'Sibsp': [1, 0, 'aaa', 2], # 'aaa' は不正 + # 'Parch': [0, 1, 2, 'Bob'], # 'Bob' は不正 + 'Ticket': ['PC 17599', 'STON/O2. 3101282', '330877', 'Cas'], # OK + 'Fare': [72.5, 8.05, 'aaa', 13.0] # 'aaa' は不正 + }) + test_expected_df = pd.DataFrame({ + 'Sex': ['male', 'male'], + # 'Age': [21, 15], + 'Embarked': ['C', 'S'], + 'Pclass': [1, 2], + 'Cabin': ['C123', 'C85'], + 'Name': ['Alen', 'Bob'], + 'Survival': [1, 0], + # 'Sibsp': [1, 0], + # 'Parch': [0, 1], + 'Ticket': ['PC 17599', 'STON/O2. 3101282'], + 'Fare': [72.5, 8.05], + }) + out_put = service_make_features._handle_violations(test_df) + MyAssert().assert_df(test_expected_df, out_put)