Skip to content

Commit a95d87d

Browse files
authored
Merge pull request #214 from Terradue/develop
Minor fixes to Dask integration with additional test coverage
2 parents 2835be5 + 81e8c82 commit a95d87d

File tree

2 files changed

+91
-9
lines changed

2 files changed

+91
-9
lines changed

calrissian/dask.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,9 @@ def wait_for_completion(self, cm_name: str=None) -> CompletionResult:
436436
raise CalrissianJobException('Unexpected pod container status', status)
437437
elif self.state_is_terminated(last_status.state):
438438
log.info('Handling terminated pod name {} with id {}'.format(pod.metadata.name, pod.metadata.uid))
439-
container = self.get_last_or_none(pod.spec.containers)
439+
container = self.get_container_by_name(pod.spec.containers, 'main-container')
440+
if container is None:
441+
raise CalrissianJobException("Container 'main-container' not found in pod spec", pod)
440442
node_selectors = self._get_pod_node_selector()
441443
self._handle_completion(last_status.state, container, node_selectors)
442444
if self.should_delete_pod():
@@ -471,6 +473,13 @@ def get_last_or_none(container_list: Optional[List[Union[V1ContainerStatus, V1Co
471473
else:
472474
return container_list[-1]
473475

476+
@staticmethod
477+
def get_container_by_name(container_list: Optional[List[Union[V1ContainerStatus, V1Container]]], container_name: str) -> Optional[Union[V1ContainerStatus, V1Container]]:
478+
if not container_list:
479+
return None
480+
481+
return next((c for c in container_list if c.name == container_name), None)
482+
474483
@retry_exponential_if_exception_type((ApiException, HTTPError,), log)
475484
def create_dask_gateway_config_map(self, dask_gateway_url: str, cm_name: str):
476485
gateway = {'gateway': {'address': dask_gateway_url}}

tests/test_dask.py

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
2+
from types import SimpleNamespace
33
import threading
44
from unittest import TestCase
55
from unittest.mock import PropertyMock, create_autospec, patch, call, Mock, mock_open
@@ -20,7 +20,7 @@
2020
from calrissian.k8s import (
2121
CompletionResult,
2222
)
23-
from kubernetes.client.models import V1Pod
23+
from kubernetes.client.models import V1Pod, V1Container, V1ContainerStatus
2424

2525

2626
class ValidateExtensionTestCase(TestCase):
@@ -478,7 +478,7 @@ def test_init(self, mock_get_namespace, mock_client):
478478
self.assertIsNone(kc.pod)
479479
self.assertIsNone(kc.completion_result)
480480

481-
481+
482482
@patch('calrissian.dask.DaskPodMonitor')
483483
def test_submit_pod(self, mock_podmonitor, mock_get_namespace, mock_client):
484484
mock_get_namespace.return_value = 'namespace'
@@ -506,27 +506,51 @@ def setup_mock_watch(self, mock_watch, event_objects=[]):
506506

507507
def make_mock_pod(self, name):
508508
mock_metadata = Mock()
509-
# Cannot mock name attribute without a propertymock
510-
name_property = PropertyMock(return_value=name)
511-
type(mock_metadata).name = name_property
509+
type(mock_metadata).name = PropertyMock(return_value=name)
512510
mock_pod = create_autospec(V1Pod, metadata=mock_metadata)
513511
return mock_pod
514512

513+
def mock_k8s_obj_with_name(self, cls, name: str):
514+
obj = create_autospec(cls, instance=True)
515+
type(obj).name = PropertyMock(return_value=name)
516+
return obj
517+
518+
def make_mock_container(self, name="main-container"):
519+
# Make something that behaves like a k8s container for _extract_cpu_memory_requests
520+
# It expects container.resources.requests to exist.
521+
return SimpleNamespace(
522+
name=name,
523+
resources=SimpleNamespace(
524+
requests={"cpu": "1", "memory": "1Mi"}
525+
)
526+
)
527+
515528

516529
@patch('calrissian.dask.watch', autospec=True)
517530
def test_wait_calls_watch_pod_with_pod_name_field_selector(self, mock_watch, mock_get_namespace, mock_client):
518531
mock_pod = self.make_mock_pod('test123')
532+
533+
mock_pod.status = Mock()
534+
mock_pod.status.init_container_statuses = None
535+
mock_pod.status.container_statuses = [Mock()]
536+
537+
mock_pod.spec = Mock()
538+
mock_pod.spec.containers = [self.make_mock_container("main-container")]
539+
519540
mock_pod.status.container_statuses[0].state = Mock(running=None, waiting=None, terminated=Mock(exit_code=0))
541+
520542
self.setup_mock_watch(mock_watch, [mock_pod])
543+
521544
kc = KubernetesDaskClient()
522545
kc._set_pod(mock_pod)
523546
kc.wait_for_completion(cm_name='dask-cm-random')
524547
mock_stream = mock_watch.Watch.return_value.stream
525548
self.assertEqual(mock_stream.call_args, call(kc.core_api_instance.list_namespaced_pod, kc.namespace,
526549
field_selector='metadata.name=test123'))
527-
550+
551+
528552
@patch('calrissian.dask.watch', autospec=True)
529-
def test_wait_calls_watch_pod_with_imcomplete_status(self, mock_watch, mock_get_namespace, mock_client):
553+
def test_wait_calls_watch_pod_with_incomplete_status(self, mock_watch, mock_get_namespace, mock_client):
530554
self.setup_mock_watch(mock_watch)
531555
mock_pod = self.make_mock_pod('test123')
532556
kc = KubernetesDaskClient()
@@ -572,6 +596,12 @@ def test_wait_finishes_when_pod_state_is_terminated(self, mock_cpu_memory,
572596
mock_podmonitor, mock_watch, mock_get_namespace,
573597
mock_client):
574598
mock_pod = create_autospec(V1Pod)
599+
mock_pod.status = Mock()
600+
mock_pod.status.init_container_statuses = None
601+
mock_pod.status.container_statuses = [Mock()]
602+
603+
mock_pod.spec = Mock()
604+
mock_pod.spec.containers = [self.make_mock_container("main-container")]
575605
mock_pod.status.container_statuses[0].state = Mock(running=None, waiting=None, terminated=Mock(exit_code=123))
576606
mock_cpu_memory.return_value = ('1', '1Mi')
577607
self.setup_mock_watch(mock_watch, [mock_pod])
@@ -584,3 +614,46 @@ def test_wait_finishes_when_pod_state_is_terminated(self, mock_cpu_memory,
584614
self.assertIsNone(kc.pod)
585615
# This is to inspect `with PodMonitor() as monitor`:
586616
self.assertTrue(mock_podmonitor.return_value.__enter__.return_value.remove.called)
617+
618+
619+
def test_get_container_by_name_returns_none_when_list_is_none(self, mock_get_namespace, mock_client):
620+
res = KubernetesDaskClient.get_container_by_name(None, "whatever")
621+
self.assertIsNone(res)
622+
623+
624+
def test_get_container_by_name_returns_none_when_list_is_empty(self, mock_get_namespace, mock_client):
625+
res = KubernetesDaskClient.get_container_by_name([], "whatever")
626+
self.assertIsNone(res)
627+
628+
629+
def test_get_container_by_name_returns_matching_v1container(self, mock_get_namespace, mock_client):
630+
c1 = self.mock_k8s_obj_with_name(V1Container, "a")
631+
c2 = self.mock_k8s_obj_with_name(V1Container, "target")
632+
c3 = self.mock_k8s_obj_with_name(V1Container, "b")
633+
634+
res = KubernetesDaskClient.get_container_by_name([c1, c2, c3], "target")
635+
self.assertIs(res, c2)
636+
637+
638+
def test_get_container_by_name_returns_matching_v1containerstatus(self, mock_get_namespace, mock_client):
639+
s1 = self.mock_k8s_obj_with_name(V1ContainerStatus, "x")
640+
s2 = self.mock_k8s_obj_with_name(V1ContainerStatus, "target")
641+
642+
res = KubernetesDaskClient.get_container_by_name([s1, s2], "target")
643+
self.assertIs(res, s2)
644+
645+
646+
def test_get_container_by_name_returns_none_when_not_found(self, mock_get_namespace, mock_client):
647+
c1 = self.mock_k8s_obj_with_name(V1Container, "a")
648+
c2 = self.mock_k8s_obj_with_name(V1ContainerStatus, "b")
649+
650+
res = KubernetesDaskClient.get_container_by_name([c1, c2], "missing")
651+
self.assertIsNone(res)
652+
653+
654+
def test_get_container_by_name_returns_first_match_if_duplicates(self, mock_get_namespace, mock_client):
655+
first = self.mock_k8s_obj_with_name(V1Container, "dup")
656+
second = self.mock_k8s_obj_with_name(V1ContainerStatus, "dup")
657+
658+
res = KubernetesDaskClient.get_container_by_name([first, second], "dup")
659+
self.assertIs(res, first)

0 commit comments

Comments
 (0)