From 5bb5f0c858549d27b3cbb6385c2d89edb6aeba0d Mon Sep 17 00:00:00 2001 From: Pablo Curiel Date: Thu, 3 Jun 2021 20:19:19 -0400 Subject: [PATCH] nxdt_host.py: add CLI mode. --- host/nxdt_host.py | 364 +++++++++++++++++++++++++++++----------------- 1 file changed, 228 insertions(+), 136 deletions(-) diff --git a/host/nxdt_host.py b/host/nxdt_host.py index a7e628e..aaaeb97 100644 --- a/host/nxdt_host.py +++ b/host/nxdt_host.py @@ -31,6 +31,7 @@ # Under MacOS, use `brew install libusb` to install libusb via Homebrew. # Under Linux, you should be good to go from the start. If not, just use the package manager from your distro to install libusb. +import sys import os import platform import threading @@ -39,6 +40,7 @@ import logging import queue import shutil import time +import datetime import struct import usb.core import usb.util @@ -52,6 +54,8 @@ from tqdm import tqdm import base64 +from argparse import ArgumentParser + # Scaling factors. WINDOWS_SCALING_FACTOR = 96.0 SCALE = 1.0 @@ -61,7 +65,7 @@ WINDOW_WIDTH = 500 WINDOW_HEIGHT = 470 # Application version. -APP_VERSION = '0.2' +APP_VERSION = '0.3' # Copyright year. COPYRIGHT_YEAR = '2021' @@ -114,6 +118,16 @@ USB_STATUS_UNSUPPORTED_ABI_VERSION = 6 USB_STATUS_MALFORMED_CMD = 7 USB_STATUS_HOST_IO_ERROR = 8 +# Script title. +SCRIPT_TITLE = "{} host script v{}".format(USB_DEV_PRODUCT, APP_VERSION) + +# Copyright text. +now = datetime.datetime.now() +cur_year = now.year +COPYRIGHT_TEXT = "Copyright (c) {}".format(COPYRIGHT_YEAR) +if cur_year > int(COPYRIGHT_YEAR): COPYRIGHT_TEXT += "-{}".format(cur_year) +COPYRIGHT_TEXT += ", {}".format(USB_DEV_MANUFACTURER) + # Messages displayed as labels. SERVER_START_MSG = 'Please connect a Nintendo Switch console running {}.'.format(USB_DEV_PRODUCT) SERVER_STOP_MSG = 'Exit {} on your console or disconnect it at any time to stop the server.'.format(USB_DEV_PRODUCT) @@ -289,13 +303,17 @@ class LogQueueHandler(logging.Handler): self.log_queue = log_queue def emit(self, record): - self.log_queue.put(record) + if g_cliMode: + msg = self.format(record) + print(msg) + else: + self.log_queue.put(record) # Reference: https://beenje.github.io/blog/posts/logging-to-a-tkinter-scrolledtext-widget. class LogConsole: - def __init__(self, scrolled_text): + def __init__(self, scrolled_text=None): self.scrolled_text = scrolled_text - self.frame = self.scrolled_text.winfo_toplevel() + self.frame = (self.scrolled_text.winfo_toplevel() if self.scrolled_text else None) # Create a logging handler using a queue. self.log_queue = queue.Queue() @@ -303,17 +321,18 @@ class LogConsole: #formatter = logging.Formatter('[%(asctime)s] -> %(message)s') formatter = logging.Formatter('%(message)s') self.queue_handler.setFormatter(formatter) - g_Logger.addHandler(self.queue_handler) + g_logger.addHandler(self.queue_handler) # Start polling messages from the queue. - self.frame.after(100, self.poll_log_queue) + if self.frame: self.frame.after(100, self.poll_log_queue) def display(self, record): msg = self.queue_handler.format(record) - self.scrolled_text.configure(state='normal') - self.scrolled_text.insert(tk.END, msg + '\n', record.levelname) - self.scrolled_text.configure(state='disabled') - self.scrolled_text.yview(tk.END) + if self.scrolled_text: + self.scrolled_text.configure(state='normal') + self.scrolled_text.insert(tk.END, msg + '\n', record.levelname) + self.scrolled_text.configure(state='disabled') + self.scrolled_text.yview(tk.END) def poll_log_queue(self): # Check every 100 ms if there is a new message in the queue to display. @@ -325,15 +344,13 @@ class LogConsole: else: self.display(record) - self.frame.after(100, self.poll_log_queue) + if self.frame: self.frame.after(100, self.poll_log_queue) # Loosely based on tk.py from tqdm. class ProgressBarWindow: global g_tlb, g_taskbar def __init__(self, bar_format=None, tk_parent=None, window_title='', window_resize=False, window_protocol=None): - if tk_parent is None: raise Exception('`tk_parent` must be provided!') - self.n = 0 self.total = 0 self.divider = 1 @@ -346,30 +363,36 @@ class ProgressBarWindow: self.hwnd = 0 self.tk_parent = tk_parent + self.tk_window = (tk.Toplevel(self.tk_parent) if self.tk_parent else None) + self.withdrawn = False + self.tk_text_var = None + self.tk_n_var = None + self.tk_pbar = None - self.tk_window = tk.Toplevel(self.tk_parent) + self.pbar = None - self.tk_window.withdraw() - self.withdrawn = True - - if window_title: self.tk_window.title(window_title) - self.tk_window.resizable(window_resize, window_resize) - if window_protocol: self.tk_window.protocol('WM_DELETE_WINDOW', window_protocol) - - pbar_frame = ttk.Frame(self.tk_window, padding=5) - pbar_frame.pack() - - self.tk_text_var = tk.StringVar(self.tk_window) - tk_label = ttk.Label(pbar_frame, textvariable=self.tk_text_var, wraplength=600, anchor='center', justify='center') - tk_label.pack() - - self.tk_n_var = tk.DoubleVar(self.tk_window, value=0) - self.tk_pbar = ttk.Progressbar(pbar_frame, variable=self.tk_n_var, length=450) - self.tk_pbar.configure(maximum=100, mode='indeterminate') - self.tk_pbar.pack() + if self.tk_window: + self.tk_window.withdraw() + self.withdrawn = True + + if window_title: self.tk_window.title(window_title) + self.tk_window.resizable(window_resize, window_resize) + if window_protocol: self.tk_window.protocol('WM_DELETE_WINDOW', window_protocol) + + pbar_frame = ttk.Frame(self.tk_window, padding=5) + pbar_frame.pack() + + self.tk_text_var = tk.StringVar(self.tk_window) + tk_label = ttk.Label(pbar_frame, textvariable=self.tk_text_var, wraplength=600, anchor='center', justify='center') + tk_label.pack() + + self.tk_n_var = tk.DoubleVar(self.tk_window, value=0) + self.tk_pbar = ttk.Progressbar(pbar_frame, variable=self.tk_n_var, length=450) + self.tk_pbar.configure(maximum=100, mode='indeterminate') + self.tk_pbar.pack() def __del__(self): - self.tk_parent.after(0, self.tk_window.destroy) + if self.tk_parent: self.tk_parent.after(0, self.tk_window.destroy) def start(self, total, n=0, divider=1, prefix='', unit='B'): if (total <= 0) or (n < 0) or (divider < 1): raise Exception('Invalid arguments!') @@ -381,35 +404,42 @@ class ProgressBarWindow: self.prefix = prefix self.unit = unit - self.tk_pbar.configure(maximum=self.total_div, mode='determinate') - - self.start_time = time.time() + if self.tk_pbar: + self.tk_pbar.configure(maximum=self.total_div, mode='determinate') + self.start_time = time.time() + else: + n_div = (float(self.n) / self.divider) + self.pbar = tqdm(initial=n_div, total=self.total_div, unit=self.unit, dynamic_ncols=True, desc=self.prefix, bar_format=self.bar_format) def update(self, n): cur_n = (self.n + n) if cur_n > self.total: return - cur_n_div = (float(cur_n) / self.divider) - self.elapsed_time = (time.time() - self.start_time) - - msg = tqdm.format_meter(n=cur_n_div, total=self.total_div, elapsed=self.elapsed_time, prefix=self.prefix, unit=self.unit, bar_format=self.bar_format) - - self.tk_text_var.set(msg) - self.tk_n_var.set(cur_n_div) - - if self.withdrawn: - self.tk_window.geometry("+{}+{}".format(self.tk_parent.winfo_x(), self.tk_parent.winfo_y())) - self.tk_window.deiconify() - self.tk_window.grab_set() + if self.tk_window: + cur_n_div = (float(cur_n) / self.divider) + self.elapsed_time = (time.time() - self.start_time) - if g_taskbar: - self.hwnd = int(self.tk_window.wm_frame(), 16) - g_taskbar.ActivateTab(self.hwnd) - g_taskbar.SetProgressState(self.hwnd, g_tlb.TBPF_NORMAL) + msg = tqdm.format_meter(n=cur_n_div, total=self.total_div, elapsed=self.elapsed_time, prefix=self.prefix, unit=self.unit, bar_format=self.bar_format) - self.withdrawn = False - - if g_taskbar: g_taskbar.SetProgressValue(self.hwnd, cur_n, self.total) + self.tk_text_var.set(msg) + self.tk_n_var.set(cur_n_div) + + if self.withdrawn: + self.tk_window.geometry("+{}+{}".format(self.tk_parent.winfo_x(), self.tk_parent.winfo_y())) + self.tk_window.deiconify() + self.tk_window.grab_set() + + if g_taskbar: + self.hwnd = int(self.tk_window.wm_frame(), 16) + g_taskbar.ActivateTab(self.hwnd) + g_taskbar.SetProgressState(self.hwnd, g_tlb.TBPF_NORMAL) + + self.withdrawn = False + + if g_taskbar: g_taskbar.SetProgressValue(self.hwnd, cur_n, self.total) + else: + n_div = (float(n) / self.divider) + self.pbar.update(n_div) self.n = cur_n @@ -423,20 +453,35 @@ class ProgressBarWindow: self.start_time = 0 self.elapsed_time = 0 - if g_taskbar: - g_taskbar.SetProgressState(self.hwnd, g_tlb.TBPF_NOPROGRESS) - g_taskbar.UnregisterTab(self.hwnd) - - self.tk_window.grab_release() - - self.tk_window.withdraw() - self.withdrawn = True - - self.tk_pbar.configure(maximum=100, mode='indeterminate') + if self.tk_window: + if g_taskbar: + g_taskbar.SetProgressState(self.hwnd, g_tlb.TBPF_NOPROGRESS) + g_taskbar.UnregisterTab(self.hwnd) + + self.tk_window.grab_release() + + self.tk_window.withdraw() + self.withdrawn = True + + self.tk_pbar.configure(maximum=100, mode='indeterminate') + else: + self.pbar.close() + self.pbar = None + print() def set_prefix(self, prefix): self.prefix = prefix +def utilsGetPath(path_arg, fallback_path, is_file, create=False): + path = os.path.abspath(os.path.expanduser(os.path.expandvars(path_arg if path_arg else fallback_path))) + + if not is_file and create: os.makedirs(path, exist_ok=True) + + if not os.path.exists(path) or (is_file and os.path.isdir(path)) or (not is_file and os.path.isfile(path)): + raise Exception("Error: '%s' points to an invalid file/directory." % (path)) + + return path + def utilsIsValueAlignedToEndpointPacketSize(value): return bool((value & (g_usbEpMaxPacketSize - 1)) == 0) @@ -473,9 +518,11 @@ def usbGetDeviceEndpoints(): usb_ep_out_lambda = lambda ep: usb.util.endpoint_direction(ep.bEndpointAddress) == usb.util.ENDPOINT_OUT usb_version = 0 + if g_cliMode: g_logger.info('Please connect a Nintendo Switch console running nxdumptool.') + while True: # Check if the user decided to stop the server. - if g_stopEvent.is_set(): + if not g_cliMode and g_stopEvent.is_set(): g_stopEvent.clear() return False @@ -491,7 +538,7 @@ def usbGetDeviceEndpoints(): # Check if the product and manufacturer strings match the ones used by nxdumptool. #if (cur_dev.manufacturer != USB_DEV_MANUFACTURER) or (cur_dev.product != USB_DEV_PRODUCT): if cur_dev.manufacturer != USB_DEV_MANUFACTURER: - g_Logger.error('Invalid manufacturer/product strings! (bus %u, address %u).' % (cur_dev.bus, cur_dev.address)) + g_logger.error('Invalid manufacturer/product strings! (bus %u, address %u).' % (cur_dev.bus, cur_dev.address)) time.sleep(0.1) continue @@ -510,7 +557,7 @@ def usbGetDeviceEndpoints(): g_usbEpOut = usb.util.find_descriptor(intf, custom_match=usb_ep_out_lambda) if (g_usbEpIn is None) or (g_usbEpOut is None): - g_Logger.error('Invalid endpoint addresses! (bus %u, address %u).' % (cur_dev.bus, cur_dev.address)) + g_logger.error('Invalid endpoint addresses! (bus %u, address %u).' % (cur_dev.bus, cur_dev.address)) time.sleep(0.1) continue @@ -520,8 +567,10 @@ def usbGetDeviceEndpoints(): break - g_Logger.debug('Successfully retrieved USB endpoints! (bus %u, address %u).' % (cur_dev.bus, cur_dev.address)) - g_Logger.debug('Max packet size: 0x%X (USB %u.%u).\n' % (g_usbEpMaxPacketSize, usb_version >> 8, (usb_version & 0xFF) >> 4)) + g_logger.debug('Successfully retrieved USB endpoints! (bus %u, address %u).' % (cur_dev.bus, cur_dev.address)) + g_logger.debug('Max packet size: 0x%X (USB %u.%u).\n' % (g_usbEpMaxPacketSize, usb_version >> 8, (usb_version & 0xFF) >> 4)) + + if g_cliMode: g_logger.info('Exit nxdumptool or disconnect your console at any time to close this script.') return True @@ -531,9 +580,9 @@ def usbRead(size, timeout=-1): try: # Convert read data to a bytes object for easier handling. rd = bytes(g_usbEpIn.read(size, timeout)) - except: - traceback.print_exc() - g_Logger.error('USB timeout triggered or console disconnected.') + except usb.core.USBError: + if not g_cliMode: traceback.print_exc() + g_logger.error('\nUSB timeout triggered or console disconnected.') return rd @@ -542,9 +591,9 @@ def usbWrite(data, timeout=-1): try: wr = g_usbEpOut.write(data, timeout) - except: - traceback.print_exc() - g_Logger.error('USB timeout triggered or console disconnected.') + except usb.core.USBError: + if not g_cliMode: traceback.print_exc() + g_logger.error('\nUSB timeout triggered or console disconnected.') return wr @@ -555,18 +604,19 @@ def usbSendStatus(code): def usbHandleStartSession(cmd_block): global g_nxdtVersionMajor, g_nxdtVersionMinor, g_nxdtVersionMicro, g_nxdtAbiVersion, g_nxdtGitCommit - g_Logger.debug('Received StartSession (%02X) command.' % (USB_CMD_START_SESSION)) + if g_cliMode: print() + g_logger.debug('Received StartSession (%02X) command.' % (USB_CMD_START_SESSION)) # Parse command block. (g_nxdtVersionMajor, g_nxdtVersionMinor, g_nxdtVersionMicro, g_nxdtAbiVersion, g_nxdtGitCommit) = struct.unpack_from(' 0: dbg_str += (' | NSP header size: 0x%X' % (nsp_header_size)) dbg_str += '.' - g_Logger.debug(dbg_str) + g_logger.debug(dbg_str) file_type_str = ('file' if (not g_nspTransferMode) else 'NSP file entry') - g_Logger.info('Receiving %s: "%s".' % (file_type_str, filename)) + if g_cliMode and not g_nspTransferMode: g_logger.info('Receiving %s: "%s".' % (file_type_str, filename)) # Perform integrity checks if (not g_nspTransferMode) and file_size and (nsp_header_size >= file_size): - g_Logger.error('NSP header size must be smaller than the full NSP size!\n') + g_logger.error('NSP header size must be smaller than the full NSP size!\n') return USB_STATUS_MALFORMED_CMD if g_nspTransferMode and nsp_header_size: - g_Logger.error('Received non-zero NSP header size during NSP transfer mode!\n') + g_logger.error('Received non-zero NSP header size during NSP transfer mode!\n') return USB_STATUS_MALFORMED_CMD if (not filename_length) or (filename_length > USB_FILE_PROPERTIES_MAX_NAME_LENGTH): - g_Logger.error('Invalid filename length!\n') + g_logger.error('Invalid filename length!\n') return USB_STATUS_MALFORMED_CMD # Enable NSP transfer mode (if needed). @@ -611,7 +662,7 @@ def usbHandleSendFileProperties(cmd_block): g_nspRemainingSize = (file_size - nsp_header_size) g_nspFile = None g_nspFilePath = None - g_Logger.debug('NSP transfer mode enabled!\n') + g_logger.debug('NSP transfer mode enabled!\n') # Perform additional integrity checks and get a file object to work with. if (not g_nspTransferMode) or (g_nspFile is None): @@ -627,14 +678,14 @@ def usbHandleSendFileProperties(cmd_block): # Make sure the output filepath doesn't point to an existing directory. if os.path.exists(fullpath) and (not os.path.isfile(fullpath)): utilsResetNspInfo() - g_Logger.error('Output filepath points to an existing directory! ("%s").\n' % (fullpath)) + g_logger.error('Output filepath points to an existing directory! ("%s").\n' % (fullpath)) return USB_STATUS_HOST_IO_ERROR # Make sure we have enough free space. (total_space, used_space, free_space) = shutil.disk_usage(dirpath) if free_space <= file_size: utilsResetNspInfo() - g_Logger.error('Not enough free space available in output volume!\n') + g_logger.error('Not enough free space available in output volume!\n') return USB_STATUS_HOST_IO_ERROR # Get file object. @@ -667,7 +718,7 @@ def usbHandleSendFileProperties(cmd_block): usbSendStatus(USB_STATUS_SUCCESS) # Start data transfer stage. - g_Logger.debug('Data transfer started. Saving %s to: "%s".' % (file_type_str, fullpath)) + g_logger.debug('Data transfer started. Saving %s to: "%s".' % (file_type_str, fullpath)) offset = 0 blksize = USB_TRANSFER_BLOCK_SIZE @@ -675,11 +726,14 @@ def usbHandleSendFileProperties(cmd_block): # Check if we should use the progress bar window. use_pbar = (((not g_nspTransferMode) and (file_size > USB_TRANSFER_THRESHOLD)) or (g_nspTransferMode and (g_nspSize > USB_TRANSFER_THRESHOLD))) if use_pbar: - idx = filename.rfind(os.path.sep) - prefix_filename = (filename[idx+1:] if (idx >= 0) else filename) - - prefix = ('Current %s: "%s".\n' % (file_type_str, prefix_filename)) - prefix += 'Use your console to cancel the file transfer if you wish to do so.' + if g_cliMode: + prefix = '' + else: + idx = filename.rfind(os.path.sep) + prefix_filename = (filename[idx+1:] if (idx >= 0) else filename) + + prefix = ('Current %s: "%s".\n' % (file_type_str, prefix_filename)) + prefix += 'Use your console to cancel the file transfer if you wish to do so.' if (not g_nspTransferMode) or g_nspRemainingSize == (g_nspSize - g_nspHeaderSize): if not g_nspTransferMode: @@ -722,7 +776,7 @@ def usbHandleSendFileProperties(cmd_block): # Read current chunk. chunk = usbRead(rd_size, USB_TRANSFER_TIMEOUT) if chunk is None: - g_Logger.error('Failed to read 0x%X-byte long data chunk!' % (rd_size)) + g_logger.error('Failed to read 0x%X-byte long data chunk!' % (rd_size)) # Cancel file transfer. cancelTransfer() @@ -736,12 +790,12 @@ def usbHandleSendFileProperties(cmd_block): if chunk_size == USB_CMD_HEADER_SIZE: (magic, cmd_id, cmd_block_size) = struct.unpack_from('<4sII', chunk, 0) if (magic == USB_MAGIC_WORD) and (cmd_id == USB_CMD_CANCEL_FILE_TRANSFER): - g_Logger.debug('\nReceived CancelFileTransfer (%02X) command.' % (USB_CMD_CANCEL_FILE_TRANSFER)) - g_Logger.warning('Transfer cancelled.') - # Cancel file transfer. cancelTransfer() + g_logger.debug('Received CancelFileTransfer (%02X) command.' % (USB_CMD_CANCEL_FILE_TRANSFER)) + g_logger.warning('Transfer cancelled.') + # Let the command handler take care of sending the status response for us. return USB_STATUS_SUCCESS @@ -759,7 +813,7 @@ def usbHandleSendFileProperties(cmd_block): if use_pbar: g_progressBarWindow.update(chunk_size) elapsed_time = round(time.time() - start_time) - g_Logger.debug('File transfer successfully completed in %s!\n' % (tqdm.format_interval(elapsed_time))) + g_logger.debug('File transfer successfully completed in %s!\n' % (tqdm.format_interval(elapsed_time))) # Close file handle (if needed). if not g_nspTransferMode: file.close() @@ -774,19 +828,19 @@ def usbHandleSendNspHeader(cmd_block): nsp_header_size = len(cmd_block) - g_Logger.debug('Received SendNspHeader (%02X) command.' % (USB_CMD_SEND_NSP_HEADER)) + g_logger.debug('Received SendNspHeader (%02X) command.' % (USB_CMD_SEND_NSP_HEADER)) # Integrity checks. if not g_nspTransferMode: - g_Logger.error('Received NSP header out of NSP transfer mode!\n') + g_logger.error('Received NSP header out of NSP transfer mode!\n') return USB_STATUS_MALFORMED_CMD if g_nspRemainingSize: - g_Logger.error('Received NSP header before receiving all NSP data! (missing 0x%X byte[s]).\n' % (g_nspRemainingSize)) + g_logger.error('Received NSP header before receiving all NSP data! (missing 0x%X byte[s]).\n' % (g_nspRemainingSize)) return USB_STATUS_MALFORMED_CMD if nsp_header_size != g_nspHeaderSize: - g_Logger.error('NSP header size mismatch! (0x%X != 0x%X).\n' % (nsp_header_size, g_nspHeaderSize)) + g_logger.error('NSP header size mismatch! (0x%X != 0x%X).\n' % (nsp_header_size, g_nspHeaderSize)) return USB_STATUS_MALFORMED_CMD # Write NSP header. @@ -794,7 +848,7 @@ def usbHandleSendNspHeader(cmd_block): g_nspFile.write(cmd_block) g_nspFile.close() - g_Logger.debug('Successfully wrote 0x%X-byte long NSP header to "%s".\n' % (nsp_header_size, g_nspFilePath)) + g_logger.debug('Successfully wrote 0x%X-byte long NSP header to "%s".\n' % (nsp_header_size, g_nspFilePath)) # Disable NSP transfer mode. utilsResetNspInfo() @@ -802,7 +856,7 @@ def usbHandleSendNspHeader(cmd_block): return USB_STATUS_SUCCESS def usbHandleEndSession(cmd_block): - g_Logger.debug('Received EndSession (%02X) command.' % (USB_CMD_END_SESSION)) + g_logger.debug('Received EndSession (%02X) command.' % (USB_CMD_END_SESSION)) return USB_STATUS_SUCCESS def usbCommandHandler(): @@ -816,13 +870,15 @@ def usbCommandHandler(): # Get device endpoints. if not usbGetDeviceEndpoints(): - # Update UI and return. - uiToggleElements(True) + if not g_cliMode: + # Update UI. + uiToggleElements(True) return - # Update UI. - g_tkCanvas.itemconfigure(g_tkTipMessage, state='normal', text=SERVER_STOP_MSG) - g_tkServerButton.configure(state='disabled') + if not g_cliMode: + # Update UI. + g_tkCanvas.itemconfigure(g_tkTipMessage, state='normal', text=SERVER_STOP_MSG) + g_tkServerButton.configure(state='disabled') # Reset NSP info. utilsResetNspInfo() @@ -831,7 +887,7 @@ def usbCommandHandler(): # Read command header. cmd_header = usbRead(USB_CMD_HEADER_SIZE) if (cmd_header is None) or (len(cmd_header) != USB_CMD_HEADER_SIZE): - g_Logger.error('Failed to read 0x%X-byte long command header!' % (USB_CMD_HEADER_SIZE)) + g_logger.error('Failed to read 0x%X-byte long command header!' % (USB_CMD_HEADER_SIZE)) break # Parse command header. @@ -849,19 +905,19 @@ def usbCommandHandler(): cmd_block = usbRead(rd_size, USB_TRANSFER_TIMEOUT) if (cmd_block is None) or (len(cmd_block) != cmd_block_size): - g_Logger.error('Failed to read 0x%X-byte long command block for command ID %02X!' % (cmd_block_size, cmd_id)) + g_logger.error('Failed to read 0x%X-byte long command block for command ID %02X!' % (cmd_block_size, cmd_id)) break # Verify magic word. if magic != USB_MAGIC_WORD: - g_Logger.error('Received command header with invalid magic word!\n') + g_logger.error('Received command header with invalid magic word!\n') usbSendStatus(USB_STATUS_INVALID_MAGIC_WORD) continue # Get command handler function. cmd_func = cmd_dict.get(cmd_id, None) if cmd_func is None: - g_Logger.error('Received command header with unsupported ID %02X.\n' % (cmd_id)) + g_logger.error('Received command header with unsupported ID %02X.\n' % (cmd_id)) usbSendStatus(USB_STATUS_UNSUPPORTED_CMD) continue @@ -869,7 +925,7 @@ def usbCommandHandler(): if (cmd_id == USB_CMD_START_SESSION and cmd_block_size != USB_CMD_BLOCK_SIZE_START_SESSION) or \ (cmd_id == USB_CMD_SEND_FILE_PROPERTIES and cmd_block_size != USB_CMD_BLOCK_SIZE_SEND_FILE_PROPERTIES) or \ (cmd_id == USB_CMD_SEND_NSP_HEADER and not cmd_block_size): - g_Logger.error('Invalid command block size for command ID %02X! (0x%X).\n' % (cmd_id, cmd_block_size)) + g_logger.error('Invalid command block size for command ID %02X! (0x%X).\n' % (cmd_id, cmd_block_size)) usbSendStatus(USB_STATUS_MALFORMED_COMMAND) continue @@ -879,10 +935,11 @@ def usbCommandHandler(): if (status is None) or (not usbSendStatus(status)) or (cmd_id == USB_CMD_END_SESSION) or (status == USB_STATUS_UNSUPPORTED_ABI_VERSION): break - g_Logger.info('\nStopping server.') + g_logger.info('\nStopping server.') - # Update UI. - uiToggleElements(True) + if not g_cliMode: + # Update UI. + uiToggleElements(True) def uiStopServer(): # Signal the shared stop event. @@ -952,14 +1009,16 @@ def uiScaleMeasure(measure): def uiInitialize(): global SCALE global g_tkRoot, g_tkCanvas, g_tkDirText, g_tkChooseDirButton, g_tkServerButton, g_tkTipMessage, g_tkScrolledTextLog - global g_tlb, g_taskbar, g_progressBarWindow + global g_stopEvent, g_tlb, g_taskbar, g_progressBarWindow + + # Setup thread event. + g_stopEvent = threading.Event() # Enable high DPI scaling under Windows (if possible). dpi_aware = False if g_isWindowsVista: try: import ctypes - dpi_aware = (ctypes.windll.user32.SetProcessDPIAware() == 1) if not dpi_aware: dpi_aware = (ctypes.windll.shcore.SetProcessDpiAwareness(1) == 0) except: @@ -989,7 +1048,7 @@ def uiInitialize(): # Create root Tkinter object. g_tkRoot = tk.Tk() - g_tkRoot.title("{} host app v{}".format(USB_DEV_PRODUCT, APP_VERSION)) + g_tkRoot.title(SCRIPT_TITLE) g_tkRoot.protocol('WM_DELETE_WINDOW', uiHandleExitProtocol) g_tkRoot.resizable(False, False) @@ -1026,7 +1085,7 @@ def uiInitialize(): g_tkCanvas.create_text(uiScaleMeasure(60), uiScaleMeasure(30), text='Output directory:', anchor=tk.CENTER) g_tkDirText = tk.Text(g_tkRoot, height=1, width=45, font=font.nametofont('TkDefaultFont'), wrap='none', state='disabled', bg='#F0F0F0') - uiUpdateDirectoryField(DEFAULT_DIR) + uiUpdateDirectoryField(g_outputDir) g_tkCanvas.create_window(uiScaleMeasure(260), uiScaleMeasure(30), window=g_tkDirText, anchor=tk.CENTER) g_tkChooseDirButton = tk.Button(g_tkRoot, text='Choose', width=10, command=uiChooseDirectory) @@ -1046,7 +1105,7 @@ def uiInitialize(): g_tkScrolledTextLog.tag_config('CRITICAL', foreground='red', underline=1) g_tkCanvas.create_window(uiScaleMeasure(WINDOW_WIDTH / 2), uiScaleMeasure(280), window=g_tkScrolledTextLog, anchor=tk.CENTER) - g_tkCanvas.create_text(uiScaleMeasure(5), uiScaleMeasure(WINDOW_HEIGHT - 10), text="Copyright (c) {}, {}".format(COPYRIGHT_YEAR, USB_DEV_MANUFACTURER), anchor=tk.W) + g_tkCanvas.create_text(uiScaleMeasure(5), uiScaleMeasure(WINDOW_HEIGHT - 10), text=COPYRIGHT_TEXT, anchor=tk.W) # Initialize console logger. console = LogConsole(g_tkScrolledTextLog) @@ -1059,22 +1118,38 @@ def uiInitialize(): g_tkRoot.lift() g_tkRoot.mainloop() +def cliInitialize(): + global g_progressBarWindow + + # Initialize console logger. + console = LogConsole() + + # Initialize progress bar window object. + bar_format = '{percentage:.2f}% |{bar}| {n:.2f}/{total:.2f} [{elapsed}<{remaining}, {rate_fmt}]' + g_progressBarWindow = ProgressBarWindow(bar_format) + + # Print info. + g_logger.info('\n' + SCRIPT_TITLE + '. ' + COPYRIGHT_TEXT + '.') + g_logger.info('Output directory: "' + g_outputDir + '".\n') + + # Start USB command handler directly. + usbCommandHandler() + def main(): - global g_Logger, g_stopEvent, g_osType, g_osVersion, g_isWindows, g_isWindowsVista, g_isWindows7 + global g_cliMode, g_outputDir, g_osType, g_osVersion, g_isWindows, g_isWindowsVista, g_isWindows7, g_logger # Disable warnings. warnings.filterwarnings("ignore") - # Setup logging mechanism. - logging.basicConfig(level=logging.INFO) - g_Logger = logging.getLogger() - if len(g_Logger.handlers): - # Remove stderr output handler from logger. - log_stderr = g_Logger.handlers[0] - g_Logger.removeHandler(log_stderr) + # Parse command line arguments. + parser = ArgumentParser(description=SCRIPT_TITLE + '. ' + COPYRIGHT_TEXT + '.') + parser.add_argument('-c', '--cli', required=False, action='store_true', help='Start the script in CLI mode.') + parser.add_argument('-o', '--outdir', required=False, type=str, metavar='DIR', help='Path to output directory. Defaults to "' + DEFAULT_DIR + '".') + args = parser.parse_args() - # Setup thread event. - g_stopEvent = threading.Event() + # Update global flags. + g_cliMode = args.cli + g_outputDir = utilsGetPath(args.outdir, DEFAULT_DIR, False, True) # Get OS information. g_osType = platform.system() @@ -1091,11 +1166,28 @@ def main(): g_isWindowsVista = (win_ver_major >= 6) g_isWindows7 = (True if (win_ver_major > 6) else (win_ver_major == 6 and win_ver_minor > 0)) - # Initialize UI. - uiInitialize() + # Setup logging mechanism. + logging.basicConfig(level=logging.INFO) + g_logger = logging.getLogger() + if len(g_logger.handlers): + # Remove stderr output handler from logger. + log_stderr = g_logger.handlers[0] + g_logger.removeHandler(log_stderr) + + if g_cliMode: + # Initialize CLI. + cliInitialize() + else: + # Initialize UI. + uiInitialize() if __name__ == "__main__": try: main() except KeyboardInterrupt: - pass + if g_cliMode: + print('\nScript interrupted.') + try: + sys.exit(0) + except SystemExit: + os._exit(0)