"""
This file implements YAML calibration, a mechanism of programming multiple automated calibration steps
(and related operations). Essentially this scheme coordinates repeated calls to at.calibrate with different
adjustables and measurables in a pre-defined sequence of automated calibration steps.
"""
import sciris as sc
from pathlib import Path
import shutil
import atomica as at
import numpy as np
import yaml
import time
import re
__all__ = ["build", "run"]
from atomica import ParameterSet
def _get_named_nodes():
    """
    Return dictionary with all named Node subclasses
    """
    return {x._name: x for x in BaseNode.__subclasses__() if x._name is not None}
[docs]
def build(instructions=None, context=None, name="calibration"):
    """
    Construct nodes representing a calibration
    :param instructions: A dictionary of attributes/settings defined for this node OR a string filename
                         containing a YAML file that can be loaded to provide instructions
    :param context: A dictionary of attributes/settings inherited from parent nodes
    :param name: The name to assign this node
    :param fname: Optionally read the instructions from a file
    :return: A Node subclass instance, the type of which depends on the instructions
    """
    if (sc.isstring(instructions) or isinstance(instructions, Path)) and Path(instructions).exists():
        with open(instructions) as file:
            instructions = yaml.load(file, Loader=yaml.FullLoader)
    named_nodes = _get_named_nodes()
    if isinstance(instructions, dict) and ("adjustables" in instructions or (context is not None and "adjustables" in context)) and ("measurables" in instructions or (context is not None and "measurables" in context)):
        return CalibrationNode(instructions, context, name)
    elif name in named_nodes:
        return named_nodes[name](instructions, context, name)
    else:
        return SectionNode(instructions, context, name) 
[docs]
def run(node, project, parset, savedir=None, save_intermediate=False, log_output: bool = False, *args, **kwargs):
    """
    Run YAML calibration
    This will execute the YAML calibration using the passed-in node (or instructions to build a node), and any associated children
    :param node: Calibration node to execute. If not a node (i.e., a YAML file, or node instructions), it will be converted into a node
    :param P: Project to which to apply these instructions
    :param parset: An `at.ParameterSet` instance to calibrate
    :param savedir: Optionally specify a directory to save the results. Defaults to the current working directory
    :param save_intermediate: Set whether to save intermediate calibrations (defaults to False)
    :return new_parset: A calibrated `at.ParameterSet` instance
    """
    if savedir is None:
        savedir = Path(".")
    else:
        savedir = Path(savedir)
    savedir.mkdir(exist_ok=True, parents=True)
    if not isinstance(node, BaseNode):
        # Save a copy of the yaml-file if saving log output
        if isinstance(node, Path) and log_output:
            shutil.copyfile(node, savedir / node.name)
        node = build(node)
    parset = sc.dcp(project.parset(parset))
    nodes = list(node.walk())  # Make a flat list of all nodes to execute in order
    n_steps = len([x for x in nodes if not isinstance(x[1], SectionNode)])
    n = 1
    if log_output:
        at.start_logging(savedir / "calibration_log.txt")
    at.logger.info(f"\nStarting calibration ({n_steps} steps)")
    for n_reps, node in nodes:
        if isinstance(node, SectionNode):
            at.logger.info(f'\nSection "{node.name}" (repeat {n_reps} of {node.repeats})')
        else:
            at.logger.info(f'\nStep {n} of {n_steps}: "{node.name}" (repeat {n_reps} of {node.repeats})')
            parset = node.apply(project, parset, savedir, save_intermediate, *args, **kwargs)
            n += 1
            if save_intermediate and not isinstance(node, SaveCalibrationNode):
                output = savedir / f'intermediate_calibration_{n:0{len(str(n_steps))}}_{node.name.replace(" ", "_")}'
                at.logger.info(f"Saving intermediate calibration...")
                parset.save_calibration(output)
    t = time.process_time()
    at.logger.info(f"\nCalibration completed. Total time elapsed: {round(t, 2)} seconds ({round(t/60, 2)} minutes)")
    if log_output:
        at.stop_logging()
    return parset 
[docs]
class BaseNode:
    """
    Node base class
    The base node class implements basic node features. Typically there should not be any
    instances of this class, only instances of subclasses
    """
    _name = None  # If specified, this key can be used as the name of the step to create a node of this type
    def __init__(self, instructions, context, name):
        self.name = name
        self.instructions = sc.dcp(instructions)
        self.context = context  # Attributes inherited from parent nodes
        self.children = []
        self.validate()
    def walk(self):
        n_reps = 0
        for repeat in range(self.repeats):
            n_reps += 1
            yield (n_reps, self)
            for child in self.children:
                yield from child.walk()
    @property
    def n_steps(self):
        if type(self) == BaseNode:
            return self.repeats * sum(child.n_steps for child in self.children)
        else:
            return self.repeats
    def __repr__(self):
        return f'<{self.__class__.__name__} "{self.name}" x{self.repeats}>'
    def __str__(self, indent=0):
        """
        Print a tree representation of this node and all children
        :param indent: Recursively increase the indent for child nodes
        :return:
        """
        s = "\t" * indent + self.__repr__()
        for child in self.children:
            s += "\n" + child.__str__(indent=indent + 1)
        return s
    @property
    def attributes(self):
        return sc.mergedicts(self.context, self.instructions)
    def __getitem__(self, item):
        # Directly index the Node to extract attributes without merging the dictionaries every time
        if item in self.instructions:
            return self.instructions[item]
        elif item in self.context:
            return self.context[item]
        else:
            raise KeyError(item)
    def __setitem__(self, key, value):
        self.instructions[key] = value
    def __contains__(self, item):
        return item in self.instructions or item in self.context
    @property
    def repeats(self):
        # Although repeats may be part of the context, we only repeat a node if the instructions requested a repeat
        # i.e., repeats are not inherited
        if isinstance(self.instructions, dict) and "repeats" in self.instructions:
            return self.instructions["repeats"]
        else:
            return 1
[docs]
    def validate(self):
        """
        Validate/sanitize contents of this node
        If the node isn't valid, an error should be raised
        """
        return 
[docs]
    def apply(self, project: at.Project, parset: at.ParameterSet, savedir, *args, **kwargs) -> ParameterSet:
        """
        Perform the action associated with this node
        """
        return parset 
 
[docs]
class SectionNode(BaseNode):
    """
    A section node is a special kind of node, that contains other nodes
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.children = self.make_children()
        self.validate()
    def make_children(self):
        children = []
        # Remove any keys in the instructions that correspond to named nodes
        # These should be used as instructions for the child node, rather than
        # forming part of the context that is passed down to all children
        named_nodes = _get_named_nodes()
        step_instructions = {k: v for k, v in self.instructions.items() if isinstance(v, dict) or k in named_nodes}
        for k in step_instructions:
            del self.instructions[k]
        # Create the child nodes
        for name, instructions in step_instructions.items():
            children.append(build(instructions, self.attributes, name))
        return children 
[docs]
class CalibrationNode(BaseNode):
    # Order for list of adjustable parameters and default values
    adj_defaults = {
        "lower_bound": 0.1,
        "upper_bound": 10.0,
        "starting_y_factor": None,
    }
    # Order for list of measurable parameters and default values
    meas_defaults = {
        "weight": 1.0,
        "metric": "fractional",
        "cal_start": -np.inf,
        "cal_end": np.inf,
    }
    @staticmethod
    def parse_list(l, defaults):
        # Routine to parse list of arguments into a dictionary of values
        d = {}
        # convert number strings back to numerical values
        for i, e in enumerate(l.copy()):
            try:
                l[i] = float(e)
            except ValueError:
                pass
        for k, v in zip(list(defaults.keys())[: len(l)], l):
            d[k] = v
        return d
[docs]
    def validate(self):
        """
        Pre-parse calibration inputs
        """
        def separate_keys(keys_str: str) -> list:
            """
            Separate inputs that kave been defined together as one key in the YAML file but actually represent multiple
            parameters.
            :param str keys_str: Unprocessed input key from the YAML file.
            :return list : A list of strings, each of which represents a single key.
            """
            in_brackets = False
            brackets_str = ""
            nobrackets_str = ""
            separated_keys = []
            for ch in keys_str:
                if ch == "(":
                    in_brackets = True
                    continue
                elif ch == ")":
                    in_brackets = False
                    separated_keys.append(brackets_str)
                    brackets_str = ""
                    continue
                if in_brackets:
                    brackets_str += ch
                else:
                    if ch == ",":
                        if nobrackets_str == " " or nobrackets_str == "":
                            nobrackets_str = ""
                            continue
                        else:
                            separated_keys.append(nobrackets_str)
                            nobrackets_str = ""
                    else:
                        nobrackets_str += ch
            if nobrackets_str != "" and nobrackets_str != " ":
                separated_keys.append(nobrackets_str)
            return [x.strip() for x in separated_keys]
        def process_key(key: str) -> tuple:
            """
            Sanitize the key name, separating the parameter codename from the optional population name/s.
            :par str key: Key representing an adjustable or measurable parameter. It can also contain one or two
            population names (two in the case of a transfer), separated from the parameter name by a comma.
            :returns: A tuple of the parameter codename and population name/s. Population defaults to None
            if not specified.
            EXAMPLES:
                INPUT: 'b_rate'
                OUTPUT: (b_rate, None)
                INPUT: 'b_rate, 0-4'
                OUTPUT: (b_rate, 0-4)
                INPUT: 'aging, 0-4, 5-14'
                OUTPUT: (aging, 0-4, 5-14)
            """
            if "," in key:
                return tuple([x.strip() for x in key.strip("() ").split(",") if x])
            else:
                return (key.strip(), None)
        def process_list(l: list) -> (tuple, list):
            """
            Process list-format inputs and separate them into a key (par_name, pop_name tuple) and a value (list of
            parameter settings).
            @param l: List representing the settings for one parameter.
            @return:    tuple key: Tuple of the parameter codename and the population (default None).
                        list value: List of calibration settings for this parameter.
            EXAMPLES:
                INPUT: [b_rate, 0.1, 10]
                OUTPUT: (b_rate, None) ; [0.1, 10]
                INPUT: [(b_rate, 0-4), 0.1, 10, 1.5]
                OUTPUT: (b_rate, 0-4) ; [0.1, 10, 1.5]
            """
            if len(l) == 1:
                # if the list is already just one string, return that string as key with None pop and vals
                return (l[0].strip("() "), None), None
            elif "(" in str(l):
                # separate out the parenthesis contents as the par/pop/s,
                # then output the key (par, pop tuple) and value
                # process keys
                s = str(l).strip("[] ").replace("'", "")
                s1 = re.findall(r"\(.*?\)", s)
                key = process_key(s1[0].replace("(", "").replace(")", ""))
                # process values/settings
                value = s.replace(s1[0], "").strip(", ").split(",")
                value = [x.strip(", ") for x in value if x]
                return key, value
            else:
                key = process_key(l[0])
                value = l[1:]
                return key, value
        def process_inputs(inputs, defaults: dict) -> dict:
            """
            Process adjustables and measurables, which can be specified as a string, list or nested dict representation.
            * In string representation, only the parameter name is specified, and the default settings are used.
            * In list representation, the input is a list of lists, where the first item in each list is the parameter
            (with optional population) and the remaining items are the supported arguments for the input type, in the
            order defined by the defaults dictionary.
            *In dict representation, the key is the quantity with optional population, and the value can either be
            a list (in the order defined by the dictionary) or a dictionary explicitly naming the inputs.
            This function returns a flat dictionary with {(quantity, pop_name):{argument:value}} e.g., {('b_rate','0-5'):{'lower_bound':0.5}}.
            In the dict representation, the key can be a comma separated list of quantities with optional values e.g.,
            'b_rate 0-5, d_rate'.
            In the list representation, multiple quantities are not supported (as a comma is already used to separate
            the arguments), but multiple lists (one for each quantity) can be provided.
            #TODO could add examples
            """
            out = {}
            if sc.isstring(inputs):
                # Support a comma separated string with "quantity pop" specifications of adjustables and measurables
                # In this case, default values should be used for all other items. Proceed by splitting into a list
                inputs = inputs.split(",")
            if isinstance(inputs, (tuple, list)):
                for l in inputs:
                    l = sc.promotetolist(l)
                    keyspops, v = process_list(l)
                    # process key
                    if len(keyspops) == 2:
                        key, pop_name = keyspops
                    else:
                        assert len(keyspops) == 3, f"Number of populations must be 0, 1 or 2."
                        key = f"{keyspops[0]}_from_{keyspops[1]}"
                        pop_name = keyspops[2]
                    # process value
                    if v is None:
                        value = defaults
                    else:
                        value = self.parse_list(v, defaults)
                    out[key, pop_name] = sc.mergedicts(out.get((key, pop_name), {}), value)
            elif isinstance(inputs, dict):
                for keys, v in inputs.items():
                    separated_keys = separate_keys(keys)
                    for key in separated_keys:
                        # separate par name from pop name
                        keyspops = process_key(key.strip())
                        if len(keyspops) == 2:
                            key, pop_name = keyspops
                        else:
                            assert len(keyspops) == 3, f"Number of populations must be 0, 1 or 2."
                            key = f"{keyspops[0]}_from_{keyspops[1]}"
                            pop_name = keyspops[2]
                        # process values
                        if isinstance(v, (tuple, list)):
                            value = self.parse_list(v, defaults)
                        elif v is None:
                            value = defaults
                        else:
                            value = v.copy()
                        # add keys and values to outputs dict
                        out[key, pop_name] = sc.mergedicts(out.get((key, pop_name), {}), value)
            return out
        self["adjustables"] = process_inputs(self["adjustables"], self.adj_defaults)
        self["measurables"] = process_inputs(self["measurables"], self.meas_defaults)
        def check_optional_number(key, v, defaults):
            if key in v and v[key] is not None:
                if not sc.isnumber(v[key], isnan=False):
                    raise TypeError(f"Adjustable argument {key} needs to be a number or None (defaults to {defaults[key]}). Provided value: {v[key]} ")
        # Validate adjustables
        assert len(self["adjustables"]) > 0, f"Cannot calibrate with no adjustables for calibration section {self.name}"
        for (quantity, pop_name), v in self["adjustables"].items():
            assert "pop_name" not in v, f'Setting the population name through "pop_name: {v["pop_name"]}" is not supported. Instead, the name of the adjustable quantity should include the population name ("{quantity} {v["pop_name"]}")'
            assert isinstance(quantity, str), f"Adjustable codename {quantity} needs to be a string"
            assert pop_name is None or isinstance(pop_name, str), f"Adjustable population {pop_name} needs to be a string or None (defaults to all populations for that parameter)"
            check_optional_number("lower_bound", v, self.adj_defaults)
            check_optional_number("upper_bound", v, self.adj_defaults)
            check_optional_number("starting_y_factor", v, self.adj_defaults)
        # Validate measurables
        assert len(self["measurables"]) > 0, f"Cannot calibrate with no measurables for calibration section {self.name}"
        for (quantity, pop_name), v in self["measurables"].items():
            assert isinstance(quantity, str), f"Measurable codename {quantity} needs to be a string"
            assert pop_name is None or isinstance(pop_name, str), f"Adjustable population {pop_name} needs to be a string or None (defaults to all populations for that parameter)"
            assert "metric" not in v or v["metric"] is None or isinstance(v["metric"], str), f"Measurable metric {v['metric']} needs to be a number or None (defaults to 'fractional')"
            check_optional_number("weight", v, self.meas_defaults)
            check_optional_number("cal_start", v, self.meas_defaults)
            check_optional_number("cal_end", v, self.meas_defaults) 
[docs]
    def apply(self, project: at.Project, parset: at.ParameterSet, n: int, *args, quiet=False, compare_results=False, **kwargs) -> ParameterSet:
        step_name = self.name
        attributes = self.attributes
        at.logger.info(f"Calibrating adjustable(s) {set([adj[0] for adj in attributes['adjustables']])} to match measurable(s) {set([mea[0] for mea in attributes['measurables']])}...")
        # Expand adjustables
        adjustables = {}
        par_names = {x[0] for x in attributes["adjustables"]}.intersection(x.name for x in parset.all_pars())
        pop_names = {x[1] for x in attributes["adjustables"]}.intersection({*parset.pop_names} | {"all", None})
        adj_defaults = {k: self.attributes[k] if k in self.attributes else self.adj_defaults[k] for k in self.adj_defaults}
        for par_name, pop_name in attributes["adjustables"]:
            if par_name not in par_names:
                at.logger.warning(f"Extra YAML adjustable parameter '{par_name}' does not exist in this project's framework/parset and will be ignored")
                continue
            elif pop_name not in pop_names:
                at.logger.warning(f"Extra YAML adjustable population '{pop_name}' does not exist in this project's databook and will be ignored")
                continue
            if pop_name is None:
                pops = parset.pop_names
            else:
                pops = sc.promotetolist(pop_name)
            for pop in pops:
                d = sc.mergedicts(adj_defaults, attributes["adjustables"].get((par_name, None), None), attributes["adjustables"].get((par_name, pop), None))
                adjustables[(par_name, pop)] = (d["lower_bound"], d["upper_bound"], d["starting_y_factor"])
        adjustables = [(*k, *v) for k, v in adjustables.items()]
        # Expand measurables
        measurables = {}
        par_names = {x[0] for x in attributes["measurables"]}.intersection(x.name for x in parset.all_pars())  # TODO: This is probably OK for now but will need to validate that pars have databook entries in the future
        pop_names = {x[1] for x in attributes["measurables"]}.intersection({*parset.pop_names} | {None})
        meas_defaults = {k: self.attributes[k] if k in self.attributes else self.meas_defaults[k] for k in self.meas_defaults}
        for par_name, pop_name in attributes["measurables"]:
            if par_name not in par_names:
                at.logger.warning(f"Extra YAML measurable variable '{par_name}' does not exist in this project's framework and will be ignored")
                continue
            elif pop_name not in pop_names:
                if not (pop_name.lower() == "total" and pop_name.lower() in {x.lower() for x in project.data.tdve[par_name].ts.keys()}):
                    at.logger.warning(f"Extra YAML measurable population '{pop_name}' does not exist in this project's databook and will be ignored")
                    continue
            if pop_name is None:
                pops = parset.pop_names
            else:
                pops = sc.promotetolist(pop_name)
            for pop in pops:
                d = sc.mergedicts(meas_defaults, attributes["measurables"].get((par_name, None), None), attributes["measurables"].get((par_name, pop), None))
                measurables[(par_name, pop)] = (d["weight"], d["metric"], d["cal_start"], d["cal_end"])
        measurables = [(*k, *v) for k, v in measurables.items()]
        # Calibration
        if len(adjustables):
            # note: attributes = instructions + context
            kwargs = sc.mergedicts(self.attributes, kwargs)
            del kwargs["adjustables"]  # supplied via the adjustables variable
            del kwargs["measurables"]  # supplied via the measurables variable
            if "repeats" in kwargs:
                del kwargs["repeats"]
            if quiet:
                with at.Quiet(show_warnings=False):
                    new_cal_parset = at.calibrate(project, parset, adjustables, measurables, **kwargs)
            else:
                new_cal_parset = at.calibrate(project, parset, adjustables, measurables, **kwargs)
        else:
            new_cal_parset = parset
        at.logger.info(f'Completed "{step_name}"...')
        made_changes = False
        for par, pop, *_ in adjustables:
            if pop == "all":
                old = parset.get_par(par).meta_y_factor
                new = new_cal_parset.get_par(par).meta_y_factor
            else:
                old = parset.get_par(par).y_factor[pop]
                new = new_cal_parset.get_par(par).y_factor[pop]
            if new != old:
                at.logger.info(f"...adjusted the y-factor for {par} in {pop} from {old} to {new}")
                made_changes = True
            else:
                at.logger.debug(f"...did NOT adjust the y-factor for {par} in {pop} from {old} to {new}")
        if not made_changes:
            at.logger.info(f"...made no changes!")
        if compare_results:
            base_res = project.run_sim(parset=parset)
            cal_res = project.run_sim(parset=new_cal_parset)
            for par_name in [par_measure[0] for par_measure in measurables]:
                base_rms_error = 0
                cal_rms_error = 0
                for pop in parset.pars[par_name].ts.keys():
                    for time_par_ind, time_value in enumerate(parset.pars[par_name].ts[pop].t):
                        data_time_val = parset.pars[par_name].ts[pop].vals[time_par_ind]
                        base_res_time_ind = list(base_res.get_variable(par_name, pop)[0].t).index(time_value)
                        base_time_val = base_res.get_variable(par_name, pop)[0].vals[base_res_time_ind]
                        cal_res_time_ind = list(cal_res.get_variable(par_name, pop)[0].t).index(time_value)  # probably redundant as they *should* be the same
                        cal_time_val = cal_res.get_variable(par_name, pop)[0].vals[cal_res_time_ind]
                        base_rms_error += (data_time_val - base_time_val) ** 2
                        cal_rms_error += (data_time_val - cal_time_val) ** 2
                        sf = at.get_sigfigs_necessary(base_time_val, cal_time_val)
                        at.logger.info(f"...for parameter {par_name} and population {pop} at time {time_value} the data value was {sc.sigfig(data_time_val, sf)}, the baseline value was {sc.sigfig(base_time_val, sf)}, and the calibrated value was {sc.sigfig(cal_time_val, sf)}.")
                base_rms_error = base_rms_error**0.5
                cal_rms_error = cal_rms_error**0.5
                sf = at.get_sigfigs_necessary(base_rms_error, cal_rms_error)
                at.logger.info(f"...RMS error for parameter {par_name} has changed from baseline {sc.sigfig(base_rms_error, sf)} to calibrated {sc.sigfig(cal_rms_error, sf)}")
        return new_cal_parset 
 
[docs]
class InitializationNode(BaseNode):
    _name = "set_initialization"
    def __init__(self, instructions, context, name):
        new_instructions = {}
        if isinstance(instructions, dict):
            new_instructions.update(instructions)
        elif type(instructions) is int:
            new_instructions.update({"init_year": instructions})
        elif isinstance(instructions, (tuple, list)):
            sc.promotetolist(instructions)
            new_instructions.update({"init_year": instructions[0]})
            if len(instructions) > 1:
                new_instructions.update({"constant_parset": instructions[1]})
        super().__init__(new_instructions, context, name)
[docs]
    def validate(self):
        assert "init_year" in self, f"Initialization year must be specified"
        assert sc.isnumber(self["init_year"]), f'Initialization year {self["init_year"]} must be numeric.'
        if "constant_parset" in self:
            assert isinstance(self["constant_parset"], int), f'Constant parset (optional) {self["constant_parset"]} must be numeric (boolean, or int to specify constant parset year).' 
[docs]
    def apply(self, project: at.Project, parset: at.ParameterSet, n: int, *args, **kwargs) -> ParameterSet:
        p2 = sc.dcp(parset)
        if "constant_parset" in self:
            if self["constant_parset"] == False:
                pass
            elif self["constant_parset"] == True:
                p2 = parset.make_constant(year=project.settings.sim_start)
            elif sc.isnumber(self["constant_parset"]):  # constant parset year was provided
                p2 = parset.make_constant(year=self["constant_parset"])
        new_settings = sc.dcp(project.settings)
        new_settings.update_time_vector(end=self["init_year"])
        res = at.run_model(settings=new_settings, framework=project.framework, parset=p2)
        parset.set_initialization(res, self["init_year"])
        return parset 
 
[docs]
class ClearInitializationNode(BaseNode):
    _name = "clear_initialization"
    def __init__(self, instructions, context, name):
        super().__init__(instructions=None, context=context, name=name)
[docs]
    def apply(self, project: at.Project, parset: at.ParameterSet, n: int, *args, **kwargs) -> ParameterSet:
        parset.initialization = None
        return parset 
 
[docs]
class SaveCalibrationNode(BaseNode):
    """
    Block in YAML file with "save calibration: <file name>"
    """
    _name = "save_calibration"
    def __init__(self, instructions, context, name):
        if not isinstance(instructions, dict):
            instructions = {"fname": instructions}
        super().__init__(instructions, context, name)
[docs]
    def validate(self):
        assert self["fname"] is not None, 'A "save calibration" node must have a file name explicitly specified' 
[docs]
    def apply(self, project: at.Project, parset: at.ParameterSet, savedir=None, *args, **kwargs) -> ParameterSet:
        parset.save_calibration(savedir / self["fname"])
        return parset