Skip to content

Commit ca7186a

Browse files
author
ansipunk
committed
S01E09
1 parent 3f26f76 commit ca7186a

File tree

7 files changed

+41
-129
lines changed

7 files changed

+41
-129
lines changed

databases/backends/aiopg.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,13 @@
55
import uuid
66

77
import aiopg
8+
from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg
89
from sqlalchemy.engine.cursor import CursorResultMetaData
910
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
10-
from sqlalchemy.engine.row import Row
1111
from sqlalchemy.sql import ClauseElement
1212
from sqlalchemy.sql.ddl import DDLElement
1313

1414
from databases.backends.common.records import Record, Row, create_column_maps
15-
from databases.backends.compilers.psycopg import PGCompiler_psycopg
16-
from databases.backends.dialects.psycopg import PGDialect_psycopg
1715
from databases.core import LOG_EXTRA, DatabaseURL
1816
from databases.interfaces import (
1917
ConnectionBackend,
@@ -38,12 +36,10 @@ def _get_dialect(self) -> Dialect:
3836
dialect = PGDialect_psycopg(
3937
json_serializer=json.dumps, json_deserializer=lambda x: x
4038
)
41-
dialect.statement_compiler = PGCompiler_psycopg
4239
dialect.implicit_returning = True
4340
dialect.supports_native_enum = True
4441
dialect.supports_smallserial = True # 9.2+
4542
dialect._backslash_escapes = False
46-
dialect.supports_sane_multi_rowcount = True # psycopg 2.0.9+
4743
dialect._has_native_hstore = True
4844
dialect.supports_native_decimal = True
4945

databases/backends/compilers/__init__.py

Whitespace-only changes.

databases/backends/compilers/psycopg.py

Lines changed: 0 additions & 17 deletions
This file was deleted.

databases/backends/dialects/__init__.py

Whitespace-only changes.

databases/backends/dialects/psycopg.py

Lines changed: 0 additions & 46 deletions
This file was deleted.

databases/backends/psycopg.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class PsycopgBackend(DatabaseBackend):
2222
_database_url: DatabaseURL
2323
_options: typing.Dict[str, typing.Any]
2424
_dialect: Dialect
25-
_pool: typing.Optional[psycopg_pool.AsyncConnectionPool]
25+
_pool: typing.Optional[psycopg_pool.AsyncConnectionPool] = None
2626

2727
def __init__(
2828
self,
@@ -33,7 +33,6 @@ def __init__(
3333
self._options = options
3434
self._dialect = PGDialect_psycopg()
3535
self._dialect.implicit_returning = True
36-
self._pool = None
3736

3837
async def connect(self) -> None:
3938
if self._pool is not None:
@@ -95,7 +94,10 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
9594
rows = await cursor.fetchall()
9695

9796
column_maps = create_column_maps(result_columns)
98-
return [PsycopgRecord(row, result_columns, self._dialect, column_maps) for row in rows]
97+
return [
98+
PsycopgRecord(row, result_columns, self._dialect, column_maps)
99+
for row in rows
100+
]
99101

100102
async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
101103
if self._connection is None:
@@ -167,7 +169,8 @@ def raw_connection(self) -> typing.Any:
167169
return self._connection
168170

169171
def _compile(
170-
self, query: ClauseElement,
172+
self,
173+
query: ClauseElement,
171174
) -> typing.Tuple[str, typing.Mapping[str, typing.Any], tuple]:
172175
compiled = query.compile(
173176
dialect=self._dialect,
@@ -224,7 +227,9 @@ def _mapping(self) -> typing.Mapping:
224227

225228
def __getitem__(self, key: typing.Any) -> typing.Any:
226229
if len(self._column_map) == 0:
227-
return self._mapping[key]
230+
if isinstance(key, str):
231+
return self._mapping[key]
232+
return self._row[key]
228233
elif isinstance(key, Column):
229234
idx, datatype = self._column_map_full[str(key)]
230235
elif isinstance(key, int):

tests/test_databases.py

Lines changed: 30 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -204,17 +204,17 @@ async def test_queries(database_url):
204204

205205
assert len(results) == 3
206206
assert results[0]["text"] == "example1"
207-
assert results[0]["completed"] == True
207+
assert results[0]["completed"] is True
208208
assert results[1]["text"] == "example2"
209-
assert results[1]["completed"] == False
209+
assert results[1]["completed"] is False
210210
assert results[2]["text"] == "example3"
211-
assert results[2]["completed"] == True
211+
assert results[2]["completed"] is True
212212

213213
# fetch_one()
214214
query = notes.select()
215215
result = await database.fetch_one(query=query)
216216
assert result["text"] == "example1"
217-
assert result["completed"] == True
217+
assert result["completed"] is True
218218

219219
# fetch_val()
220220
query = sqlalchemy.sql.select(*[notes.c.text])
@@ -246,11 +246,11 @@ async def test_queries(database_url):
246246
iterate_results.append(result)
247247
assert len(iterate_results) == 3
248248
assert iterate_results[0]["text"] == "example1"
249-
assert iterate_results[0]["completed"] == True
249+
assert iterate_results[0]["completed"] is True
250250
assert iterate_results[1]["text"] == "example2"
251-
assert iterate_results[1]["completed"] == False
251+
assert iterate_results[1]["completed"] is False
252252
assert iterate_results[2]["text"] == "example3"
253-
assert iterate_results[2]["completed"] == True
253+
assert iterate_results[2]["completed"] is True
254254

255255

256256
@pytest.mark.parametrize("database_url", DATABASE_URLS)
@@ -280,26 +280,26 @@ async def test_queries_raw(database_url):
280280
results = await database.fetch_all(query=query, values={"completed": True})
281281
assert len(results) == 2
282282
assert results[0]["text"] == "example1"
283-
assert results[0]["completed"] == True
283+
assert results[0]["completed"] is True
284284
assert results[1]["text"] == "example3"
285-
assert results[1]["completed"] == True
285+
assert results[1]["completed"] is True
286286

287287
# fetch_one()
288288
query = "SELECT * FROM notes WHERE completed = :completed"
289289
result = await database.fetch_one(query=query, values={"completed": False})
290290
assert result["text"] == "example2"
291-
assert result["completed"] == False
291+
assert result["completed"] is False
292292

293293
# fetch_val()
294294
query = "SELECT completed FROM notes WHERE text = :text"
295295
result = await database.fetch_val(query=query, values={"text": "example1"})
296-
assert result == True
296+
assert result is True
297297

298298
query = "SELECT * FROM notes WHERE text = :text"
299299
result = await database.fetch_val(
300300
query=query, values={"text": "example1"}, column="completed"
301301
)
302-
assert result == True
302+
assert result is True
303303

304304
# iterate()
305305
query = "SELECT * FROM notes"
@@ -308,11 +308,11 @@ async def test_queries_raw(database_url):
308308
iterate_results.append(result)
309309
assert len(iterate_results) == 3
310310
assert iterate_results[0]["text"] == "example1"
311-
assert iterate_results[0]["completed"] == True
311+
assert iterate_results[0]["completed"] is True
312312
assert iterate_results[1]["text"] == "example2"
313-
assert iterate_results[1]["completed"] == False
313+
assert iterate_results[1]["completed"] is False
314314
assert iterate_results[2]["text"] == "example3"
315-
assert iterate_results[2]["completed"] == True
315+
assert iterate_results[2]["completed"] is True
316316

317317

318318
@pytest.mark.parametrize("database_url", DATABASE_URLS)
@@ -380,7 +380,7 @@ async def test_results_support_mapping_interface(database_url):
380380

381381
assert isinstance(results_as_dicts[0]["id"], int)
382382
assert results_as_dicts[0]["text"] == "example1"
383-
assert results_as_dicts[0]["completed"] == True
383+
assert results_as_dicts[0]["completed"] is True
384384

385385

386386
@pytest.mark.parametrize("database_url", DATABASE_URLS)
@@ -467,7 +467,7 @@ async def test_execute_return_val(database_url):
467467
query = notes.select().where(notes.c.id == pk)
468468
result = await database.fetch_one(query)
469469
assert result["text"] == "example1"
470-
assert result["completed"] == True
470+
assert result["completed"] is True
471471

472472

473473
@pytest.mark.parametrize("database_url", DATABASE_URLS)
@@ -857,7 +857,7 @@ async def test_transaction_commit_low_level(database_url):
857857
try:
858858
query = notes.insert().values(text="example1", completed=True)
859859
await database.execute(query)
860-
except: # pragma: no cover
860+
except Exception: # pragma: no cover
861861
await transaction.rollback()
862862
else:
863863
await transaction.commit()
@@ -881,7 +881,7 @@ async def test_transaction_rollback_low_level(database_url):
881881
query = notes.insert().values(text="example1", completed=True)
882882
await database.execute(query)
883883
raise RuntimeError()
884-
except:
884+
except Exception:
885885
await transaction.rollback()
886886
else: # pragma: no cover
887887
await transaction.commit()
@@ -1354,13 +1354,12 @@ async def test_queries_with_expose_backend_connection(database_url):
13541354
]:
13551355
cursor = await raw_connection.cursor()
13561356
await cursor.execute(insert_query, values)
1357-
elif database.url.scheme == "mysql+asyncmy":
1357+
elif database.url.scheme in ["mysql+asyncmy", "postgresql+psycopg"]:
13581358
async with raw_connection.cursor() as cursor:
13591359
await cursor.execute(insert_query, values)
13601360
elif database.url.scheme in [
13611361
"postgresql",
13621362
"postgresql+asyncpg",
1363-
"postgresql+psycopg",
13641363
]:
13651364
await raw_connection.execute(insert_query, *values)
13661365
elif database.url.scheme in ["sqlite", "sqlite+aiosqlite"]:
@@ -1372,7 +1371,7 @@ async def test_queries_with_expose_backend_connection(database_url):
13721371
if database.url.scheme in ["mysql", "mysql+aiomysql"]:
13731372
cursor = await raw_connection.cursor()
13741373
await cursor.executemany(insert_query, values)
1375-
elif database.url.scheme == "mysql+asyncmy":
1374+
elif database.url.scheme in ["mysql+asyncmy", "postgresql+psycopg"]:
13761375
async with raw_connection.cursor() as cursor:
13771376
await cursor.executemany(insert_query, values)
13781377
elif database.url.scheme == "postgresql+aiopg":
@@ -1395,36 +1394,28 @@ async def test_queries_with_expose_backend_connection(database_url):
13951394
cursor = await raw_connection.cursor()
13961395
await cursor.execute(select_query)
13971396
results = await cursor.fetchall()
1398-
elif database.url.scheme == "mysql+asyncmy":
1397+
elif database.url.scheme in ["mysql+asyncmy", "postgresql+psycopg"]:
13991398
async with raw_connection.cursor() as cursor:
14001399
await cursor.execute(select_query)
14011400
results = await cursor.fetchall()
1402-
elif database.url.scheme in [
1403-
"postgresql",
1404-
"postgresql+asyncpg",
1405-
"postgresql+psycopg",
1406-
]:
1401+
elif database.url.scheme in ["postgresql", "postgresql+asyncpg"]:
14071402
results = await raw_connection.fetch(select_query)
14081403
elif database.url.scheme in ["sqlite", "sqlite+aiosqlite"]:
14091404
results = await raw_connection.execute_fetchall(select_query)
14101405

14111406
assert len(results) == 3
14121407
# Raw output for the raw request
14131408
assert results[0][1] == "example1"
1414-
assert results[0][2] == True
1409+
assert results[0][2] is True
14151410
assert results[1][1] == "example2"
1416-
assert results[1][2] == False
1411+
assert results[1][2] is False
14171412
assert results[2][1] == "example3"
1418-
assert results[2][2] == True
1413+
assert results[2][2] is True
14191414

14201415
# fetch_one()
1421-
if database.url.scheme in [
1422-
"postgresql",
1423-
"postgresql+asyncpg",
1424-
"postgresql+psycopg",
1425-
]:
1416+
if database.url.scheme in ["postgresql", "postgresql+asyncpg"]:
14261417
result = await raw_connection.fetchrow(select_query)
1427-
elif database.url.scheme == "mysql+asyncmy":
1418+
elif database.url.scheme in ["mysql+asyncmy", "postgresql+psycopg"]:
14281419
async with raw_connection.cursor() as cursor:
14291420
await cursor.execute(select_query)
14301421
result = await cursor.fetchone()
@@ -1435,7 +1426,7 @@ async def test_queries_with_expose_backend_connection(database_url):
14351426

14361427
# Raw output for the raw request
14371428
assert result[1] == "example1"
1438-
assert result[2] == True
1429+
assert result[2] is True
14391430

14401431

14411432
@pytest.mark.parametrize("database_url", DATABASE_URLS)
@@ -1606,7 +1597,7 @@ async def test_column_names(database_url, select_query):
16061597

16071598
assert sorted(results[0]._mapping.keys()) == ["completed", "id", "text"]
16081599
assert results[0]["text"] == "example1"
1609-
assert results[0]["completed"] == True
1600+
assert results[0]["completed"] is True
16101601

16111602

16121603
@pytest.mark.parametrize("database_url", DATABASE_URLS)
@@ -1641,23 +1632,6 @@ async def test_result_named_access(database_url):
16411632
assert result.completed is True
16421633

16431634

1644-
@pytest.mark.parametrize("database_url", DATABASE_URLS)
1645-
@async_adapter
1646-
async def test_mapping_property_interface(database_url):
1647-
"""
1648-
Test that all connections implement interface with `_mapping` property
1649-
"""
1650-
async with Database(database_url) as database:
1651-
query = notes.select()
1652-
single_result = await database.fetch_one(query=query)
1653-
assert single_result._mapping["text"] == "example1"
1654-
assert single_result._mapping["completed"] is True
1655-
1656-
list_result = await database.fetch_all(query=query)
1657-
assert list_result[0]._mapping["text"] == "example1"
1658-
assert list_result[0]._mapping["completed"] is True
1659-
1660-
16611635
@async_adapter
16621636
async def test_should_not_maintain_ref_when_no_cache_param():
16631637
async with Database(

0 commit comments

Comments
 (0)