22# Copyright 2024-2025 Intel Corporation
33# Media Communications Mesh
44
5+ import copy
56import datetime
67import logging
78import os
3839 remove_result_media ,
3940)
4041from pytest_mfd_logging .amber_log_formatter import AmberLogFormatter
42+ from pytest_mfd_config .models .topology import TopologyModel
4143
4244logger = logging .getLogger (__name__ )
4345phase_report_key = pytest .StashKey [Dict [str , pytest .CollectReport ]]()
4446
47+ # Store extra host config fields that aren't in the TopologyModel schema
48+ # Key: host name, Value: dict of extra fields (e.g., dsa_device, dsa_address)
49+ _host_extra_config : Dict [str , Dict [str , Any ]] = {}
50+
51+
52+ # Known extra fields that we allow in host config but TopologyModel doesn't support
53+ HOST_EXTRA_FIELDS = ["dsa_device" , "dsa_address" , "build_path" ]
54+
55+
56+ def _extract_extra_fields (config : dict ) -> dict :
57+ """
58+ Extract extra fields from topology config that aren't supported by TopologyModel.
59+
60+ This allows us to add custom fields like dsa_device to host configs without
61+ breaking the pydantic validation.
62+
63+ Returns a cleaned config dict suitable for TopologyModel.
64+ """
65+ global _host_extra_config
66+ _host_extra_config .clear ()
67+
68+ cleaned = copy .deepcopy (config )
69+
70+ # List of top-level fields to strip (not validated by TopologyModel)
71+ TOP_LEVEL_EXTRA_FIELDS = ["host_mtl_paths" ]
72+
73+ # Remove top-level extra fields that TopologyModel doesn't understand
74+ for field in TOP_LEVEL_EXTRA_FIELDS :
75+ if field in cleaned :
76+ # Store in special key for retrieval
77+ _host_extra_config [f"__toplevel__{ field } " ] = cleaned .pop (field )
78+ logger .debug (f"Extracted top-level config: { field } " )
79+
80+ # Extract extra fields from hosts
81+ for host_cfg in cleaned .get ("hosts" , []):
82+ host_name = host_cfg .get ("name" )
83+ if host_name :
84+ extras = {}
85+ for field in HOST_EXTRA_FIELDS :
86+ if field in host_cfg :
87+ extras [field ] = host_cfg .pop (field )
88+ if extras :
89+ _host_extra_config [host_name ] = extras
90+ logger .debug (f"Extracted extra config for { host_name } : { extras } " )
91+
92+ return cleaned
93+
94+
95+ def get_host_extra_config (host_name : str ) -> Dict [str , Any ]:
96+ """Get extra configuration fields for a host (e.g., dsa_device)."""
97+ return _host_extra_config .get (host_name , {})
98+
99+
100+ def get_toplevel_extra_config (field_name : str ) -> Any :
101+ """Get top-level extra config field (e.g., host_mtl_paths)."""
102+ return _host_extra_config .get (f"__toplevel__{ field_name } " )
103+
104+
105+ def get_host_mtl_paths () -> Dict [str , str ]:
106+ """Get the host_mtl_paths config dictionary."""
107+ return get_toplevel_extra_config ("host_mtl_paths" ) or {}
108+
109+
110+ @pytest .fixture (scope = "session" )
111+ def topology (topology_config : dict ) -> TopologyModel :
112+ """
113+ Create topology model from config file data.
114+
115+ This overrides the default pytest_mfd_config topology fixture to allow
116+ extra fields like dsa_device, dsa_address in host configurations.
117+ """
118+ logger .debug ("Creating Topology model with extra field support." )
119+ cleaned_config = _extract_extra_fields (topology_config )
120+ return TopologyModel (** cleaned_config )
121+
45122
46123def _select_sniff_interface (host , capture_cfg : dict ):
47124 def _pci_device_id (nic ) -> str :
@@ -281,9 +358,16 @@ def dma_port_list(request):
281358
282359
283360@pytest .fixture (scope = "session" )
284- def nic_port_list (hosts : dict , mtl_path ) -> None :
361+ def nic_port_list (hosts : dict , mtl_path , test_config ) -> None :
362+ # Try to get host_mtl_paths from test_config first, then from our extracted config
363+ host_mtl_paths = test_config .get ("host_mtl_paths" , {})
364+ if not host_mtl_paths :
365+ host_mtl_paths = get_host_mtl_paths ()
366+
285367 for host in hosts .values ():
286- nicctl = Nicctl (mtl_path , host )
368+ # Use per-host MTL path if configured, otherwise fall back to default
369+ host_path = host_mtl_paths .get (host .name , mtl_path )
370+ nicctl = Nicctl (host_path , host )
287371 if int (host .network_interfaces [0 ].virtualization .get_current_vfs ()) == 0 :
288372 vfs = nicctl .create_vfs (host .network_interfaces [0 ].pci_address .lspci )
289373 vfs = nicctl .vfio_list (host .network_interfaces [0 ].pci_address .lspci )
@@ -293,7 +377,10 @@ def nic_port_list(hosts: dict, mtl_path) -> None:
293377
294378@pytest .fixture (scope = "function" )
295379def setup_interfaces (hosts , test_config , mtl_path ):
296- interface_setup = InterfaceSetup (hosts , mtl_path )
380+ host_mtl_paths = test_config .get ("host_mtl_paths" , {})
381+ if not host_mtl_paths :
382+ host_mtl_paths = get_host_mtl_paths ()
383+ interface_setup = InterfaceSetup (hosts , mtl_path , host_mtl_paths )
297384 yield interface_setup
298385 interface_setup .cleanup ()
299386
@@ -658,5 +745,6 @@ def fail_test(stage):
658745
659746@pytest .fixture (scope = "session" , autouse = True )
660747def init_ip_address_pools (test_config : dict [Any , Any ]) -> None :
661- session_id = int (test_config ["session_id" ])
662- ip_pools .init (session_id = session_id )
748+ # Support session_id at top level or in metadata, with default of 1
749+ session_id = test_config .get ("session_id" ) or test_config .get ("metadata" , {}).get ("session_id" , 1 )
750+ ip_pools .init (session_id = int (session_id ))
0 commit comments