aboutsummaryrefslogtreecommitdiff
path: root/scripts/common/db.py
blob: bbd220cb323b4b9a9458684d5f40321975795480 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
"""
Database access layer using psycopg2 with connection pooling.

All DB access in the v2 pipeline goes through this module.
"""

import csv
import io
import sys

import psycopg2
import psycopg2.pool
import psycopg2.extras

DB_HOST = "localhost"
DB_PORT = 35434
DB_USER = "tidyindex"
DB_PASSWORD = "tidyindex"
DB_NAME = "tidyindex"

_pool = None


def get_pool():
    """Get or create the connection pool (lazy singleton)."""
    global _pool
    if _pool is None:
        _pool = psycopg2.pool.SimpleConnectionPool(
            minconn=1,
            maxconn=4,
            host=DB_HOST,
            port=DB_PORT,
            user=DB_USER,
            password=DB_PASSWORD,
            dbname=DB_NAME,
        )
    return _pool


def get_conn():
    """Get a connection from the pool."""
    return get_pool().getconn()


def put_conn(conn):
    """Return a connection to the pool."""
    get_pool().putconn(conn)


def execute(sql, params=None):
    """Execute SQL (no result). Auto-commits."""
    conn = get_conn()
    try:
        with conn.cursor() as cur:
            cur.execute(sql, params)
        conn.commit()
    except Exception:
        conn.rollback()
        raise
    finally:
        put_conn(conn)


def execute_scalar(sql, params=None):
    """Execute SQL and return a single scalar value, or None."""
    conn = get_conn()
    try:
        with conn.cursor() as cur:
            cur.execute(sql, params)
            row = cur.fetchone()
        conn.commit()
        return row[0] if row else None
    except Exception:
        conn.rollback()
        raise
    finally:
        put_conn(conn)


def execute_all(sql, params=None):
    """Execute SQL and return all rows as a list of tuples."""
    conn = get_conn()
    try:
        with conn.cursor() as cur:
            cur.execute(sql, params)
            rows = cur.fetchall()
        conn.commit()
        return rows
    except Exception:
        conn.rollback()
        raise
    finally:
        put_conn(conn)


def execute_transaction(fn):
    """Run fn(conn) inside a transaction. Commits on success, rolls back on error.

    fn receives a connection with autocommit off. fn should use conn.cursor()
    to execute statements. Do NOT commit inside fn — this wrapper handles it.
    """
    conn = get_conn()
    try:
        result = fn(conn)
        conn.commit()
        return result
    except Exception:
        conn.rollback()
        raise
    finally:
        put_conn(conn)


def copy_rows(table, columns, rows):
    """Bulk insert rows via COPY FROM. Returns count of inserted rows."""
    if not rows:
        return 0

    buf = io.StringIO()
    writer = csv.DictWriter(buf, fieldnames=columns, extrasaction="ignore")
    for row in rows:
        writer.writerow(row)
    buf.seek(0)

    conn = get_conn()
    try:
        with conn.cursor() as cur:
            cur.copy_expert(
                f"COPY {table} ({','.join(columns)}) FROM STDIN WITH (FORMAT csv)",
                buf,
            )
            count = cur.rowcount
        conn.commit()
        return count
    except Exception:
        conn.rollback()
        raise
    finally:
        put_conn(conn)


# ============================================================
# Legacy compatibility (for old parsers using parse_common.py)
# These shell-based functions are NOT used by v2 code.
# ============================================================

import subprocess

_DB_CONTAINER = "tidyindex-postgres"


def psql(sql):
    """Execute SQL via docker exec psql. Legacy — use execute() instead."""
    result = subprocess.run(
        ["docker", "exec", "-i", _DB_CONTAINER, "psql", "-U", DB_USER, "-d", DB_NAME],
        input=sql,
        capture_output=True,
        text=True,
    )
    if result.returncode != 0:
        print(f"PSQL ERROR: {result.stderr}", file=sys.stderr)
        sys.exit(1)
    return result.stdout


def psql_scalar(sql):
    """Legacy — use execute_scalar() instead."""
    result = subprocess.run(
        [
            "docker", "exec", "-i", _DB_CONTAINER,
            "psql", "-U", DB_USER, "-d", DB_NAME, "-t", "-A",
        ],
        input=sql,
        capture_output=True,
        text=True,
    )
    if result.returncode != 0:
        print(f"PSQL ERROR: {result.stderr}", file=sys.stderr)
        sys.exit(1)
    for line in result.stdout.strip().split("\n"):
        line = line.strip()
        if line and not line.startswith("INSERT") and not line.startswith("UPDATE") and not line.startswith("DELETE"):
            return line
    return None


def psql_query_values(sql):
    """Legacy — use execute_all() instead."""
    result = psql(sql)
    lines = result.strip().split("\n")
    if len(lines) >= 3:
        return [line.strip() for line in lines[2:-1]]
    return []


def insert_rows(table, columns, rows):
    """Legacy — use copy_rows() instead."""
    return copy_rows(table, columns, rows)