77This file is part of afscgap released under the BSD 3-Clause License. See
88LICENSE.md.
99"""
10- import io
1110import itertools
1211import os
1312import sys
1413import time
1514import typing
1615
1716import boto3 # type: ignore
17+ import coiled # type: ignore
1818import fastavro # type: ignore
1919
2020import const
@@ -103,7 +103,7 @@ def attempt_pagination() -> typing.Iterable[str]:
103103 return attempt_pagination ()
104104
105105
106- def check_file (s3_client , bucket : str , path : str , expected_fields : typing .Iterable [str ]):
106+ def check_file (s3_client , bucket : str , path : str , expected_fields : typing .Iterable [str ]) -> bool :
107107 """Read a file and ensure it is parsable with expected keys.
108108
109109 Args:
@@ -112,6 +112,21 @@ def check_file(s3_client, bucket: str, path: str, expected_fields: typing.Iterab
112112 path: The path at which the file can be found.
113113 expected_fields: The names of the fields which are required on the downloaded file.
114114 """
115+ import io
116+ import os
117+ import time
118+
119+ import boto3 # type: ignore
120+
121+ access_key = os .environ ['AWS_ACCESS_KEY' ]
122+ access_secret = os .environ ['AWS_ACCESS_SECRET' ]
123+
124+ s3_client = boto3 .client (
125+ 's3' ,
126+ aws_access_key_id = access_key ,
127+ aws_secret_access_key = access_secret
128+ )
129+
115130 def attempt_download () -> io .BytesIO :
116131 target_buffer = io .BytesIO ()
117132 s3_client .download_fileobj (bucket , path , target_buffer )
@@ -122,13 +137,19 @@ def attempt_download() -> io.BytesIO:
122137 target_buffer = attempt_download ()
123138 except :
124139 time .sleep (const .RETRY_DELAY )
125- target_buffer = attempt_download ()
140+
141+ try :
142+ target_buffer = attempt_download ()
143+ except :
144+ return False
126145
127146 results : typing .Iterable [dict ] = list (fastavro .reader (target_buffer )) # type: ignore
128147 for result in results :
129148 for field in expected_fields :
130149 if field not in result :
131- raise RuntimeError ('Could not find %s.' % field )
150+ return False
151+
152+ return True
132153
133154
134155def main ():
@@ -159,13 +180,29 @@ def main():
159180 )
160181
161182 files = list_files (s3_client , bucket , path )
162- i = 0
163- for file in files :
164- if i % 1000 == 0 :
165- print ('Checked %d files.' % i )
166183
167- check_file (s3_client , bucket , file , fields )
168- i += 1
184+ cluster = coiled .Cluster (
185+ name = 'DseProcessAfscgapCheck' ,
186+ n_workers = 10 ,
187+ worker_vm_types = ['m7a.medium' ],
188+ scheduler_vm_types = ['m7a.medium' ],
189+ environ = {
190+ 'AWS_ACCESS_KEY' : os .environ .get ('AWS_ACCESS_KEY' , '' ),
191+ 'AWS_ACCESS_SECRET' : os .environ .get ('AWS_ACCESS_SECRET' , '' ),
192+ 'SOURCE_DATA_LOC' : os .environ .get ('SOURCE_DATA_LOC' , '' )
193+ }
194+ )
195+ cluster .adapt (minimum = 10 , maximum = 500 )
196+ client = cluster .get_client ()
197+
198+ results = client .map (
199+ lambda x : check_file (s3_client , bucket , x , fields ),
200+ files
201+ )
202+ results_realized = map (lambda x : x .result (), results )
203+
204+ for result in results_realized :
205+ assert result is True
169206
170207
171208if __name__ == '__main__' :
0 commit comments