|
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +import sqlite3 |
| 14 | +from unittest.mock import MagicMock, patch |
| 15 | + |
| 16 | +import pytest |
| 17 | +from botocore.session import Session |
| 18 | + |
| 19 | +from awscli.telemetry import ( |
| 20 | + CLISessionData, |
| 21 | + CLISessionDatabaseConnection, |
| 22 | + CLISessionDatabaseReader, |
| 23 | + CLISessionDatabaseSweeper, |
| 24 | + CLISessionDatabaseWriter, |
| 25 | + CLISessionGenerator, |
| 26 | + CLISessionOrchestrator, |
| 27 | + add_session_id_component_to_user_agent_extra, |
| 28 | +) |
| 29 | +from awscli.testutils import skip_if_windows |
| 30 | + |
| 31 | + |
| 32 | +@pytest.fixture |
| 33 | +def session_conn(): |
| 34 | + conn = CLISessionDatabaseConnection( |
| 35 | + connection=sqlite3.connect( |
| 36 | + # Use an in-memory db for testing. |
| 37 | + ':memory:', |
| 38 | + check_same_thread=False, |
| 39 | + isolation_level=None, |
| 40 | + ), |
| 41 | + ) |
| 42 | + # Write an initial record. |
| 43 | + conn.execute( |
| 44 | + """ |
| 45 | + INSERT OR REPLACE INTO session ( |
| 46 | + key, session_id, timestamp |
| 47 | + ) VALUES ('first_key', 'first_id', 5555555555) |
| 48 | + """ |
| 49 | + ) |
| 50 | + return conn |
| 51 | + |
| 52 | + |
| 53 | +@pytest.fixture |
| 54 | +def session_writer(session_conn): |
| 55 | + return CLISessionDatabaseWriter(session_conn) |
| 56 | + |
| 57 | + |
| 58 | +@pytest.fixture |
| 59 | +def session_reader(session_conn): |
| 60 | + return CLISessionDatabaseReader(session_conn) |
| 61 | + |
| 62 | + |
| 63 | +@pytest.fixture |
| 64 | +def session_sweeper(session_conn): |
| 65 | + return CLISessionDatabaseSweeper(session_conn) |
| 66 | + |
| 67 | + |
| 68 | +@pytest.fixture |
| 69 | +def session_generator(): |
| 70 | + return CLISessionGenerator() |
| 71 | + |
| 72 | + |
| 73 | +@pytest.fixture |
| 74 | +def expired_data(session_writer, session_reader, session_sweeper): |
| 75 | + # Write an expired record. |
| 76 | + session_writer.write( |
| 77 | + CLISessionData( |
| 78 | + key='expired_key', |
| 79 | + session_id='expired_id', |
| 80 | + timestamp=1000000000, |
| 81 | + ) |
| 82 | + ) |
| 83 | + # Ensure expired record exists. |
| 84 | + assert session_reader.read('expired_key') is not None |
| 85 | + yield |
| 86 | + # Ensure cleanup after test is run. |
| 87 | + session_sweeper.sweep(1000000001) |
| 88 | + |
| 89 | + |
| 90 | +class TestCLISessionDatabaseConnection: |
| 91 | + def test_ensure_database_setup(self, session_conn): |
| 92 | + cursor = session_conn.execute( |
| 93 | + """ |
| 94 | + SELECT name |
| 95 | + FROM sqlite_master |
| 96 | + WHERE type='table' |
| 97 | + AND name='session'; |
| 98 | + """ |
| 99 | + ) |
| 100 | + assert cursor.fetchall() == [('session',)] |
| 101 | + |
| 102 | + def test_timeout_does_not_raise_exception(self, session_conn): |
| 103 | + class FakeConnection(sqlite3.Connection): |
| 104 | + def execute(self, query, *parameters): |
| 105 | + # Simulate timeout by always raising. |
| 106 | + raise sqlite3.OperationalError() |
| 107 | + |
| 108 | + fake_conn = CLISessionDatabaseConnection(FakeConnection(":memory:")) |
| 109 | + cursor = fake_conn.execute( |
| 110 | + """ |
| 111 | + SELECT name |
| 112 | + FROM sqlite_master |
| 113 | + WHERE type='table' |
| 114 | + AND name='session'; |
| 115 | + """ |
| 116 | + ) |
| 117 | + assert cursor.fetchall() == [] |
| 118 | + |
| 119 | + |
| 120 | +class TestCLISessionDatabaseWriter: |
| 121 | + def test_write(self, session_writer, session_reader, session_sweeper): |
| 122 | + session_writer.write( |
| 123 | + CLISessionData( |
| 124 | + key='new-key', |
| 125 | + session_id='new-id', |
| 126 | + timestamp=1000000000, |
| 127 | + ) |
| 128 | + ) |
| 129 | + session_data = session_reader.read('new-key') |
| 130 | + assert session_data.key == 'new-key' |
| 131 | + assert session_data.session_id == 'new-id' |
| 132 | + assert session_data.timestamp == 1000000000 |
| 133 | + session_sweeper.sweep(1000000001) |
| 134 | + |
| 135 | + |
| 136 | +class TestCLISessionDatabaseReader: |
| 137 | + def test_read(self, session_reader): |
| 138 | + session_data = session_reader.read('first_key') |
| 139 | + assert session_data.key == 'first_key' |
| 140 | + assert session_data.session_id == 'first_id' |
| 141 | + assert session_data.timestamp == 5555555555 |
| 142 | + |
| 143 | + def test_read_nonexistent_record(self, session_reader): |
| 144 | + session_data = session_reader.read('bad_key') |
| 145 | + assert session_data is None |
| 146 | + |
| 147 | + |
| 148 | +class TestCLISessionDatabaseSweeper: |
| 149 | + def test_sweep(self, expired_data, session_reader, session_sweeper): |
| 150 | + session_sweeper.sweep(1000000001) |
| 151 | + swept_data = session_reader.read('expired_key') |
| 152 | + assert swept_data is None |
| 153 | + |
| 154 | + def test_sweep_not_expired( |
| 155 | + self, expired_data, session_reader, session_sweeper |
| 156 | + ): |
| 157 | + session_sweeper.sweep(1000000000) |
| 158 | + swept_data = session_reader.read('expired_key') |
| 159 | + assert swept_data is not None |
| 160 | + |
| 161 | + def test_sweep_never_raises(self, session_sweeper): |
| 162 | + # Normally this would raise `sqlite3.ProgrammingError`, |
| 163 | + # but the `sweep` method catches bare exceptions. |
| 164 | + session_sweeper.sweep({'bad': 'input'}) |
| 165 | + |
| 166 | + |
| 167 | +class TestCLISessionGenerator: |
| 168 | + def test_generate_session_id(self, session_generator): |
| 169 | + session_id = session_generator.generate_session_id( |
| 170 | + 'my-hostname', |
| 171 | + 'my-tty', |
| 172 | + 1000000000, |
| 173 | + ) |
| 174 | + assert session_id == 'd949713b13ee3fb52983b04316e8e6b5' |
| 175 | + |
| 176 | + def test_generate_cache_key(self, session_generator): |
| 177 | + cache_key = session_generator.generate_cache_key( |
| 178 | + 'my-hostname', |
| 179 | + 'my-tty', |
| 180 | + ) |
| 181 | + assert cache_key == 'b1ca2be0ffac12f172933b6777e06f2c' |
| 182 | + |
| 183 | + |
| 184 | +@skip_if_windows("No os.ttyname") |
| 185 | +@patch('sys.stdin') |
| 186 | +@patch('time.time', return_value=5555555555) |
| 187 | +@patch('socket.gethostname', return_value='my-hostname') |
| 188 | +@patch('os.ttyname', return_value='my-tty') |
| 189 | +class TestCLISessionOrchestrator: |
| 190 | + def test_session_id_gets_cached( |
| 191 | + self, |
| 192 | + patched_tty_name, |
| 193 | + patched_hostname, |
| 194 | + patched_time, |
| 195 | + patched_stdin, |
| 196 | + session_sweeper, |
| 197 | + session_generator, |
| 198 | + session_reader, |
| 199 | + session_writer, |
| 200 | + ): |
| 201 | + patched_stdin.fileno.return_value = None |
| 202 | + orchestrator = CLISessionOrchestrator( |
| 203 | + session_generator, session_writer, session_reader, session_sweeper |
| 204 | + ) |
| 205 | + assert orchestrator.session_id == '881cea8546fa4888970cce8d133c3bf9' |
| 206 | + |
| 207 | + session_data = session_reader.read(orchestrator.cache_key) |
| 208 | + assert session_data.key == orchestrator.cache_key |
| 209 | + assert session_data.session_id == orchestrator.session_id |
| 210 | + assert session_data.timestamp == 5555555555 |
| 211 | + |
| 212 | + def test_cached_session_id_gets_updated( |
| 213 | + self, |
| 214 | + patched_tty_name, |
| 215 | + patched_hostname, |
| 216 | + patched_time, |
| 217 | + patched_stdin, |
| 218 | + session_sweeper, |
| 219 | + session_generator, |
| 220 | + session_reader, |
| 221 | + session_writer, |
| 222 | + ): |
| 223 | + patched_stdin.fileno.return_value = None |
| 224 | + |
| 225 | + # First, generate and cache a session id. |
| 226 | + orchestrator_1 = CLISessionOrchestrator( |
| 227 | + session_generator, session_writer, session_reader, session_sweeper |
| 228 | + ) |
| 229 | + session_id_1 = orchestrator_1.session_id |
| 230 | + session_data_1 = session_reader.read(orchestrator_1.cache_key) |
| 231 | + assert session_data_1.session_id == session_id_1 |
| 232 | + |
| 233 | + # Update the timestamp and get the new session id. |
| 234 | + patched_time.return_value = 7777777777 |
| 235 | + orchestrator_2 = CLISessionOrchestrator( |
| 236 | + session_generator, session_writer, session_reader, session_sweeper |
| 237 | + ) |
| 238 | + session_id_2 = orchestrator_2.session_id |
| 239 | + session_data_2 = session_reader.read(orchestrator_2.cache_key) |
| 240 | + |
| 241 | + # Cache key should be the same. |
| 242 | + assert session_data_2.key == session_data_1.key |
| 243 | + # Session id and timestamp should be updated. |
| 244 | + assert session_data_2.session_id == session_id_2 |
| 245 | + assert session_data_2.session_id != session_data_1.session_id |
| 246 | + assert session_data_2.timestamp == 7777777777 |
| 247 | + assert session_data_2.timestamp != session_data_1.timestamp |
| 248 | + |
| 249 | + |
| 250 | +def test_add_session_id_component_to_user_agent_extra(): |
| 251 | + session = MagicMock(Session) |
| 252 | + session.user_agent_extra = '' |
| 253 | + orchestrator = MagicMock(CLISessionOrchestrator) |
| 254 | + orchestrator.session_id = 'my-session-id' |
| 255 | + add_session_id_component_to_user_agent_extra(session, orchestrator) |
| 256 | + assert session.user_agent_extra == 'sid/my-session-id' |
0 commit comments