Skip to content

Commit 5abaca0

Browse files
committed
Make check distributed.
1 parent 1da3015 commit 5abaca0

File tree

1 file changed

+47
-10
lines changed

1 file changed

+47
-10
lines changed

snapshot/check_read.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
This file is part of afscgap released under the BSD 3-Clause License. See
88
LICENSE.md.
99
"""
10-
import io
1110
import itertools
1211
import os
1312
import sys
1413
import time
1514
import typing
1615

1716
import boto3 # type: ignore
17+
import coiled # type: ignore
1818
import fastavro # type: ignore
1919

2020
import const
@@ -103,7 +103,7 @@ def attempt_pagination() -> typing.Iterable[str]:
103103
return attempt_pagination()
104104

105105

106-
def check_file(s3_client, bucket: str, path: str, expected_fields: typing.Iterable[str]):
106+
def check_file(s3_client, bucket: str, path: str, expected_fields: typing.Iterable[str]) -> bool:
107107
"""Read a file and ensure it is parsable with expected keys.
108108
109109
Args:
@@ -112,6 +112,21 @@ def check_file(s3_client, bucket: str, path: str, expected_fields: typing.Iterab
112112
path: The path at which the file can be found.
113113
expected_fields: The names of the fields which are required on the downloaded file.
114114
"""
115+
import io
116+
import os
117+
import time
118+
119+
import boto3 # type: ignore
120+
121+
access_key = os.environ['AWS_ACCESS_KEY']
122+
access_secret = os.environ['AWS_ACCESS_SECRET']
123+
124+
s3_client = boto3.client(
125+
's3',
126+
aws_access_key_id=access_key,
127+
aws_secret_access_key=access_secret
128+
)
129+
115130
def attempt_download() -> io.BytesIO:
116131
target_buffer = io.BytesIO()
117132
s3_client.download_fileobj(bucket, path, target_buffer)
@@ -122,13 +137,19 @@ def attempt_download() -> io.BytesIO:
122137
target_buffer = attempt_download()
123138
except:
124139
time.sleep(const.RETRY_DELAY)
125-
target_buffer = attempt_download()
140+
141+
try:
142+
target_buffer = attempt_download()
143+
except:
144+
return False
126145

127146
results: typing.Iterable[dict] = list(fastavro.reader(target_buffer)) # type: ignore
128147
for result in results:
129148
for field in expected_fields:
130149
if field not in result:
131-
raise RuntimeError('Could not find %s.' % field)
150+
return False
151+
152+
return True
132153

133154

134155
def main():
@@ -159,13 +180,29 @@ def main():
159180
)
160181

161182
files = list_files(s3_client, bucket, path)
162-
i = 0
163-
for file in files:
164-
if i % 1000 == 0:
165-
print('Checked %d files.' % i)
166183

167-
check_file(s3_client, bucket, file, fields)
168-
i += 1
184+
cluster = coiled.Cluster(
185+
name='DseProcessAfscgapCheck',
186+
n_workers=10,
187+
worker_vm_types=['m7a.medium'],
188+
scheduler_vm_types=['m7a.medium'],
189+
environ={
190+
'AWS_ACCESS_KEY': os.environ.get('AWS_ACCESS_KEY', ''),
191+
'AWS_ACCESS_SECRET': os.environ.get('AWS_ACCESS_SECRET', ''),
192+
'SOURCE_DATA_LOC': os.environ.get('SOURCE_DATA_LOC', '')
193+
}
194+
)
195+
cluster.adapt(minimum=10, maximum=500)
196+
client = cluster.get_client()
197+
198+
results = client.map(
199+
lambda x: check_file(s3_client, bucket, x, fields),
200+
files
201+
)
202+
results_realized = map(lambda x: x.result(), results)
203+
204+
for result in results_realized:
205+
assert result is True
169206

170207

171208
if __name__ == '__main__':

0 commit comments

Comments
 (0)