11import os
2-
2+ from types import SimpleNamespace
33import threading
44from unittest import TestCase
55from unittest .mock import PropertyMock , create_autospec , patch , call , Mock , mock_open
2020from calrissian .k8s import (
2121 CompletionResult ,
2222)
23- from kubernetes .client .models import V1Pod
23+ from kubernetes .client .models import V1Pod , V1Container , V1ContainerStatus
2424
2525
2626class ValidateExtensionTestCase (TestCase ):
@@ -478,7 +478,7 @@ def test_init(self, mock_get_namespace, mock_client):
478478 self .assertIsNone (kc .pod )
479479 self .assertIsNone (kc .completion_result )
480480
481-
481+
482482 @patch ('calrissian.dask.DaskPodMonitor' )
483483 def test_submit_pod (self , mock_podmonitor , mock_get_namespace , mock_client ):
484484 mock_get_namespace .return_value = 'namespace'
@@ -506,27 +506,51 @@ def setup_mock_watch(self, mock_watch, event_objects=[]):
506506
507507 def make_mock_pod (self , name ):
508508 mock_metadata = Mock ()
509- # Cannot mock name attribute without a propertymock
510- name_property = PropertyMock (return_value = name )
511- type(mock_metadata ).name = name_property
509+ type(mock_metadata ).name = PropertyMock (return_value = name )
512510 mock_pod = create_autospec (V1Pod , metadata = mock_metadata )
513511 return mock_pod
514512
513+ def mock_k8s_obj_with_name (self , cls , name : str ):
514+ obj = create_autospec (cls , instance = True )
515+ type(obj ).name = PropertyMock (return_value = name )
516+ return obj
517+
518+ def make_mock_container (self , name = "main-container" ):
519+ # Make something that behaves like a k8s container for _extract_cpu_memory_requests
520+ # It expects container.resources.requests to exist.
521+ return SimpleNamespace (
522+ name = name ,
523+ resources = SimpleNamespace (
524+ requests = {"cpu" : "1" , "memory" : "1Mi" }
525+ )
526+ )
527+
515528
516529 @patch ('calrissian.dask.watch' , autospec = True )
517530 def test_wait_calls_watch_pod_with_pod_name_field_selector (self , mock_watch , mock_get_namespace , mock_client ):
518531 mock_pod = self .make_mock_pod ('test123' )
532+
533+ mock_pod .status = Mock ()
534+ mock_pod .status .init_container_statuses = None
535+ mock_pod .status .container_statuses = [Mock ()]
536+
537+ mock_pod .spec = Mock ()
538+ mock_pod .spec .containers = [self .make_mock_container ("main-container" )]
539+
519540 mock_pod .status .container_statuses [0 ].state = Mock (running = None , waiting = None , terminated = Mock (exit_code = 0 ))
541+
520542 self .setup_mock_watch (mock_watch , [mock_pod ])
543+
521544 kc = KubernetesDaskClient ()
522545 kc ._set_pod (mock_pod )
523546 kc .wait_for_completion (cm_name = 'dask-cm-random' )
524547 mock_stream = mock_watch .Watch .return_value .stream
525548 self .assertEqual (mock_stream .call_args , call (kc .core_api_instance .list_namespaced_pod , kc .namespace ,
526549 field_selector = 'metadata.name=test123' ))
527-
550+
551+
528552 @patch ('calrissian.dask.watch' , autospec = True )
529- def test_wait_calls_watch_pod_with_imcomplete_status (self , mock_watch , mock_get_namespace , mock_client ):
553+ def test_wait_calls_watch_pod_with_incomplete_status (self , mock_watch , mock_get_namespace , mock_client ):
530554 self .setup_mock_watch (mock_watch )
531555 mock_pod = self .make_mock_pod ('test123' )
532556 kc = KubernetesDaskClient ()
@@ -572,6 +596,12 @@ def test_wait_finishes_when_pod_state_is_terminated(self, mock_cpu_memory,
572596 mock_podmonitor , mock_watch , mock_get_namespace ,
573597 mock_client ):
574598 mock_pod = create_autospec (V1Pod )
599+ mock_pod .status = Mock ()
600+ mock_pod .status .init_container_statuses = None
601+ mock_pod .status .container_statuses = [Mock ()]
602+
603+ mock_pod .spec = Mock ()
604+ mock_pod .spec .containers = [self .make_mock_container ("main-container" )]
575605 mock_pod .status .container_statuses [0 ].state = Mock (running = None , waiting = None , terminated = Mock (exit_code = 123 ))
576606 mock_cpu_memory .return_value = ('1' , '1Mi' )
577607 self .setup_mock_watch (mock_watch , [mock_pod ])
@@ -584,3 +614,46 @@ def test_wait_finishes_when_pod_state_is_terminated(self, mock_cpu_memory,
584614 self .assertIsNone (kc .pod )
585615 # This is to inspect `with PodMonitor() as monitor`:
586616 self .assertTrue (mock_podmonitor .return_value .__enter__ .return_value .remove .called )
617+
618+
619+ def test_get_container_by_name_returns_none_when_list_is_none (self , mock_get_namespace , mock_client ):
620+ res = KubernetesDaskClient .get_container_by_name (None , "whatever" )
621+ self .assertIsNone (res )
622+
623+
624+ def test_get_container_by_name_returns_none_when_list_is_empty (self , mock_get_namespace , mock_client ):
625+ res = KubernetesDaskClient .get_container_by_name ([], "whatever" )
626+ self .assertIsNone (res )
627+
628+
629+ def test_get_container_by_name_returns_matching_v1container (self , mock_get_namespace , mock_client ):
630+ c1 = self .mock_k8s_obj_with_name (V1Container , "a" )
631+ c2 = self .mock_k8s_obj_with_name (V1Container , "target" )
632+ c3 = self .mock_k8s_obj_with_name (V1Container , "b" )
633+
634+ res = KubernetesDaskClient .get_container_by_name ([c1 , c2 , c3 ], "target" )
635+ self .assertIs (res , c2 )
636+
637+
638+ def test_get_container_by_name_returns_matching_v1containerstatus (self , mock_get_namespace , mock_client ):
639+ s1 = self .mock_k8s_obj_with_name (V1ContainerStatus , "x" )
640+ s2 = self .mock_k8s_obj_with_name (V1ContainerStatus , "target" )
641+
642+ res = KubernetesDaskClient .get_container_by_name ([s1 , s2 ], "target" )
643+ self .assertIs (res , s2 )
644+
645+
646+ def test_get_container_by_name_returns_none_when_not_found (self , mock_get_namespace , mock_client ):
647+ c1 = self .mock_k8s_obj_with_name (V1Container , "a" )
648+ c2 = self .mock_k8s_obj_with_name (V1ContainerStatus , "b" )
649+
650+ res = KubernetesDaskClient .get_container_by_name ([c1 , c2 ], "missing" )
651+ self .assertIsNone (res )
652+
653+
654+ def test_get_container_by_name_returns_first_match_if_duplicates (self , mock_get_namespace , mock_client ):
655+ first = self .mock_k8s_obj_with_name (V1Container , "dup" )
656+ second = self .mock_k8s_obj_with_name (V1ContainerStatus , "dup" )
657+
658+ res = KubernetesDaskClient .get_container_by_name ([first , second ], "dup" )
659+ self .assertIs (res , first )
0 commit comments