#!/usr/bin/env python3
# vim: set tw=120 :
# coding: utf-8
#
# This file is part of alip's chess scripts.
# Copyright (C) 2015, 2016, 2018, 2023 Ali Polatel <alip@exherbo.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import os, sys, socket

import argparse
import signal
import string
import json
import itertools
import collections
import re
import shlex
import subprocess
import tempfile
import time
import io
import chess
import chess.pgn
import chess.polyglot

#
# Configuration
# TODO: use command line arguments.
################################################
JJA = 'jja'
PGN_EXTRACT = 'pgn-extract'
# Checkout jja.git to ~/src/jja or set JJA_GIT.
JJA_GIT = os.environ.get("JJA_GIT", os.path.expanduser("~/src/jja"))
ECO_PLY     = {
        'A11' : 8, 'A15': 8, 'A39h': 18,
        'B10m': 12, 'B17f': 16, 'B38y': 32,
        'D02o': 6, 'D12': 12, 'D12p': 14, 'D15': 12, 'D41n': 30}
ECO_PLY_DEF = 10 # Starts from move 6. (?)
ECO_PLY_MAX = 42
################################################

def debugger():
    import pdb; pdb.set_trace()

# Filter PGN file by eco information using pgn-extract.
# Generates list of games.
def filter_eco(fname, eco_info):
    eco_code = eco_info['eco'][:3] # C22a -> C22
    cmd = "%s -Te%s %s" % (PGN_EXTRACT, eco_code, shlex.quote(fname))

    with subprocess.Popen(shlex.split(cmd), stdout = subprocess.PIPE, universal_newlines = True) as proc:
        chess_game = chess.pgn.read_game(proc.stdout)
        while chess_game is not None:
            yield chess_game
            chess_game = chess.pgn.read_game(proc.stdout)

# Filter PGN file by eco & position information using pgn-extract
# Generates list of games.
def filter_position(fname, **kwargs):
    argv = []
    if 'eco' in kwargs:
        argv.append("-e -Te%s" % kwargs['eco'])
    if 'ftag' in kwargs:
        argv.append("-t%s" % shlex.quote(kwargs['ftag']))
    if 'fvars' in kwargs:
        argv.append("-x%s" % shlex.quote(kwargs['fvars']))
    limit = kwargs.get('limit', None)

    cmd = "%s %s --addhashcode --markmatches MATCH %s" % (PGN_EXTRACT, ' '.join(argv), shlex.quote(fname))
    ref = os.environ.get("JJA_REFERENCE_DB",
            os.path.expanduser('~/jja-ref.pgn'))
    with open(ref, 'a', encoding="utf-8-sig", errors="surrogateescape") as ref:
        exporter = chess.pgn.FileExporter(ref)

        with subprocess.Popen(shlex.split(cmd), stdout = subprocess.PIPE, universal_newlines = True) as proc:
            c = 0
            chess_game = chess.pgn.read_game(proc.stdout)
            while chess_game is not None:
                chess_game.accept(exporter)
                yield chess_game
                c += 1
                if limit is not None and limit == c:
                    break
                chess_game = chess.pgn.read_game(proc.stdout)

# Prepare ECO dictionary.
eco_file = os.path.join(JJA_GIT, 'misc/eco.json')
ECO = None
with open(eco_file, encoding="utf-8-sig", errors="surrogateescape") as f:
    data = json.load(f)
    ECO = dict(sorted(map(lambda k: [data[k]['hash'], k], data.keys()), key=lambda x: x[1]))

def find_eco(chess_game):
    global ECO
    eco_map = ECO

    node = chess_game
    for ply in range(48):
        if not node.variations:
            break # End of game.
        node = node.variation(0)

    while node.parent is not None:
        key = chess.polyglot.zobrist_hash(node.board())
        if key in eco_map:
            return eco_map[key]
        node = node.parent

    return None

def add_references(my_game, refdb, hash_codes = set(), limit = None):
    eco_code = find_eco(my_game)
    if eco_code is not None:
        my_game.headers['ECO'] = eco_code

    ply_min = ECO_PLY.get(eco_code, ECO_PLY_DEF)
    ply_max = ECO_PLY_MAX

    if os.path.isdir(refdb):
        refdb = os.path.join(refdb, "%s.pgn" % eco_code)

    # Temporary file to write tag criteria.
    tfd, tname = tempfile.mkstemp()
    ftag = os.fdopen(tfd, 'w')

    node = my_game
    mdict = collections.OrderedDict()
    for ply in range(ply_max):
        if not node.variations:
            break # End of game, too bad.
        board = node.board()
        if ply >= ply_min:
            fen_pattern = node.board().fen().split(' ')[0]
            if fen_pattern not in mdict:
                mdict[fen_pattern] = node
        node = node.variation(0)
    for fen_pattern in reversed(mdict):
        print("FENPattern \"%s\"" % fen_pattern, file=ftag)
    ftag.close() # Closes the underlying fd as well.

    for refgame in filter_position(refdb, ftag = tname, limit = limit):
        short_date = refgame.headers['Date'].replace('.??.??', '', 1)
        starting_comment = "%s - %s (%s, %s)" % (refgame.headers['White'],
                                                 refgame.headers['Black'],
                                                 refgame.headers['Site'],
                                                 short_date)

        hash_code = refgame.headers['HashCode']
        if hash_code not in hash_codes:
            hash_codes.add(hash_code)
        else:
            print("Match DUP: %s" % starting_comment, file=sys.stderr)
            continue

        # Find the node with the match.
        node = refgame
        while node.variations:
            if node.comment == 'MATCH':
                break
            node = node.variation(0)

        if not node.variations:
            continue # Game ends here, too bad.

        node_reference     = None
        parent_fen_pattern = None
        while node.variations:
            fen_pattern = node.board().fen().split(' ')[0]
            if fen_pattern not in mdict: # Deviation
                node_reference = node
                break
            node = node.variation(0)
            parent_fen_pattern = fen_pattern

        if node_reference is None: # No deviation?
            print("No deviation: %s" % starting_comment, file=sys.stderr)
            continue

        # TODO, find node with max. halfmove clock and add novelty NAG.
        mygame_node = mdict[parent_fen_pattern]

        node_reference.parent = mygame_node
        node_reference.starting_comment = starting_comment
        node_end = node_reference.end()
        node_end.comment = refgame.headers['Result']

        mygame_node.variations.append(node_reference)
        print("Match: %s" % starting_comment, file=sys.stderr)

    return hash_codes

try: # Fast, shiny
    import gmpy2
    piece_count = lambda board: gmpy2.popcount(board.occupied)
except ImportError: # Awfully slow
    def piece_count(board):
        n = 0
        for i in range(64):
            if (board.occupied & (1 << i)):
                n += 1
        return n

def probe(orig_board, game_board):
    """
    Probe a board position using jja lookup.
    Assumes SYZYGY_PATH is set.
    """
    results = set()
    for board in (orig_board, game_board):
        if board is None:
            continue
        fen = shlex.quote(board.fen())
        cmd = shlex.split(f"{JJA} probe --test --fast {fen}")
        ts0 = time.time()
        print(f"PROBE_WDL\t{fen}", file=sys.stderr, end='\t')
        with subprocess.Popen(cmd, stdout = subprocess.PIPE, universal_newlines = True) as proc:
            stdout, _ = proc.communicate()
            result = stdout.strip()
            results.add(result)
            ts1 = time.time()
            print(f"{result}\t{ts1 - ts0}sec", file=sys.stderr)
    if game_board is not None and len(results) <= 1:
        # Game move does not change result.
        return None

    # Return PV of from the perspective of orig_board.
    fen = shlex.quote(orig_board.fen())
    cmd = shlex.split(f"{JJA} probe {fen}")
    with subprocess.Popen(cmd, stdout = subprocess.PIPE, universal_newlines = True) as proc:
        game = chess.pgn.read_game(proc.stdout)
        if game.variations: # Avoid returning variations without any moves.
            return game
    return None

def add_tablebase_probes(my_game):
    node = my_game.end()
    while node is not None and node.parent is not None:
        game_board = node.board()
        orig_board = node.parent.board()

        if orig_board.is_game_over(claim_draw=True) or \
                game_board.is_game_over(claim_draw=True) or \
                piece_count(orig_board) > 7 or \
                piece_count(game_board) > 7:
            node = node.parent
            continue

        probe_game = probe(orig_board, game_board)
        if probe_game is not None:
            pvar = probe_game.variation(0)
            pres = probe_game.headers['Result']
            pvar.parent = node.parent
            pvar.starting_comment = f'syzygy:{pres}'
            pvar.nags.add(chess.pgn.NAG_GOOD_MOVE)
            if pres == '0-1':
                pvar.nags.add(chess.pgn.NAG_BLACK_DECISIVE_ADVANTAGE)
            elif pres == '1-0':
                pvar.nags.add(chess.pgn.NAG_WHITE_DECISIVE_ADVANTAGE)
            else: # Draw
                pvar.nags.add(chess.pgn.NAG_DRAWISH_POSITION)
            node.nags.add(chess.pgn.NAG_BLUNDER)
            node.parent.variations.append(pvar)
        node = node.parent

if __name__ == '__main__':
    parser = argparse.ArgumentParser(prog="jja-doctor",
        description="Annotates novelties in input PGN.")

    parser.add_argument("-r", "--reference",
        action='append', nargs="?",
        help="reference PGN files.")

    parser.add_argument("source", type=str, nargs="?", default='-', help="input PGN.")
    parser.add_argument("output",
        type=argparse.FileType("w"),
        nargs="?",
        default=sys.stdout,
        help="The output PGN.")

    args = parser.parse_args()

    if args.source == '-':
        # Temporary `seekable' file. (required only for games > 0, below)
        fd, _ = tempfile.mkstemp()
        src = os.fdopen(fd, 'w+', encoding="utf-8", errors = "replace")
        src.write(sys.stdin.read())
        src.flush()
        src.seek(0)
        args.source = src
    else:
        args.source = open(args.source, 'r', encoding="utf-8", errors = "replace")

    c = 1
    game = chess.pgn.read_game(args.source)
    while game is not None:
        print("Processing Game %d" % c, file=sys.stderr)
        if not game.variations:
            print("Skipped blank game %d" % c, file=sys.stderr)
        else:
            hash_codes = set()
            for reference in args.reference or tuple():
                limit = None
                print("Game %d: Adding references from `%s' (limit: %r)" % (c, reference, limit), file=sys.stderr)
                hash_codes = add_references(game, reference, hash_codes, limit)
            print("Game %d: Added %d reference games" % (c, len(hash_codes)), file=sys.stderr)

            print("Game %d: Adding tablebase probes" % c, file=sys.stderr)
            add_tablebase_probes(game)

            print(game, file=args.output)
            args.output.write("\n")

        c += 1
        game = chess.pgn.read_game(args.source)
