Source code for rook.catalog.db

import sqlalchemy
from pywps.dblog import get_session
from sqlalchemy import text

from .base import Catalog
from .intake import IntakeCatalog
from .util import MAX_DATETIME, MIN_DATETIME, parse_time


[docs] class DBCatalog(Catalog): def __init__(self, project, url=None): super().__init__(project) self.table_name = f"rook_catalog_{self.project}".replace("-", "_") self.intake_catalog = IntakeCatalog(project, url)
[docs] def exists(self): with get_session() as session: try: engine = session.get_bind() ins = sqlalchemy.inspect(engine) return ins.dialect.has_table(engine.connect(), self.table_name) except Exception: return False
[docs] def update(self): if not self.exists(): self.to_db()
[docs] def to_db(self): df = self.intake_catalog.load() # Handle NaN values and undefined values df = df.fillna({"start_time": MIN_DATETIME, "end_time": MAX_DATETIME}) df = df.replace({"start_time": {"undefined": MIN_DATETIME}}) df = df.replace({"end_time": {"undefined": MAX_DATETIME}}) df = df.set_index("ds_id") # db connection with get_session() as session: engine = session.get_bind() df.to_sql( name=self.table_name, con=engine, if_exists="replace", index=True, chunksize=500, ) session.commit()
def _query(self, collection, time=None, time_components=None): """Query database to get the given collection (dataset id).""" self.update() start, end = parse_time(time, time_components) with get_session() as session: try: # Parameterized query to avoid SQL injection if len(collection) > 1: # FIXME: This is vulnerable to SQL injection query_ = text( f"SELECT * FROM {self.table_name} WHERE ds_id IN {tuple(collection)} " # noqa: S608 f"AND end_time >= :start AND start_time <= :end" ) result = session.execute(query_, { # "table_name": self.table_name, # "collection": tuple(collection), "start": start, "end": end }).fetchall() else: # FIXME: This is vulnerable to SQL injection query_ = text( f"SELECT * FROM {self.table_name} WHERE ds_id = :ds_id " # noqa: S608 f"AND end_time >= :start AND start_time <= :end" ) result = session.execute(query_, { # "table_name": self.table_name, "ds_id": collection[0], "start": start, "end": end }).fetchall() except Exception: result = [] # Processing result records = {} for row in result: if row.ds_id not in records: records[row.ds_id] = [] records[row.ds_id].append(row.path) return records