Skip to content

Commit 2bbdf4b

Browse files
committed
MySQLPersistentKeyValueCache:
* Switch from MySQLdb to pymsql * Add support for additional connection arguments * Use autocommit and remove option of using deferred commits as it's the only way to guarantee that no stale data is read due to transactions going on too long * Handle duplicate key upon insertion due to race condition by providing a more informative Exception
1 parent 1f17c46 commit 2bbdf4b

File tree

2 files changed

+78
-20
lines changed

2 files changed

+78
-20
lines changed

CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,17 @@
77
* Dropped support for Python versions below 3.10
88
* Dropped support for TensorFlow (removing `sensai.tensorflow`)
99

10+
### Improvements/Changes
11+
12+
* `util`:
13+
* `util.cache`:
14+
* `cache_mysql.MySQLPersistentKeyValueCache`:
15+
* Switch from MySQLdb to pymsql
16+
* Add support for additional connection arguments
17+
* Use autocommit and remove option of using deferred commits as it's the only way to guarantee
18+
that no stale data is read due to transactions going on too long
19+
* Handle duplicate key upon insertion due to race condition by providing a more informative Exception
20+
1021

1122
## 1.4.0 (2025-01-21)
1223

src/sensai/util/cache_mysql.py

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,74 @@
99

1010

1111
class MySQLPersistentKeyValueCache(PersistentKeyValueCache):
12+
"""
13+
Can cache arbitrary values in a MySQL database.
14+
The keys are always strings at the database level, i.e. if a key is not a string, it is converted to a string using str().
15+
"""
1216

1317
class ValueType(enum.Enum):
14-
DOUBLE = ("DOUBLE", False) # (SQL data type, isCachedValuePickled)
18+
"""
19+
The value type to use within the MySQL database.
20+
Note that the binary BLOB types can be used for all Python types that can be pickled, so the lack
21+
of specific types (e.g. for strings) is not a problem.
22+
"""
23+
# enum values are (SQL data type, isCachedValuePickled)
24+
DOUBLE = ("DOUBLE", False)
1525
BLOB = ("BLOB", True)
26+
"""
27+
for Python data types whose pickled representation is up to 64 KB
28+
"""
29+
MEDIUMBLOB = ("MEDIUMBLOB", True)
30+
"""
31+
for Python data types whose pickled representation is up to 16 MB
32+
"""
1633

17-
def __init__(self, host, db, user, pw, value_type: ValueType, table_name="cache", deferred_commit_delay_secs=1.0, in_memory=False):
18-
import MySQLdb
19-
self.conn = MySQLdb.connect(host=host, database=db, user=user, password=pw)
34+
def __init__(self, host: str, db: str, user: str, pw: str, value_type: ValueType, table_name="cache",
35+
connect_params: dict | None = None, in_memory=False, max_key_length: int = 255, port=3306):
36+
"""
37+
:param host:
38+
:param db:
39+
:param user:
40+
:param pw:
41+
:param value_type: the type of value to store in the cache
42+
:param table_name:
43+
:param connect_params: additional parameters to pass to the pymysql.connect() function (e.g. ssl, etc.)
44+
:param in_memory:
45+
:param max_key_length: maximal length of the cache key string (keys are always strings) stored in the DB
46+
(i.e. the MySQL type is VARCHAR[max_key_length])
47+
:param port: the MySQL server port to connect to
48+
"""
49+
import pymysql
50+
if connect_params is None:
51+
connect_params = {}
52+
self._connect = lambda: pymysql.connect(host=host, database=db, user=user, password=pw, port=port, autocommit=True,
53+
**connect_params)
54+
self._conn = self._connect()
2055
self.table_name = table_name
21-
self.max_key_length = 255
22-
self._update_hook = DelayedUpdateHook(self._commit, deferred_commit_delay_secs)
23-
self._num_entries_to_be_committed = 0
56+
self.max_key_length = max_key_length
2457

2558
cache_value_sql_type, self.is_cache_value_pickled = value_type.value
2659

27-
cursor = self.conn.cursor()
60+
cursor = self._conn.cursor()
2861
cursor.execute(f"SHOW TABLES;")
2962
if table_name not in [r[0] for r in cursor.fetchall()]:
63+
log.debug(f"Creating table {table_name}")
3064
cursor.execute(f"CREATE TABLE {table_name} (cache_key VARCHAR({self.max_key_length}) PRIMARY KEY, "
3165
f"cache_value {cache_value_sql_type});")
3266
cursor.close()
3367

3468
self._in_memory_df = None if not in_memory else self._load_table_to_data_frame()
3569

70+
def _cursor(self):
71+
try:
72+
self._conn.ping(reconnect=True)
73+
except Exception as e:
74+
log.error(f"Error while pinging MySQL server: {e}; Reconnecting ...")
75+
self._conn = self._connect()
76+
return self._conn.cursor()
77+
3678
def _load_table_to_data_frame(self):
37-
df = pd.read_sql(f"SELECT * FROM {self.table_name};", con=self.conn, index_col="cache_key")
79+
df = pd.read_sql(f"SELECT * FROM {self.table_name};", con=self._conn, index_col="cache_key")
3880
if self.is_cache_value_pickled:
3981
df["cache_value"] = df["cache_value"].apply(pickle.loads)
4082
return df
@@ -43,16 +85,25 @@ def set(self, key, value):
4385
key = str(key)
4486
if len(key) > self.max_key_length:
4587
raise ValueError(f"Key too long, maximal key length is {self.max_key_length}")
46-
cursor = self.conn.cursor()
88+
cursor = self._cursor()
4789
cursor.execute(f"SELECT COUNT(*) FROM {self.table_name} WHERE cache_key=%s", (key,))
4890
stored_value = pickle.dumps(value) if self.is_cache_value_pickled else value
4991
if cursor.fetchone()[0] == 0:
50-
cursor.execute(f"INSERT INTO {self.table_name} (cache_key, cache_value) VALUES (%s, %s)",
51-
(key, stored_value))
92+
from pymysql.err import IntegrityError
93+
try:
94+
cursor.execute(f"INSERT INTO {self.table_name} (cache_key, cache_value) VALUES (%s, %s)",
95+
(key, stored_value))
96+
except IntegrityError as e:
97+
if e.args[0] == 1062: # Duplicate entry
98+
# This can only happen when the user is inserting the same value almost simultaneously (race condition)
99+
args = list(e.args)
100+
args[1] = f"{args[1]}; The duplicate entry is due to quasi-simultaneous insertions for the same key; " \
101+
"Check your application logic!"
102+
raise IntegrityError(*args)
103+
else:
104+
raise
52105
else:
53106
cursor.execute(f"UPDATE {self.table_name} SET cache_value=%s WHERE cache_key=%s", (stored_value, key))
54-
self._num_entries_to_be_committed += 1
55-
self._update_hook.handle_update()
56107
cursor.close()
57108
if self._in_memory_df is not None:
58109
self._in_memory_df["cache_value"][str(key)] = value
@@ -64,9 +115,10 @@ def get(self, key):
64115
return value
65116

66117
def _get_from_table(self, key):
67-
cursor = self.conn.cursor()
118+
cursor = self._cursor()
68119
cursor.execute(f"SELECT cache_value FROM {self.table_name} WHERE cache_key=%s", (str(key),))
69120
row = cursor.fetchone()
121+
cursor.close()
70122
if row is None:
71123
return None
72124
stored_value = row[0]
@@ -81,8 +133,3 @@ def _get_from_in_memory_df(self, key):
81133
except Exception as e:
82134
log.debug(f"Unable to load value for key {str(key)} from in-memory dataframe: {e}")
83135
return None
84-
85-
def _commit(self):
86-
log.info(f"Committing {self._num_entries_to_be_committed} cache entries to the database")
87-
self.conn.commit()
88-
self._num_entries_to_be_committed = 0

0 commit comments

Comments
 (0)