#!/usr/bin/env python

# This file is part of Window-Switch.
# Copyright (c) 2009-2013 Antoine Martin <antoine@nagafix.co.uk>
# Window-Switch is released under the terms of the GNU GPL v3

import time

from winswitch.util.simple_logger import Logger
logger = Logger("server_base")
debug_import = logger.get_debug_import()


debug_import("objects.common")
from winswitch.objects.common import ALLOW_FILE_SHARING, ALLOW_PRINTER_SHARING, ModifiedCallbackObject
debug_import("server_command")
from winswitch.objects.server_command import ServerCommand
debug_import("session")
from winswitch.objects.session import Session, do_get_available_session_types
debug_import("crypt_util")
from winswitch.util.crypt_util import recreate_key, generate_key, make_key_fingerprint
debug_import("consts")
from winswitch.consts import TYPE_WORKSTATION, SSH_TYPE, X11_TYPE, WINDOWS_TYPE, OSX_TYPE, LIBVIRT_TYPE, VIRTUALBOX_TYPE, VNC_TYPE, DEFAULT_RDP_PORT, DEFAULT_VNC_PORT, BINASCII
debug_import("common")
from winswitch.util.common import csv_list
debug_import("commands_util")
from winswitch.util.commands_util import SEAMLESSRDP_COMMAND
debug_import("all done")


class ServerBase(ModifiedCallbackObject):
	"""
	Server attributes that are common to ServerConfig and ServerSettings.
	"""

	PERSIST_COMMON = [
					"# Identity",
					"ID",
					"name", "type",
					"# Require an SSH tunnel to access this server?:",
					"ssh_tunnel",
					"# Protocols that can be tunnelled:",
					"tunnel_fs", "tunnel_sink", "tunnel_source", "tunnel_clone", "tunnel_printer",
					"# Authentication key (public part of the key):",
					"crypto_modulus", "crypto_public_exponent",
					]

	def __init__(self, skip_detection=False):
		ModifiedCallbackObject.__init__(self)
		self.ID = ""
		self.name = ""

		self.type = TYPE_WORKSTATION

		self.rdp_seamless_command	= SEAMLESSRDP_COMMAND
		self.test_support_protocols(skip_detection)

		self.rdp_version = ""
		self.rdp_port = DEFAULT_RDP_PORT
		self.winvnc_port = DEFAULT_VNC_PORT

		self.ssh_tunnel = False							#force use of ssh tunnel
		self.tunnel_fs = ALLOW_FILE_SHARING
		self.tunnel_sink = self.supports_sound
		self.tunnel_source = self.supports_sound
		self.tunnel_clone = self.supports_sound
		self.tunnel_printer = ALLOW_PRINTER_SHARING

		self.clients_can_stop = True
		self.allow_custom_commands = True
		self.allow_file_transfers = True
		self.download_directory = "~/Downloads"

		self.binary_encodings = [BINASCII]
		self.gstaudio_codecs = []
		self.gstvideo_codecs = []

		self.default_locale = ""
		self.locales = []

		self.init_transients()

	def __hash__(self):
		return self.ID.__hash__()

	def __cmp__(self, other):
		if type(other)!=type(self):
			return -1
		c = cmp(self.name, other.name)
		if c==0:
			c =	cmp(self.ID, other.ID)
		if c==0:
			c = ModifiedCallbackObject.__cmp__(self, other)
		return	c


	def _must_override(self):
		raise Exception("must override in subclass")

	def test_supports_fileopen(self):
		return	False

	def detect_xpra_encodings(self):
		return	[]

	def test_support_protocols(self, skip_detection=False):
		self.supports_xpra 			= skip_detection or self.test_supports_xpra()
		self.supports_xpra_desktop	= skip_detection or (self.supports_xpra and self.test_supports_xpra_desktop())
		self.supports_xprashadow	= skip_detection or self.test_supports_xprashadow()
		self.supports_xpra_encodings = []
		if not skip_detection:
			self.supports_xpra_encodings = self.detect_xpra_encodings()
		self.supports_nx			= skip_detection or self.test_supports_nx()
		self.supports_ssh			= skip_detection or self.test_supports_ssh()
		self.supports_ssh_desktop	= skip_detection or (self.supports_ssh and self.test_supports_ssh_desktop())
		self.supports_vnc			= skip_detection or self.test_supports_vnc()
		self.supports_vncshadow		= skip_detection or self.test_supports_vncshadow()
		self.supports_rdp			= skip_detection or self.test_supports_rdp()
		self.supports_rdp_seamless	= skip_detection or (self.supports_rdp and self.test_supports_rdp_seamless())
		self.supports_screen		= skip_detection or self.test_supports_screen()
		self.supports_libvirt		= skip_detection or self.test_supports_libvirt()
		self.supports_virtualbox	= skip_detection or self.test_supports_virtualbox()
		self.supports_file_open		= skip_detection or self.test_supports_fileopen()
		self.supports_sound 		= skip_detection or self.test_supports_sound()
		self.supports_gstvideo		= skip_detection or self.test_supports_gstvideo()

	def test_supports_sound(self):
		debug_import("gstreamer_util")
		from winswitch.util.gstreamer_util import has_gst, has_tcp_plugins, supported_gstaudio_codecs
		return	has_gst and has_tcp_plugins and len(supported_gstaudio_codecs)>0

	def test_supports_xpra_desktop(self):
		self._must_override()
	def test_supports_xpra(self):
		self._must_override()
	def test_supports_xprashadow(self):
		self._must_override()
	def test_supports_nx(self):
		self._must_override()
	def test_supports_vnc(self):
		self._must_override()
	def test_supports_gstvideo(self):
		self._must_override()
	def test_supports_vncshadow(self):
		self._must_override()
	def test_supports_rdp(self):
		self._must_override()
	def test_supports_rdp_seamless(self):
		self._must_override()
	def test_supports_ssh(self):
		self._must_override()
	def test_supports_ssh_desktop(self):
		self._must_override()
	def test_supports_screen(self):
		self._must_override()
	def test_supports_libvirt(self):
		self._must_override()
	def test_supports_virtualbox(self):
		self._must_override()

	def init_transients(self):
		self.remote_name = None
		self.users = []
		self.sessions = {}
		self.server_commands = []				# list of commands the server can handle (loaded from .desktop files)
		self.menu_directories = []				# menu categories for the commands above
		self.desktop_commands = []				# commands for starting X-sessions (kde, gnome, ...)
		self.action_commands = []				# main commands shown under "actions"
		self.key = None
		self.crypto_modulus = 0l
		self.crypto_public_exponent = 0l
		self.crypto_private_exponent = 0l
		self.ssh_host_public_key = None
		self.key_fingerprint = ""
		self.key_fingerprint_image = None
		self.platform = ""
		self.os_version = ""
		self.start_time = 0

	def check_timeout(self, time_synced):
		self.sdebug(None, time_synced)
		#FIXME: needs locking around these lists to prevent updates whilst we do this
		#if self.last_updated<time_synced:
		#	self.error("(%s) server has not been updated since %s!" % (time_synced, self.last_updated))
		#	return	True
		# timeout the sessions
		mod = False
		for session in self.sessions.values():
			if not session.timed_out and session.timedout(time_synced):
				self.slog(None, "session %s (%s) has timed out" % (session, session.last_updated), time_synced)
				mod = True
		# timeout the users
		for user in self.users:
			if user.active and user.timedout(time_synced):
				user.active = False
				self.slog("user %s (%s) has timed out" % (user.last_updated, user), time_synced)
				mod = True
		# timeout all the commands/dirs/desktops
		timedout = {}
		for var in ["server_commands", "menu_directories", "desktop_commands", "action_commands"]:
			commands = getattr(self, var)
			updated = []
			for command in commands:
				if not command.timedout(time_synced):
					updated.append(command)
				else:
					command.remove()
					timedout_command_list = timedout.get(var)
					if timedout_command_list is None:
						timedout_command_list = []
						timedout[var] = timedout_command_list
					timedout_command_list.append(command)
					mod = True
			setattr(self, var, updated)
		if len(timedout)>0:
			self.slog("timed out the following objects: " % timedout, time_synced)
		if mod:
			self.touch()

	def get_key(self):
		if not self.key:
			self.key = recreate_key(self.crypto_modulus, self.crypto_public_exponent, self.crypto_private_exponent)
			self.regenerate_key_fingerprint()
		return self.key

	def assign_keys(self):
		new_key = generate_key()
		self.set_key(new_key)
		self.crypto_private_exponent = new_key.d

	def set_key(self, key):
		self.key = key
		self.crypto_modulus = self.key.n
		self.crypto_public_exponent = self.key.e
		self.regenerate_key_fingerprint()

	def regenerate_key_fingerprint(self):
		self.key_fingerprint = make_key_fingerprint(self.crypto_modulus, self.crypto_public_exponent)
		return	self.key_fingerprint

	def get_command_list_for_type(self, command_type):
		if command_type == ServerCommand.COMMAND:
			return	self.server_commands
		elif command_type == ServerCommand.CATEGORY:
			return	self.menu_directories
		elif command_type == ServerCommand.DESKTOP:
			return	self.desktop_commands
		elif command_type == ServerCommand.ACTION:
			return	self.action_commands
		else:
			self.serror("unknown command type", command_type)
			return	None

	def add_command(self, command):
		"""
		Adds the command to the relevant list.
		The list is found from the command's type using get_command_list_for_type()
		"""
		command_list = self.get_command_list_for_type(command.type)
		#self.debug("(%s) list(%s)=%s" % (command, command.type, str(command_list)))
		if command_list is None:
			return	False
		return self._do_add_command(command_list, command)

	def _do_add_command(self, command_list, command):
		"""
		Adds the command to the command_list if it is not present already.
		Returns True if the command_list was modified (item added).
		Calls touch() on self if the command_list is modified, touch() on the existing command otherwise.
		"""
		for test in command_list:
			if test.uuid == command.uuid:
				#FIXME: copy attributes to update the command and return True if any are modified
				test.update(command)
				return False
		#self.debug("(%s,%s) adding it" % (command_list, command))
		command_list.append(command)
		self.touch()
		return True

	def get_command_by_uuid(self, uuid):
		for command_list in [self.server_commands, self.menu_directories, self.desktop_commands, self.action_commands]:
			for command in command_list:
				if command.uuid == uuid:
					return	command
		return	None

	def get_server_command_by_command(self, cmd, substring=False):
		sub_match = None
		for command in self.server_commands:
			if command.command == cmd:
				return	command
			if substring and command.command.startswith(cmd):
				sub_match = command
		return	sub_match

	def add_user(self, user):
		for test in self.users:
			if test.uuid == user.uuid:
				test.update(user)
				test.active = True
				test.touch()
				return False
		self.users.append(user)
		user.last_updated = time.time()
		self.touch()
		return True


	def remove_user_by_uuid(self, uuid):
		user = self.get_user_by_uuid(uuid)
		if user:
			self.users.remove(user)
			self.touch()

	def get_user_by_uuid(self, uuid):
		for test in self.users:
			if test.uuid == uuid:
				return test
		return None

	def get_users(self):
		return self.users

	def get_active_users(self):
		l = []
		for user in self.users:
			if user.active:
				l.append(user)
		return l

	def get_connected_sessions(self, actor, ignore=[X11_TYPE, WINDOWS_TYPE, OSX_TYPE], or_connecting=False):
		valid_sessions = [session for session in self.sessions.values() if session.session_type not in ignore]
		def is_connected(session):
			return	session.is_connected_to(actor) or (or_connecting and session.is_connecting_to(actor))
		connected = [session for session in valid_sessions if is_connected(session)]
		self.sdebug("=%s" % csv_list(connected), actor, csv_list(ignore), or_connecting)
		return	connected

	def add_session(self, session):
		assert len(session.ID)>0
		existing = self.get_session(session.ID)
		if existing is session:
			#no need to copy anything! this is the same object!
			return existing
		if existing and existing.status != Session.STATUS_CLOSED:
			existing.update(session)
			return	existing
		session.touch()
		self.sessions[session.ID] = session
		self.touch()
		return	session

	def remove_session_by_ID(self, ID):
		session = self.sessions.get(ID)
		if session:
			session.set_status(Session.STATUS_CLOSED)
			del self.sessions[ID]
			self.touch()
		return session

	def remove_session_by_display(self, display):
		for test in self.sessions.values():
			if test.display == display:
				return	self.remove_session_by_ID(test.ID)

	def get_session_by_display(self, display):
		for test in self.sessions.values():
			if test.display == display:
				return	test
		return	None

	def get_all_sessions_for_display(self, display, shadows=None):
		""" Returns all the sessions whose display matches the one supplied.
			If shadows is None (default), this will match against session.display or session.shadowed_display
			If shadows is False, this will match session.display only
			If shadows is True, this will match session.shadowed_display only
			SSH sessions are always ignored.
		"""
		def matches_display(sess):
			return	(sess.display==display and shadows is not True) or (shadows is not False and sess.shadowed_display and sess.shadowed_display==display)
		return [session for session in self.get_live_sessions(True, ignore=[SSH_TYPE]) if matches_display(session)]

	def remove_session(self, session):
		self.remove_session_by_ID(session.ID)

	def get_sessions(self):
		return self.sessions

	def get_live_sessions(self, allow_shadow=True, allow_preload=False, ignore=[X11_TYPE, WINDOWS_TYPE]):
		valid_sessions = [session for session in self.sessions.values() if session.session_type not in ignore]
		def is_live(session):
			if not allow_shadow and session.shadowed_display and len(session.shadowed_display)>0:
				return	False
			return (not session.preload or allow_preload) and session.status not in [Session.STATUS_CLOSED] and not session.timed_out
		return	[session for session in valid_sessions if is_live(session)]

	def get_sessions_by_status(self, status, allow_preload=False):
		if not allow_preload:
			sessions = [session for session in self.sessions.values() if not session.preload]
		else:
			sessions = self.sessions.values()
		return	[session for session in sessions if session.status==status]

	def get_sessions_by_type(self, session_type):
		return	[session for session in self.sessions.values() if session.session_type == session_type]

	def get_session(self, ID):
		if ID in self.sessions:
			return self.sessions[ID]
		return	None

	def get_available_session_types(self, desktop_only=False, hide_suboptimal=False):
		return	do_get_available_session_types(self.supports_xpra, self.supports_xpra_desktop,
												self.supports_nx,
												self.supports_ssh,
												self.supports_ssh_desktop,
												self.supports_vnc,
												self.supports_libvirt,
												self.supports_virtualbox,
												self.supports_rdp_seamless, self.supports_rdp,
													desktop_only, hide_suboptimal)

	def get_filtered_session_types(self, from_list, desktop_only, hide_suboptimal):
		"""
		Returns the list of session types that can be started on this server.
		(removing any types that are not in "from_list")
		"""
		_list = [x for x in from_list if x not in [LIBVIRT_TYPE, VIRTUALBOX_TYPE]]
		if self.platform.startswith("win") and VNC_TYPE in _list:
			""" win32 can't start a VNC session (only shadow) """
			_list.remove(VNC_TYPE)
		server_session_types = self.get_available_session_types(desktop_only, hide_suboptimal)
		session_types = [x for x in server_session_types if x in _list]
		#self.sdebug("=%s from list=%s, server_session_types=%s" % (session_types, _list, server_session_types), from_list, desktop_only, hide_suboptimal)
		return	session_types
