From df6bc0bc2daf80129b7cda3a3227ce01efbf4543 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micka=C3=ABl=20Schoentgen?= Date: Thu, 17 Dec 2020 18:06:12 +0100 Subject: [PATCH] NXDRIVE-2455: Improve SQL performances on get_download(), get_upload() and get_dt_upload() --- nxdrive/engine/dao/sqlite.py | 202 ++++++++++--------- nxdrive/engine/engine.py | 18 +- nxdrive/gui/api.py | 8 +- tests/old_functional/test_direct_transfer.py | 6 +- tests/old_functional/test_transfer.py | 12 +- 5 files changed, 125 insertions(+), 121 deletions(-) diff --git a/nxdrive/engine/dao/sqlite.py b/nxdrive/engine/dao/sqlite.py index 5f77c9f61..010cc5932 100644 --- a/nxdrive/engine/dao/sqlite.py +++ b/nxdrive/engine/dao/sqlite.py @@ -8,6 +8,7 @@ import sys from contextlib import suppress from datetime import datetime +from functools import partial from logging import getLogger from os.path import basename from pathlib import Path, PosixPath, WindowsPath @@ -27,7 +28,6 @@ Any, Callable, Dict, - Generator, List, Optional, Tuple, @@ -138,6 +138,44 @@ def _adapt_path(path: Path, /) -> str: register_adapter(WindowsPath if WINDOWS else PosixPath, _adapt_path) +def get_transfer_with_func( + func: Callable, + /, + *, + uid: int = None, + path: Path = None, + doc_pair: int = None, + status: TransferStatus = None, +) -> Optional[Union[Download, Upload]]: + """Helper to fetch 1 transfer.""" + value: Any + if uid is not None: + key, value = "uid", uid + elif path is not None: + key, value = "path", path + elif doc_pair is not None: + key, value = "doc_pair", doc_pair + elif status is not None: + key = "status" + value = status.value + else: + # Should never happen + log.error(f"get_transfer_with_func({func!r}, {uid!r}, {path!r}, {doc_pair!r}") + return None + + res = func(key=key, value=value, limit=1) + return res[0] if res else None + + +def status(value: int) -> TransferStatus: + """Helper to handle the status value from the database.""" + try: + return TransferStatus(value) + except ValueError: + # Most likely a NXDRIVE-1901 case + return TransferStatus.DONE + + class AutoRetryCursor(Cursor): def execute(self, *args: str, **kwargs: Any) -> Cursor: count = 1 @@ -646,6 +684,11 @@ def __init__(self, db: Path, /) -> None: self._filters = self.get_filters() self.reinit_processors() + # Helpers to retrieve only 1 item (or None if no entry found) + self.get_download = partial(get_transfer_with_func, self.get_downloads) + self.get_upload = partial(get_transfer_with_func, self.get_uploads) + self.get_dt_upload = partial(get_transfer_with_func, self.get_dt_uploads) + def get_schema_version(self) -> int: return 17 @@ -2325,20 +2368,26 @@ def remove_filter(self, path: str, /) -> None: self._filters = self.get_filters() self.get_syncing_count() - def get_downloads(self) -> Generator[Download, None, None]: - con = self._get_read_connection() - c = con.cursor() - for res in c.execute("SELECT * FROM Downloads").fetchall(): - try: - status = TransferStatus(res.status) - except ValueError: - # Most likely a NXDRIVE-1901 case - status = TransferStatus.DONE + def get_downloads( + self, *, key: str = "", value: str = "", limit: int = 1000 + ) -> List[Download]: + """Retrieve sync downloads. + It is possible to filter on a given *key* having the value *value*. + """ + + sql = "SELECT * FROM Downloads" + args = () + if key: + sql += f" WHERE {key} = ?" + args = (value,) + sql += f" LIMIT {limit}" - yield Download( + c = self._get_read_connection().cursor() + return [ + Download( res.uid, Path(res.path), - status, + status(res.status), res.engine, is_direct_edit=res.is_direct_edit, progress=res.progress, @@ -2347,23 +2396,29 @@ def get_downloads(self) -> Generator[Download, None, None]: tmpname=Path(res.tmpname), url=res.url, ) + for res in c.execute(sql, args).fetchall() + ] - def get_uploads(self) -> Generator[Upload, None, None]: - con = self._get_read_connection() - c = con.cursor() - for res in c.execute( - "SELECT * FROM Uploads WHERE is_direct_transfer = 0" - ).fetchall(): - try: - status = TransferStatus(res.status) - except ValueError: - # Most likely a NXDRIVE-1901 case - status = TransferStatus.DONE + def get_uploads( + self, *, key: str = "", value: str = "", limit: int = 1000 + ) -> List[Upload]: + """Retrieve sync uploads. + It is possible to filter on a given *key* having the value *value*. + """ + + sql = "SELECT * FROM Uploads WHERE is_direct_transfer = 0" + args = () + if key: + sql += f" AND {key} = ?" + args = (value,) + sql += f" LIMIT {limit}" - yield Upload( + c = self._get_read_connection().cursor() + return [ + Upload( res.uid, Path(res.path), - status, + status(res.status), res.engine, is_direct_edit=res.is_direct_edit, progress=res.progress, @@ -2372,18 +2427,29 @@ def get_uploads(self) -> Generator[Upload, None, None]: batch=json.loads(res.batch), chunk_size=res.chunk_size or 0, ) + for res in c.execute(sql, args).fetchall() + ] - def get_dt_uploads(self) -> Generator[Upload, None, None]: - """Retrieve all Direct Transfer items (only needed details).""" - con = self._get_read_connection() - c = con.cursor() - for res in c.execute( - "SELECT * FROM Uploads WHERE is_direct_transfer = 1" - ).fetchall(): - yield Upload( + def get_dt_uploads( + self, *, key: str = "", value: str = "", limit: int = 1000 + ) -> List[Upload]: + """Retrieve Direct Transfer items. + It is possible to filter on a given *key* having the value *value*. + """ + + sql = "SELECT * FROM Uploads WHERE is_direct_transfer = 1" + args = () + if key: + sql += f" AND {key} = ?" + args = (value,) + sql += f" LIMIT {limit}" + + c = self._get_read_connection().cursor() + return [ + Upload( res.uid, Path(res.path), - TransferStatus(res.status), + status(res.status), res.engine, batch=json.loads(res.batch), chunk_size=res.chunk_size or 0, @@ -2394,6 +2460,8 @@ def get_dt_uploads(self) -> Generator[Upload, None, None]: remote_parent_path=res.remote_parent_path, remote_parent_ref=res.remote_parent_ref, ) + for res in c.execute(sql, args).fetchall() + ] def get_dt_uploads_raw( self, *, limit: int = 1, chunked: bool = False @@ -2415,7 +2483,7 @@ def get_dt_uploads_raw( "uid": res.uid, "name": basename(res.path), # More efficient than Path(res.path).name "filesize": res.filesize, - "status": TransferStatus(res.status), + "status": status(res.status), "engine": res.engine, "progress": res.progress or 0.0, "doc_pair": res.doc_pair, @@ -2436,7 +2504,7 @@ def get_active_sessions_raw(self) -> List[Dict[str, Any]]: return [ { "uid": res.uid, - "status": TransferStatus(res.status), + "status": status(res.status), "remote_path": res.remote_path, "remote_ref": res.remote_ref, "uploaded": res.uploaded, @@ -2464,7 +2532,7 @@ def get_completed_sessions_raw(self, *, limit: int = 1) -> List[Dict[str, Any]]: return [ { "uid": res.uid, - "status": TransferStatus(res.status), + "status": status(res.status), "remote_path": res.remote_path, "remote_ref": res.remote_ref, "uploaded": res.uploaded, @@ -2493,7 +2561,7 @@ def get_session(self, uid: int, /) -> Optional[Session]: res.uid, res.remote_path, res.remote_ref, - TransferStatus(res.status), + status(res.status), res.uploaded, res.total, res.engine, @@ -2611,64 +2679,6 @@ def decrease_session_counts(self, uid: int, /) -> Optional[Session]: self.sessionUpdated.emit() return session - def get_downloads_with_status(self, status: TransferStatus, /) -> List[Download]: - return [d for d in self.get_downloads() if d.status == status] - - def get_uploads_with_status(self, status: TransferStatus, /) -> List[Upload]: - return self._get_uploads_with_status_and_func(self.get_uploads, status) - - def get_dt_uploads_with_status(self, status: TransferStatus, /) -> List[Upload]: - return self._get_uploads_with_status_and_func(self.get_dt_uploads, status) - - def _get_uploads_with_status_and_func( - self, func: Callable, status: TransferStatus, / - ) -> List[Upload]: - return [u for u in func() if u.status == status] - - def get_download( - self, *, uid: int = None, path: Path = None, doc_pair: int = None - ) -> Optional[Download]: - value: Any - if uid: - key, value = "uid", uid - elif path: - key, value = "path", path - elif doc_pair: - key, value = "doc_pair", doc_pair - else: - return None - - res = [d for d in self.get_downloads() if getattr(d, key) == value] - return res[0] if res else None - - def get_upload(self, **kwargs: Any) -> Optional[Upload]: - return self._get_upload_with_func(self.get_uploads, **kwargs) - - def get_dt_upload(self, **kwargs: Any) -> Optional[Upload]: - return self._get_upload_with_func(self.get_dt_uploads, **kwargs) - - def _get_upload_with_func( - self, - func: Callable, - /, - *, - uid: int = None, - path: Path = None, - doc_pair: int = None, - ) -> Optional[Upload]: - value: Any - if uid: - key, value = "uid", uid - elif path: - key, value = "path", path - elif doc_pair: - key, value = "doc_pair", doc_pair - else: - return None - - res = [u for u in func() if getattr(u, key) == value] - return res[0] if res else None - def save_download(self, download: Download, /) -> None: """New download.""" with self.lock: diff --git a/nxdrive/engine/engine.py b/nxdrive/engine/engine.py index 75ad82017..b422e26bc 100644 --- a/nxdrive/engine/engine.py +++ b/nxdrive/engine/engine.py @@ -621,11 +621,7 @@ def _resume_transfers( resume = self.dao.resume_transfer get_state = self.dao.get_state_from_id - transfers = func() - if not isinstance(transfers, list): - transfers = [transfers] - - for transfer in transfers: + for transfer in func(): if transfer.uid is None: continue @@ -655,12 +651,14 @@ def resume_suspended_transfers(self) -> None: status = TransferStatus.SUSPENDED self._resume_transfers( - "download", partial(dao.get_downloads_with_status, status) + "download", partial(dao.get_downloads, key="status", value=status.value) + ) + self._resume_transfers( + "upload", partial(dao.get_uploads, key="status", value=status.value) ) - self._resume_transfers("upload", partial(dao.get_uploads_with_status, status)) self._resume_transfers( "upload", - partial(dao.get_dt_uploads_with_status, status), + partial(dao.get_dt_uploads, key="status", value=status.value), is_direct_transfer=True, ) @@ -688,8 +686,8 @@ def _manage_staled_transfers(self) -> None: dao = self.dao for nature in ("download", "upload"): - meth = getattr(dao, f"get_{nature}s_with_status") - for transfer in meth(TransferStatus.ONGOING): + meth = getattr(dao, f"get_{nature}s") + for transfer in meth(key="status", value=TransferStatus.ONGOING.value): if app_has_crashed: # Update the status to let .resume_suspended_transfers() processing it transfer.status = TransferStatus.SUSPENDED diff --git a/nxdrive/gui/api.py b/nxdrive/gui/api.py index a45525ff3..7513c3864 100644 --- a/nxdrive/gui/api.py +++ b/nxdrive/gui/api.py @@ -189,14 +189,10 @@ def get_transfers(self, dao: EngineDAO, /) -> List[Dict[str, Any]]: limit = 5 # 10 files are displayed in the systray, so take 5 of each kind result: List[Dict[str, Any]] = [] - for count, download in enumerate(dao.get_downloads()): - if count >= limit: - break + for download in dao.get_downloads(limit=limit): result.append(asdict(download)) - for count, upload in enumerate(dao.get_uploads()): - if count >= limit: - break + for upload in dao.get_uploads(limit=limit): result.append(asdict(upload)) return result diff --git a/tests/old_functional/test_direct_transfer.py b/tests/old_functional/test_direct_transfer.py index 1d2daf4e7..70df2e2ce 100644 --- a/tests/old_functional/test_direct_transfer.py +++ b/tests/old_functional/test_direct_transfer.py @@ -159,7 +159,7 @@ def callback(*_): self.direct_transfer(last_local_selected_location=self.file.parent) self.wait_sync() - assert dao.get_dt_uploads_with_status(TransferStatus.PAUSED) + assert dao.get_dt_upload(status=TransferStatus.PAUSED) last_location = dao.get_config("dt_last_local_selected_location") assert last_location @@ -319,7 +319,7 @@ def callback(*_): with ensure_no_exception(): self.direct_transfer() self.wait_sync() - assert dao.get_dt_uploads_with_status(TransferStatus.PAUSED) + assert dao.get_dt_upload(status=TransferStatus.PAUSED) # Resume the upload engine.resume_transfer( @@ -354,7 +354,7 @@ def callback(*_): with ensure_no_exception(): self.direct_transfer() self.wait_sync() - assert dao.get_dt_uploads_with_status(TransferStatus.SUSPENDED) + assert dao.get_dt_upload(status=TransferStatus.SUSPENDED) # Resume the upload self.manager_1.resume() diff --git a/tests/old_functional/test_transfer.py b/tests/old_functional/test_transfer.py index dbca32dae..847b2994e 100644 --- a/tests/old_functional/test_transfer.py +++ b/tests/old_functional/test_transfer.py @@ -84,7 +84,7 @@ def callback(downloader): with patch.object(engine.remote, "download_callback", new=callback): with ensure_no_exception(): self.wait_sync(wait_for_async=True) - assert dao.get_downloads_with_status(TransferStatus.PAUSED) + assert dao.get_download(status=TransferStatus.PAUSED) # Resume the download engine.resume_transfer("download", list(dao.get_downloads())[0].uid) @@ -129,7 +129,7 @@ def callback(downloader): with patch.object(engine.remote, "download_callback", new=callback): with ensure_no_exception(): self.wait_sync(wait_for_async=True) - assert dao.get_downloads_with_status(TransferStatus.SUSPENDED) + assert dao.get_download(status=TransferStatus.SUSPENDED) # Resume the download self.manager_1.resume() @@ -281,7 +281,7 @@ def callback(uploader): with patch.object(engine.remote, "upload_callback", new=callback): with ensure_no_exception(): self.wait_sync() - assert dao.get_uploads_with_status(TransferStatus.PAUSED) + assert dao.get_upload(status=TransferStatus.PAUSED) # Resume the upload engine.resume_transfer("upload", list(dao.get_uploads())[0].uid) @@ -317,7 +317,7 @@ def callback(uploader): with patch.object(engine.remote, "upload_callback", new=callback): with ensure_no_exception(): self.wait_sync() - assert dao.get_uploads_with_status(TransferStatus.SUSPENDED) + assert dao.get_upload(status=TransferStatus.SUSPENDED) # Resume the upload self.manager_1.resume() @@ -732,7 +732,7 @@ def callback(uploader): self.wait_sync() # For now, the transfer is only suspended - assert dao.get_uploads_with_status(TransferStatus.SUSPENDED) + assert dao.get_upload(status=TransferStatus.SUSPENDED) # Stop the engine engine.stop() @@ -741,7 +741,7 @@ def callback(uploader): upload = list(dao.get_uploads())[0] upload.status = TransferStatus.ONGOING dao.set_transfer_status("upload", upload) - assert dao.get_uploads_with_status(TransferStatus.ONGOING) + assert dao.get_upload(status=TransferStatus.ONGOING) # Simple check: nothing has been uploaded yet assert not self.remote_1.exists("/test.bin")