Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions hail/python/hailtop/batch/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from hailtop.aiotools.validators import validate_file
from hailtop.batch.hail_genetics_images import HAIL_GENETICS_IMAGES, hailgenetics_hail_image_for_current_python_version
from hailtop.batch_client.aioclient import BatchClient as AioBatchClient
from hailtop.batch_client.client import BatchClient
from hailtop.batch_client.parse import parse_cpu_in_mcpu
from hailtop.config import ConfigVariable, configuration_of, get_deploy_config, get_remote_tmpdir
from hailtop.utils import async_to_blocking, bounded_gather, parse_docker_image_reference, url_scheme
Expand Down Expand Up @@ -527,39 +526,35 @@ class ServiceBackend(Backend[bc.Batch]):
ANY_REGION: ClassVar[List[str]] = ['any_region']
"""A special value that indicates a job may run in any region."""

@staticmethod
def supported_regions():
def supported_regions(self):
"""
Get the supported cloud regions

Examples
--------
>>> regions = ServiceBackend.supported_regions()
>>> regions = service_backend.supported_regions() # doctest: +SKIP

Returns
-------
A list of the supported cloud regions
"""
with BatchClient('dummy') as dummy_client:
return dummy_client.supported_regions()
return async_to_blocking(self._batch_client_sync().supported_regions())

@staticmethod
def default_region():
def default_region(self):
"""
Get the default cloud region

This value is "us-central1" for the Hail Team maintained Batch instance.

Examples
--------
>>> region = ServiceBackend.default_region()
>>> region = service_backend.default_region() # doctest: +SKIP

Returns
-------
The default region jobs run in when no regions are specified
"""
with BatchClient('dummy') as dummy_client:
return dummy_client.default_region()
return async_to_blocking(self._batch_client_sync().default_region())

def __init__(
self,
Expand All @@ -573,6 +568,9 @@ def __init__(
gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None,
gcs_bucket_allow_list: Optional[List[str]] = None,
):
self.__batch_client: Optional[AioBatchClient] = None
self._token = token

if len(args) > 2:
raise TypeError(f'ServiceBackend() takes 2 positional arguments but {len(args)} were given')
if len(args) >= 1:
Expand All @@ -598,7 +596,6 @@ def __init__(
'MY_BILLING_PROJECT`'
)
self._billing_project = billing_project
self._token = token

self.remote_tmpdir = get_remote_tmpdir('ServiceBackend', bucket=bucket, remote_tmpdir=remote_tmpdir)

Expand Down Expand Up @@ -634,18 +631,20 @@ def __init__(
assert isinstance(regions_from_conf, str)
regions = regions_from_conf.split(',')
else:
regions = [ServiceBackend.default_region()]
regions = [self.default_region()]
elif regions == ServiceBackend.ANY_REGION:
regions = None

self.regions = regions
self.__batch_client: Optional[AioBatchClient] = None

async def _batch_client(self) -> AioBatchClient:
if self.__batch_client is None:
self.__batch_client = await AioBatchClient.create(self._billing_project, _token=self._token)
return self.__batch_client

def _batch_client_sync(self) -> AioBatchClient:
return async_to_blocking(self._batch_client())

@property
def _fs(self) -> RouterAsyncFS:
return self.__fs
Expand Down