﻿# -*- coding: UTF-8 -*-
from __future__ import print_function, unicode_literals
import io, os, sys, struct, logging
import scx, translator


class Message(object):
	def __repr__(self):
		return repr(scx.text_to_string(self.text))


def parse_messages_group(input_data, pos, limit, expected_text_id):
	# one message is formed by commands {cmd010900 <id>, cmd011101, cmd011102, cmd010B, cmd010C *, cmd010D 0}
	messages = []
	while (pos < limit and
		input_data.content[pos:pos+3] == bytearray((0x01, 0x09, 0x00)) and
		input_data.content[pos+5:pos+15] == bytearray((0x01, 0x11, 0x01, 0x01, 0x11, 0x02, 0x01, 0x0B, 0x01, 0x0C))):
		msg = Message()
		msg.savepoints = struct.unpack_from('<H', input_data.content, pos+3) #one-item tuple
		flags = input_data.content[pos+15]
		parse_pos = pos + 16
		if flags & 1:
			next_pos = scx.skip_expr(input_data.content, parse_pos)
			msg.voice = input_data.content[parse_pos:next_pos]
			parse_pos = next_pos
		else:
			msg.voice = None
		next_pos = scx.skip_expr(input_data.content, parse_pos)
		msg.unk1 = input_data.content[parse_pos:next_pos]
		parse_pos = next_pos
		if flags & 2:
			next_pos = scx.skip_expr(input_data.content, parse_pos)
			msg.unk2 = input_data.content[parse_pos:next_pos]
			parse_pos = next_pos
		else:
			msg.unk2 = None
		text_id = struct.unpack_from('<H', input_data.content, parse_pos)[0]
		if text_id != expected_text_id:
			raise Exception("unexpected text id %d instead of %d" % (text_id, expected_text_id))
		expected_text_id += 1
		msg.text = input_data.parse_text(text_id)
		parse_pos += 2
		if input_data.content[parse_pos:parse_pos+3] != bytearray((0x01, 0x0D, 0x00)):
			break
		parse_pos += 3
		messages.append(msg)
		pos = parse_pos
	return messages, pos


def serialize_messages_group(output, messages, pos_base, old_text_count):
	cur_text_count = 0
	for msg in messages:
		for s in sorted(msg.savepoints):
			if s < len(output.savepoints) and output.savepoints[s] == len(output.code) + pos_base:
				continue
			if s != len(output.savepoints):
				raise Exception("unexpected savepoint id: %d vs %d" % (s, len(output.savepoints)))
			output.add_savepoint(len(output.code) + pos_base)
		for s in sorted(msg.savepoints, reverse=True):
			output.code.extend((0x01, 0x09, 0x00, s & 0xFF, s >> 8))
		output.code.extend((0x01, 0x11, 0x01, 0x01, 0x11, 0x02, 0x01, 0x0B, 0x01, 0x0C))
		flags = 0
		if msg.voice is not None:
			flags |= 1
		if msg.unk2 is not None:
			flags |= 2
		output.code.append(flags)
		if msg.voice is not None:
			output.code.extend(msg.voice)
		output.code.extend(msg.unk1)
		if msg.unk2 is not None:
			output.code.extend(msg.unk2)
		# Text identifiers are used in markers of read/unread text in savefiles.
		# We want to provide a reasonable compatibility of savefiles.
		# So we don't change text identifiers where it is possible to keep them.
		#
		# If the original has 3 pages where the translation has 4 pages,
		# it is impossible; to preserve the subsequent identifiers,
		# we always append new texts to the end of list after all existing texts.
		#
		# If the original has 3 pages where the translation has 2 pages,
		# we leave a gap in the identifiers for the same goal.
		if cur_text_count < old_text_count:
			text_id = output.serialize_text(msg.text)
			cur_text_count += 1
		else:
			text_id = output.pending_text_index + len(output.pending_texts)
			output.pending_texts.append(msg.text)
		output.code.extend((text_id & 0xFF, text_id >> 8))
		output.code.extend((0x01, 0x0D, 0x00))
	while cur_text_count < old_text_count:
		#output.text_ptrs.append(0xBAAD0000) # invalid value because this id should never be accessed
		output.text_ptrs.append(0)
		cur_text_count += 1



def replace_translated(messages, translated_key, logger):
	if not translated_key:
		raise Exception("deleting the entire group of messages is not supported")
	translated_texts = translated_key.split('|')
	savepoints = []
	same_length = (len(messages) == len(translated_texts))
	voices = []
	for msg in messages:
		savepoints += msg.savepoints
		if same_length or (msg.voice, msg.unk1, msg.unk2) != (None, b'\x80\x00\x00', None):
			voices.append((msg.voice, msg.unk1, msg.unk2))
	savepoints_pos = 0
	voices_pos = 0
	result = []
	for key in translated_texts:
		msg = Message()
		if savepoints_pos < len(savepoints):
			msg.savepoints = [savepoints[savepoints_pos]]
			savepoints_pos += 1
		else:
			msg.savepoints = []
		msg.voice, msg.unk1, msg.unk2 = (None, b'\x80\x00\x00', None)
		if same_length or key.startswith('[s:'):
			if voices_pos < len(voices):
				msg.voice, msg.unk1, msg.unk2 = voices[voices_pos]
				voices_pos += 1
			else:
				logger.warning("no voice in: " + key)
				translator.num_warnings += 1
		msg.text = scx.string_to_text(key + '[output]')
		result.append(msg)
	result[-1].savepoints += savepoints[savepoints_pos:]
	if voices_pos < len(voices):
		logger.warning("dropped voice in: " + translated_key)
		translator.num_warnings += 1
	return result


def translate_sc3(input_file, translations_file, output_file):
	logger = logging.getLogger(os.path.basename(input_file))
	input_data = scx.ScxData(input_file)
	translations = translator.load_translations(translations_file, fix_quotes=True, use_speech_quotes=True, fix_two_spaces=True)
	pos = input_data.eip
	next_label_idx = 0
	next_label = input_data.labels[0] if input_data.labels else input_data.strings_table
	next_savepoint_idx = 0
	next_savepoint = input_data.savepoints[0] if input_data.savepoints else input_data.strings_table
	text_pos = 0
	code_delta = 0
	unmodified_pos = pos
	text_id = 0
	output = scx.OutputBuilder(input_data.num_strings())
	content = input_data.content
	while pos < input_data.strings_table:
		while pos >= next_label:
			output.add_label(next_label + code_delta)
			next_label_idx += 1
			next_label = input_data.labels[next_label_idx] if next_label_idx < len(input_data.labels) else input_data.strings_table
		while pos >= next_savepoint:
			output.add_savepoint(next_savepoint + code_delta)
			next_savepoint_idx += 1
			next_savepoint = input_data.savepoints[next_savepoint_idx] if next_savepoint_idx < len(input_data.savepoints) else input_data.strings_table
		messages, next_pos = parse_messages_group(input_data, pos, next_label, text_id)
		if messages:
			#print("found message series at %X" % pos)
			text_id += len(messages)
			while next_pos >= next_savepoint: # drop, serialize_messages_group() recreates these savepoints in proper places
				next_savepoint_idx += 1
				next_savepoint = input_data.savepoints[next_savepoint_idx] if next_savepoint_idx < len(input_data.savepoints) else input_data.strings_table
			text = ''
			text_count = 0
			source = []
			translated = []
			for msg in messages:
				if text_count:
					text += '|'
				text += scx.text_to_string(msg.text)
				text_count += 1
				if not text.endswith('[output]'):
					raise Exception("no output tag in scx file")
				text = text[:-len('[output]')]
				source.append(msg)
				if text_pos == len(translations):
					logger.warning("too few keys in translation, should be: " + text)
					translator.num_warnings += 1
					translated.extend(source)
					text = ''
					text_count = 0
					source = []
				elif text_count == translations[text_pos][0].count('|') + 1:
					if translator.normalize_translation_key(text) != translator.normalize_translation_key(translations[text_pos][0]):
						logger.warning("mismatch in translation keys:")
						logger.warning("seen:      " + translations[text_pos][0])
						logger.warning("should be: " + text)
						translator.num_warnings += 1
					translated.extend(replace_translated(source, translations[text_pos][1], logger))
					text_pos += 1
					text = ''
					text_count = 0
					source = []
			if text:
				logger.error("translation key overlaps group boundary:")
				logger.error("seen:      " + translations[text_pos][0])
				logger.error("should be: " + text)
				raise Exception("translation key overlaps group boundary")
			output.add_code(content[unmodified_pos:pos])
			serialize_messages_group(output, translated, input_data.eip, len(messages))
			unmodified_pos = next_pos
			code_delta = output.get_code_pos(input_data.eip) - next_pos
			pos = next_pos
		else:
			# generic parsing
			byte1 = content[pos]
			if byte1 == 0xFE:
				pos += 1
				pos = scx.skip_expr(content, pos)
			else:
				byte2 = content[pos + 1]
				pos += 2
				cmd = (byte1 << 8) | byte2
				optypes = scx.commands.get(cmd)
				if type(optypes) == dict:
					optypes = optypes.get(content[pos])
					cmd = (cmd << 8) | content[pos]
					pos += 1
				if optypes is None:
					#print("unknown command 0x%X at 0x%X" % (cmd, pos - 2))
					pos = next_label
					continue
				for optype in optypes:
					if optype == scx.OPTYPE_EXPR:
						pos = scx.skip_expr(content, pos)
					elif optype == scx.OPTYPE_LABEL: # note: labels are 1-based
						pos += 2
						if cmd == 0x0008:
							label = struct.unpack_from('<H', content, pos - 2)[0]
							if label - 1 != next_label_idx:
								raise Exception("Parse error [cmd0008]")
							# there can be a padding zero byte between pos and labels[next_label_idx]
							# labels[next_label_idx] ... labels[next_label_idx + 1] is an array of words with switch labels
							output.add_code(content[unmodified_pos:pos])
							if output.get_code_pos(input_data.eip) % 2:
								output.add_code(b'\0')
							unmodified_pos = input_data.labels[next_label_idx]
							pos = input_data.labels[next_label_idx + 1]
							code_delta = output.get_code_pos(input_data.eip) - unmodified_pos
					elif optype == scx.OPTYPE_BYTE:
						pos += 1
					elif optype == scx.OPTYPE_WORD:
						pos += 2
					elif optype == scx.OPTYPE_TEXTPTR:
						got_text_id = struct.unpack_from('<H', content, pos)
						if text_id != got_text_id:
							raise Exception("unexpected text id %d instead of %d" % (got_text_id, text_id))
						text_id += 1
						output.add_code(content[unmodified_pos:pos])
						new_text_id = output.serialize_text(scx.parse_text(content, pos, strings_table))
						output.add_code(bytearray((new_text_id & 0xFF, new_text_id >> 8)))
						pos += 2
						unmodified_pos = pos
					elif optype == scx.OPTYPE_2BOR2EXPR:
						if content[pos] == 99:
							pos += 1
							pos = scx.skip_expr(content, pos)
							pos = scx.skip_expr(content, pos)
						else:
							pos += 2
					elif optype == scx.OPTYPE_CMD0023:
						if content[pos] != 2:
							pos += 1
							pos = scx.skip_expr(content, pos)
							pos = scx.skip_expr(content, pos)
						else:
							pos += 1
					elif optype == scx.OPTYPE_CMD010C:
						flags = content[pos]
						pos += 1
						if flags & 1:
							pos = scx.skip_expr(content, pos)
						pos = scx.skip_expr(content, pos)
						if flags & 2:
							pos = scx.skip_expr(content, pos)
					elif optype == scx.OPTYPE_CMD1004FIRST:
						cmd1004first = struct.unpack_from('<B', content, pos)[0]
						pos += 1
					elif optype == scx.OPTYPE_CMD1004SECOND:
						if cmd1004first >= 4:
							pos = scx.skip_expr(content, pos)
				if cmd in (0,0xE) and input_data.strings_table - pos <= 3 and content[pos:input_data.strings_table] == bytearray((0,) * (input_data.strings_table - pos)):
					break
	output.add_code(content[unmodified_pos:pos])
	while next_label_idx < len(input_data.labels):
		output.add_label(input_data.labels[next_label_idx] + code_delta)
		next_label_idx += 1
	while next_savepoint_idx < len(input_data.savepoints):
		output.add_savepoint(input_data.savepoints[next_savepoint_idx] + code_delta)
		next_savepoint_idx += 1
	with io.open(output_file, 'wb') as f:
		output.save(f)


if __name__ == "__main__":
	logging.basicConfig()
	scx.load_characters_data('input_map.bin', translator.enable_cyrillic)
	scx.setup_for_main_text(translator.enable_value_transform)
	#scx.map_fullwidth_to_normal()
	translate_sc3('script/' + sys.argv[1], sys.argv[2], 'translated_script/' + sys.argv[1])
