@@ -20,37 +20,38 @@ use crate::{
2020 class_factory:: ClassFactory , rd_pipe_plugin:: RdPipePlugin , registry:: CLSID_RD_PIPE_PLUGIN ,
2121} ;
2222use rd_pipe_plugin:: REG_PATH ;
23- #[ cfg( target_arch = "x86" ) ]
24- use registry:: { ctx_add_to_registry, ctx_delete_from_registry} ;
2523use registry:: {
26- delete_from_registry , inproc_server_add_to_registry , msts_add_to_registry , COM_CLS_FOLDER ,
27- TS_ADD_INS_FOLDER , TS_ADD_IN_RD_PIPE_FOLDER_NAME ,
24+ COM_CLS_FOLDER , TS_ADD_IN_RD_PIPE_FOLDER_NAME , TS_ADD_INS_FOLDER , delete_from_registry ,
25+ inproc_server_add_to_registry , msts_add_to_registry ,
2826} ;
27+ #[ cfg( target_arch = "x86" ) ]
28+ use registry:: { ctx_add_to_registry, ctx_delete_from_registry} ;
2929use std:: { ffi:: c_void, io, mem:: transmute, panic, str:: FromStr } ;
3030use tokio:: runtime:: Runtime ;
3131use tracing:: { debug, error, instrument, trace} ;
3232use windows:: {
33- core:: { Interface , PCWSTR } ,
3433 Win32 :: {
35- Foundation :: { ERROR_INVALID_FUNCTION , ERROR_INVALID_PARAMETER , HMODULE , WIN32_ERROR } ,
36- System :: LibraryLoader :: GetModuleFileNameW ,
37- } ,
38- } ;
39- use windows:: {
40- core:: { GUID , HRESULT } ,
41- Win32 :: {
42- Foundation :: { BOOL , CLASS_E_CLASSNOTAVAILABLE , E_UNEXPECTED , S_OK } ,
34+ Foundation :: { CLASS_E_CLASSNOTAVAILABLE , E_UNEXPECTED , S_OK } ,
4335 System :: {
4436 Com :: IClassFactory ,
4537 LibraryLoader :: DisableThreadLibraryCalls ,
4638 RemoteDesktop :: IWTSPlugin ,
4739 SystemServices :: { DLL_PROCESS_ATTACH , DLL_PROCESS_DETACH } ,
4840 } ,
4941 } ,
42+ core:: { GUID , HRESULT } ,
5043} ;
44+ use windows:: {
45+ Win32 :: {
46+ Foundation :: { ERROR_INVALID_PARAMETER , HMODULE , WIN32_ERROR } ,
47+ System :: LibraryLoader :: GetModuleFileNameW ,
48+ } ,
49+ core:: { Interface , PCWSTR } ,
50+ } ;
51+ use windows_core:: BOOL ;
5152use winreg:: {
53+ HKEY , RegKey ,
5254 enums:: { HKEY_CURRENT_USER , HKEY_LOCAL_MACHINE } ,
53- RegKey , HKEY ,
5455} ;
5556
5657lazy_static:: lazy_static! {
@@ -70,7 +71,7 @@ fn get_log_level_from_registry(parent_key: HKEY) -> io::Result<u32> {
7071
7172static mut INSTANCE : Option < HMODULE > = None ;
7273
73- #[ no_mangle]
74+ #[ unsafe ( no_mangle) ]
7475pub extern "stdcall" fn DllMain ( hinst : HMODULE , reason : u32 , _reserved : * mut c_void ) -> BOOL {
7576 match reason {
7677 DLL_PROCESS_ATTACH => {
@@ -112,7 +113,7 @@ pub extern "stdcall" fn DllMain(hinst: HMODULE, reason: u32, _reserved: *mut c_v
112113 true . into ( )
113114}
114115
115- #[ no_mangle]
116+ #[ unsafe ( no_mangle) ]
116117#[ instrument]
117118pub extern "stdcall" fn DllGetClassObject (
118119 rclsid : * const GUID ,
@@ -143,7 +144,7 @@ pub extern "stdcall" fn DllGetClassObject(
143144 S_OK
144145}
145146
146- #[ no_mangle]
147+ #[ unsafe ( no_mangle) ]
147148#[ instrument]
148149pub extern "stdcall" fn VirtualChannelGetInstance (
149150 riid : * const GUID ,
@@ -160,9 +161,13 @@ pub extern "stdcall" fn VirtualChannelGetInstance(
160161 return E_UNEXPECTED ;
161162 }
162163 let pnumobjs = unsafe { & mut * pnumobjs } ;
163- trace ! ( "Checking whether result pointer is null (i.e. whether this call is a query for number of plugins or a query for the plugins itself)" ) ;
164+ trace ! (
165+ "Checking whether result pointer is null (i.e. whether this call is a query for number of plugins or a query for the plugins itself)"
166+ ) ;
164167 if ppo. is_null ( ) {
165- debug ! ( "Result pointer is null, client is querying for number of objects. Setting pnumobjs to 1, since we only support one plugin" ) ;
168+ debug ! (
169+ "Result pointer is null, client is querying for number of objects. Setting pnumobjs to 1, since we only support one plugin"
170+ ) ;
166171 * pnumobjs = 1 ;
167172 } else {
168173 debug ! ( "{} plugins requested" , * pnumobjs) ;
@@ -172,9 +177,9 @@ pub extern "stdcall" fn VirtualChannelGetInstance(
172177 }
173178 let ppo = unsafe { & mut * ppo } ;
174179 trace ! ( "Constructing the plugin" ) ;
175- let plugin: IWTSPlugin = RdPipePlugin :: new ( ) . into ( ) ;
180+ let plugin = RdPipePlugin :: new ( ) ;
176181 trace ! ( "Setting result pointer to plugin" ) ;
177- * ppo = unsafe { transmute ( plugin) } ;
182+ * ppo = unsafe { transmute ( & plugin) } ;
178183 }
179184 S_OK
180185}
@@ -184,7 +189,7 @@ const CMD_MSTS: char = 'r'; // Registers/unregisters RDP/MSTS support
184189const CMD_CITRIX : char = 'x' ; // Registers/unregisters Citrix support
185190const CMD_LOCAL_MACHINE : char = 'm' ; // If omitted, registers to HKEY_CURRENT_USER
186191
187- #[ no_mangle]
192+ #[ unsafe ( no_mangle) ]
188193#[ instrument]
189194pub extern "stdcall" fn DllInstall ( install : bool , cmd_line : PCWSTR ) -> HRESULT {
190195 debug ! ( "DllInstall called" ) ;
@@ -224,37 +229,29 @@ pub extern "stdcall" fn DllInstall(install: bool, cmd_line: PCWSTR) -> HRESULT {
224229 error ! ( "No channel names provided" ) ;
225230 return ERROR_INVALID_PARAMETER . into ( ) ;
226231 }
227- match unsafe { INSTANCE } {
228- Some ( h) => {
229- let mut file_name = [ 0u16 ; 256 ] ;
230- let path_string: String ;
231- match unsafe { GetModuleFileNameW ( h, file_name. as_mut ( ) ) } > 0 {
232- true => {
233- path_string = String :: from_utf16_lossy ( & file_name) ;
234- }
235- false => {
236- let e = windows:: core:: Error :: from_win32 ( ) ;
237- error ! ( "Error calling GetModuleFileNameW: {}" , e) ;
238- return e. into ( ) ;
239- }
240- }
241- if let Err ( e) = inproc_server_add_to_registry (
242- scope_hkey,
243- & COM_CLS_FOLDER ,
244- & path_string,
245- & arguments[ 1 ..] ,
246- ) {
247- let e: windows:: core:: Error =
248- WIN32_ERROR ( e. raw_os_error ( ) . unwrap ( ) as u32 ) . into ( ) ;
249- error ! ( "Error calling inproc_server_add_to_registry: {}" , e) ;
250- return e. into ( ) ;
251- }
232+ let mut file_name = [ 0u16 ; 256 ] ;
233+ let path_string: String ;
234+ match unsafe { GetModuleFileNameW ( INSTANCE , file_name. as_mut ( ) ) } > 0 {
235+ true => {
236+ path_string = String :: from_utf16_lossy ( & file_name) ;
252237 }
253- None => {
254- error ! ( "No hinstance to calculate dll path" ) ;
255- return ERROR_INVALID_FUNCTION . into ( ) ;
238+ false => {
239+ let e = windows:: core:: Error :: from_win32 ( ) ;
240+ error ! ( "Error calling GetModuleFileNameW: {}" , e) ;
241+ return e. into ( ) ;
256242 }
257243 }
244+ if let Err ( e) = inproc_server_add_to_registry (
245+ scope_hkey,
246+ & COM_CLS_FOLDER ,
247+ & path_string,
248+ & arguments[ 1 ..] ,
249+ ) {
250+ let e: windows:: core:: Error =
251+ WIN32_ERROR ( e. raw_os_error ( ) . unwrap ( ) as u32 ) . into ( ) ;
252+ error ! ( "Error calling inproc_server_add_to_registry: {}" , e) ;
253+ return e. into ( ) ;
254+ }
258255 }
259256 if commands. contains ( CMD_MSTS ) {
260257 if let Err ( e) = msts_add_to_registry ( scope_hkey) {
0 commit comments