Source code for pycalphad.core.workspace

import warnings
from collections import OrderedDict, Counter, defaultdict
from copy import copy
from pycalphad.property_framework.computed_property import JanssonDerivative
import pycalphad.variables as v
from pycalphad.core.utils import unpack_species, unpack_condition, unpack_phases, filter_phases, instantiate_models
from pycalphad import calculate
from pycalphad.core.starting_point import starting_point
from pycalphad.codegen.phase_record_factory import PhaseRecordFactory
from pycalphad.core.eqsolver import _solve_eq_at_conditions
from pycalphad.core.composition_set import CompositionSet
from pycalphad.core.solver import Solver, SolverBase
from pycalphad.core.light_dataset import LightDataset
from pycalphad.model import Model
import numpy as np
import numpy.typing as npt
from typing import Optional, Tuple, Type
from pycalphad.io.database import Database
from pycalphad.variables import Species, StateVariable
from pycalphad.core.conditions import Conditions, ConditionError
from pycalphad.property_framework import ComputableProperty, as_property
from pycalphad.property_framework.units import unit_conversion_context, ureg, as_quantity, Q_
from runtype import isa
from runtype.pytypes import Dict, List, Sequence, SumType, Mapping, NoneType
from typing import TypeVar



def _adjust_conditions(conds) -> OrderedDict[StateVariable, List[float]]:
    "Adjust conditions values to be in the implementation units of the quantity, and within the numerical limit of the solver."
    new_conds = OrderedDict()
    minimum_composition = 1e-10
    for key, value in sorted(conds.items(), key=str):
        key = as_property(key)
        # If conditions have units, convert to impl units and strip them
        if isinstance(value, Q_):
            value = value.to(key.implementation_units).magnitude
        if isinstance(key, v.MoleFraction):
            vals = unpack_condition(value)
            # "Zero" composition is a common pattern. Do not warn for that case.
            if np.any(np.logical_and(np.asarray(vals) < minimum_composition, np.asarray(vals) > 0)):
                warnings.warn(
                    f"Some specified compositions are below the minimum allowed composition of {minimum_composition}.")
            new_conds[key] = [min(max(val, minimum_composition), 1-minimum_composition) for val in vals]
        else:
            new_conds[key] = unpack_condition(value)
        if getattr(key, 'display_units', '') != '':
            new_conds[key] = Q_(new_conds[key], units=key.display_units).to(key.implementation_units)
    return new_conds

[docs] class ComponentList:
[docs] @classmethod def cast_from(cls, s: Sequence[SumType([str, v.Component])]) -> "ComponentList": # no Database, so we don't get Species lookup support here, it's implemented in ComponentsField return v.unpack_components(s)
[docs] class ConstituentsList:
[docs] @classmethod def cast_from(cls, s: Sequence) -> "ConstituentsList": return sorted(Species.cast_from(x) for x in s)
[docs] class PhaseList:
[docs] @classmethod def cast_from(cls, s: SumType([str, Sequence[str]])) -> "PhaseList": if isinstance(s, str): s = [s] return sorted(PhaseName.cast_from(x) for x in s)
[docs] class PhaseName:
[docs] @classmethod def cast_from(cls, s: str) -> "PhaseName": return s.upper()
[docs] class ConditionValue:
[docs] @classmethod def cast_from(cls, value: SumType([float, Sequence[float]])) -> "ConditionValue": return unpack_condition(value)
[docs] class ConditionKey:
[docs] @classmethod def cast_from(cls, key: SumType([str, StateVariable])) -> "ConditionKey": return as_property(key)
[docs] class TypedField: """ A descriptor for managing attributes with specific types in a class, supporting automatic type coercion and default values. This class is designed to be used in scenarios (like `Workspace`) where one needs to implement an observer pattern. It enables the tracking of changes in attribute values and notifies dependent attributes of any updates. """ def __init__(self, default_factory=None, depends_on=None): """ Attributes ---------- default_factory : callable, optional A callable that returns the default value of the attribute when no initial value is provided. depends_on : list of str, optional A list of attribute names, from the parent object, that the current attribute depends on. Changes to these attributes will trigger updates to the current attribute. """ self.default_factory = default_factory self.depends_on = depends_on def __set_name__(self, owner, name): "Initializes the attribute, determining its private and public names and registering dependency callbacks if necessary." self.type = owner.__annotations__.get(name, None) self.public_name = name self.private_name = '_' + name if self.depends_on is not None: for dependency in self.depends_on: owner._callbacks[dependency].append(self.on_dependency_update) def __set__(self, obj, value): "Sets the value of the attribute in an object, handling type coercion via the `cast_from` method if the direct assignment isn't possible. It raises `TypeError` if coercion fails." if (self.type != NoneType) and not isa(value, self.type) and value is not None: value = self.type.cast_from(value) elif value is None and self.default_factory is not None: value = self.default_factory(obj) oldval = getattr(obj, self.private_name, None) setattr(obj, self.private_name, value) for cb in obj._callbacks[self.public_name]: cb(obj, self.public_name, oldval, value) def __get__(self, obj, objtype=None): "Retrieves the value of the attribute, initializing it with default_factory if it hasn't been set before." if not hasattr(obj, self.private_name): if self.default_factory is not None: default_value = self.default_factory(obj) setattr(obj, self.private_name, default_value) return getattr(obj, self.private_name)
[docs] def on_dependency_update(self, obj, updated_attribute, old_val, new_val): "A callback method that can be overridden to define custom behavior when a dependent attribute is updated." if obj._suspend_dependency_updates: return
[docs] class ComponentsField(TypedField): def __init__(self, depends_on=None): get_pure_element_components = lambda obj: v.unpack_components(sorted(x for x in obj.database.elements if x != '/-'), obj.database) super().__init__(default_factory=get_pure_element_components, depends_on=depends_on) def __set__(self, obj, value): comps = sorted(v.unpack_components(value, obj.database)) super().__set__(obj, comps) def __get__(self, obj, objtype=None): getobj = super().__get__(obj, objtype=objtype) return sorted(v.unpack_components(getobj, obj.database))
[docs] def on_dependency_update(self, obj, updated_attribute, old_val, new_val): if obj._suspend_dependency_updates: return self.__set__(obj, self.default_factory(obj))
[docs] class ConstituentsField(TypedField): def __init__(self, depends_on=None): super().__init__(default_factory=lambda obj: unpack_species(obj.database, sorted(x.name for x in obj.database.species if x.name != '/-')), depends_on=depends_on) def __set__(self, obj, value): constituents = sorted(unpack_species(obj.database, value)) super().__set__(obj, constituents) def __get__(self, obj, objtype=None): getobj = super().__get__(obj, objtype=objtype) return sorted(unpack_species(obj.database, getobj))
[docs] def on_dependency_update(self, obj, updated_attribute, old_val, new_val): if obj._suspend_dependency_updates: return self.__set__(obj, unpack_species(obj.database, obj.components))
[docs] class PhasesField(TypedField): def __init__(self, depends_on=None): super().__init__(default_factory=lambda obj: filter_phases(obj.database, obj.components), depends_on=depends_on) def __set__(self, obj, value): phases = sorted(unpack_phases(value)) super().__set__(obj, phases) def __get__(self, obj, objtype=None): getobj = super().__get__(obj, objtype=objtype) return filter_phases(obj.database, obj.components, getobj)
[docs] class DictField(TypedField):
[docs] def get_proxy(self, obj): class DictProxy: @staticmethod def unwrap(): return TypedField.__get__(self, obj) def __getattr__(pxy, name): getobj = TypedField.__get__(self, obj) if getobj == pxy: raise ValueError('Proxy object points to itself') return getattr(getobj, name) def __getitem__(pxy, item): return TypedField.__get__(self, obj).get(item) def __iter__(pxy): return TypedField.__get__(self, obj).__iter__() def __setitem__(pxy, item, value): # we are careful not to mutate the original, by making a copy; if we don't, it will break cache invalidation conds = copy(TypedField.__get__(self, obj)) conds[item] = value self.__set__(obj, conds) def update(pxy, new_conds): # we are careful not to mutate the original, by making a copy; if we don't, it will break cache invalidation conds = copy(TypedField.__get__(self, obj)) conds.update(new_conds) self.__set__(obj, conds) def __delitem__(pxy, item): # we are careful not to mutate the original, by making a copy; if we don't, it will break cache invalidation conds = copy(TypedField.__get__(self, obj)) del conds[item] self.__set__(obj, conds) def __len__(pxy): return len(TypedField.__get__(self, obj)) def __str__(pxy): return str(TypedField.__get__(self, obj)) def __repr__(pxy): return repr(TypedField.__get__(self, obj)) return DictProxy()
def __get__(self, obj, objtype=None): return self.get_proxy(obj)
[docs] class ConditionsField(DictField): def __set__(self, obj, value): conds = Conditions(obj) for k, v in value.items(): conds[k] = v super().__set__(obj, conds)
[docs] class ModelsField(DictField): def __init__(self, depends_on=None): super().__init__(default_factory=lambda obj: instantiate_models(obj.database, obj.components, obj.phases, model=None, parameters=obj.parameters), depends_on=depends_on) def __set__(self, obj, value): # Unwrap proxy objects before being stored if hasattr(value, 'unwrap'): value = value.unwrap() try: # Expand specified Model type into a dict of instances value = instantiate_models(obj.database, obj.components, obj.phases, model=value, parameters=obj.parameters) super().__set__(obj, value) except AttributeError: super().__set__(obj, None)
[docs] def on_dependency_update(self, obj, updated_attribute, old_val, new_val): if obj._suspend_dependency_updates: return self.__set__(obj, self.default_factory(obj))
[docs] class PRFField(TypedField): def __init__(self, depends_on=None): def make_prf(obj): try: prf = PhaseRecordFactory(obj.database, obj.components, obj.conditions, obj.models.unwrap() if hasattr(obj.models, 'unwrap') else obj.models, parameters=obj.parameters) return prf except AttributeError: return None super().__init__(default_factory=make_prf, depends_on=depends_on)
[docs] def on_dependency_update(self, obj, updated_attribute, old_val, new_val): if obj._suspend_dependency_updates: return # changes in conditions values (as opposed to keys) do not affect the PhaseRecordFactory if updated_attribute == 'conditions' and (old_val is not None) and \ (list(old_val.keys()) == list(new_val.keys())): return self.__set__(obj, self.default_factory(obj))
[docs] class SolverField(TypedField):
[docs] def on_dependency_update(self, obj, updated_attribute, old_val, new_val): if obj._suspend_dependency_updates: return self.__set__(obj, self.default_factory(obj))
[docs] class EquilibriumCalculationField(TypedField): def __get__(self, obj, objtype=None): if (not hasattr(obj, self.private_name)) or (getattr(obj, self.private_name) is None): try: default_value = obj.recompute() except AttributeError: default_value = None setattr(obj, self.private_name, default_value) return getattr(obj, self.private_name)
[docs] def on_dependency_update(self, obj, updated_attribute, old_val, new_val): if obj._suspend_dependency_updates: return self.__set__(obj, None)
# Defined to allow type checking for Model or its subclasses ModelType = TypeVar('ModelType', bound=Model) # TODO: enable converting v.X conditions for v.Component objects into # LinearCombination conditions for components with more than one constituent
[docs] class Workspace: _callbacks = defaultdict(lambda: []) database: Database = TypedField(lambda _: None) components: ComponentList = ComponentsField(depends_on=['database']) constituents: ConstituentsList = ConstituentsField(depends_on=['database', 'components']) phases: PhaseList = PhasesField(depends_on=['database', 'components']) conditions: Conditions = ConditionsField(lambda wks: Conditions(wks), depends_on=['components', 'models']) verbose: bool = TypedField(lambda _: False) models: Mapping[PhaseName, ModelType] = ModelsField(depends_on=['phases', 'parameters']) parameters: SumType([NoneType, Dict]) = DictField(lambda _: OrderedDict()) phase_record_factory: Optional[PhaseRecordFactory] = PRFField(depends_on=['phases', 'conditions', 'models', 'parameters']) calc_opts: SumType([NoneType, Dict]) = DictField(lambda _: OrderedDict()) solver: SolverBase = SolverField(lambda obj: Solver(verbose=obj.verbose), depends_on=['verbose']) # eq is set by a callback in the EquilibriumCalculationField (TypedField) eq: Optional[LightDataset] = EquilibriumCalculationField(depends_on=['phase_record_factory', 'conditions', 'calc_opts', 'solver']) def __init__(self, *args, **kwargs): self._suspend_dependency_updates = True self._eq = None # manually initialized since we don't initialize the public name 'eq' (see below) # Assume positional arguments are specified in class typed-attribute definition order for arg, attrname in zip(args, ['database', 'components', 'phases', 'conditions']): kwargs[attrname] = arg attributes = list(self.__annotations__.keys()) # avoid unnecessary work by initializing in a graph-optimal order (least-dependent first) # don't include 'eq' in the init order to avoid an expensive code path in the partially initialized case init_order = ['database', 'verbose', 'parameters', 'calc_opts', 'solver', 'components', 'constituents', 'phases', 'models', 'conditions', 'phase_record_factory'] for kwarg_name in init_order: if kwarg_name in kwargs.keys(): setattr(self, kwarg_name, kwargs[kwarg_name]) else: # trigger default constructor (which is allowed to fail) getattr(self, kwarg_name) for kwarg_name, kwarg_val in kwargs.items(): if kwarg_name in init_order: continue if kwarg_name not in attributes: raise ValueError(f'{kwarg_name} is not a Workspace attribute') setattr(self, kwarg_name, kwarg_val) self._suspend_dependency_updates = False
[docs] def recompute(self): # Assumes implementation units from this point unitless_conds = OrderedDict((key, as_quantity(key, value).to(key.implementation_units).magnitude) for key, value in self.conditions.items()) str_conds = OrderedDict((str(key), value) for key, value in unitless_conds.items()) local_conds = {key: as_quantity(key, value).to(key.implementation_units).magnitude for key, value in self.conditions.items() if getattr(key, 'phase_name', None) is not None} state_variables = self.phase_record_factory.state_variables self.phase_record_factory.update_parameters(self.parameters.unwrap()) # 'calculate' accepts conditions through its keyword arguments grid_opts = self.calc_opts.copy() statevar_strings = [str(x) for x in state_variables] grid_opts.update({key: value for key, value in str_conds.items() if key in statevar_strings}) if 'pdens' not in grid_opts: grid_opts['pdens'] = 60 grid = calculate(self.database, self.components, self.phases, model=self.models.unwrap(), fake_points=True, phase_records=self.phase_record_factory, output='GM', parameters=self.parameters.unwrap(), to_xarray=False, conditions=local_conds, **grid_opts) properties = starting_point(unitless_conds, state_variables, self.phase_record_factory, grid) return _solve_eq_at_conditions(properties, self.phase_record_factory, grid, list(unitless_conds.keys()), state_variables, self.verbose, solver=self.solver)
def _detect_phase_multiplicity(self): multiplicity = {k: 0 for k in sorted(self.phase_record_factory.keys())} prop_GM_values = self.eq.GM prop_Phase_values = self.eq.Phase for index in np.ndindex(prop_GM_values.shape): cur_multiplicity = Counter() for phase_name in prop_Phase_values[index]: if phase_name == '' or phase_name == '_FAKE_': continue cur_multiplicity[phase_name] += 1 for key, value in cur_multiplicity.items(): multiplicity[key] = max(multiplicity[key], value) return multiplicity def _expand_property_arguments(self, args: Sequence[ComputableProperty]): "Mutates args" multiplicity = self._detect_phase_multiplicity() indices_to_delete = [] i = 0 while i < len(args): if hasattr(args[i], 'phase_name') and args[i].phase_name == '*': indices_to_delete.append(i) phase_names = sorted(self.phase_record_factory.keys()) additional_args = args[i].expand_wildcard(phase_names=phase_names) args.extend(additional_args) elif hasattr(args[i], 'sublattice_index') and args[i].sublattice_index == '*': # We need to resolve sublattice_index before species to ensure we # get the correct set of phase constituents for each sublattice indices_to_delete.append(i) sublattice_indices = sorted(set([x.sublattice_index for x in self.phase_record_factory[args[i].phase_name].variables])) additional_args = args[i].expand_wildcard(sublattice_indices=sublattice_indices) args.extend(additional_args) elif hasattr(args[i], 'species') and args[i].species.name == '*': indices_to_delete.append(i) internal_to_phase = hasattr(args[i], 'sublattice_index') if internal_to_phase: components = [x.species for x in self.phase_record_factory[args[i].phase_name].variables if x.sublattice_index == args[i].sublattice_index] else: components = [comp for comp in self.components if comp != v.Component('VA')] # TODO: Special case for vacancy additional_args = args[i].expand_wildcard(components=components) args.extend(additional_args) elif isinstance(args[i], JanssonDerivative): numerator_args = [args[i].numerator] self._expand_property_arguments(numerator_args) denominator_args = [args[i].denominator] self._expand_property_arguments(denominator_args) if (len(numerator_args) > 1) or (len(denominator_args) > 1): for n_arg in numerator_args: for d_arg in denominator_args: args.append(JanssonDerivative(n_arg, d_arg)) indices_to_delete.append(i) else: # This is a concrete ComputableProperty if hasattr(args[i], 'phase_name') and (args[i].phase_name is not None) \ and not ('#' in args[i].phase_name) and multiplicity[args[i].phase_name] > 1: # Miscibility gap detected; expand property into multiple composition sets additional_phase_names = [args[i].phase_name+'#'+str(multi_idx+1) for multi_idx in range(multiplicity[args[i].phase_name])] indices_to_delete.append(i) additional_args = args[i].expand_wildcard(phase_names=additional_phase_names) args.extend(additional_args) i += 1 # Watch deletion order! Indices will change as items are deleted for deletion_index in reversed(indices_to_delete): del args[deletion_index] @property def ndim(self) -> int: _ndim = 0 for cond_val in self.conditions.values(): if len(cond_val) > 1: _ndim += 1 return _ndim
[docs] def enumerate_composition_sets(self): if self.eq is None: return prop_GM_values = self.eq.GM prop_Y_values = self.eq.Y prop_NP_values = self.eq.NP prop_Phase_values = self.eq.Phase conds_keys = [str(k) for k in self.eq.coords.keys() if k not in ('vertex', 'component', 'internal_dof')] state_variables = list(self.phase_record_factory.values())[0].state_variables str_state_variables = [str(k) for k in state_variables] for index in np.ndindex(prop_GM_values.shape): cur_conds = OrderedDict(zip(conds_keys, [np.asarray(self.eq.coords[b][a], dtype=np.float64) for a, b in zip(index, conds_keys)])) state_variable_values = [cur_conds[key] for key in str_state_variables] state_variable_values = np.array(state_variable_values) composition_sets = [] for phase_idx, phase_name in enumerate(prop_Phase_values[index]): if phase_name == '' or phase_name == '_FAKE_': continue # phase_name can be a numpy.str_, which is different from the builtin str phase_record = self.phase_record_factory[str(phase_name)] sfx = prop_Y_values[index + np.index_exp[phase_idx, :phase_record.phase_dof]] phase_amt = prop_NP_values[index + np.index_exp[phase_idx]] compset = CompositionSet(phase_record) compset.update(sfx, phase_amt, state_variable_values) composition_sets.append(compset) yield index, composition_sets
[docs] def get_composition_sets(self): if self.ndim != 0: raise ConditionError('get_composition_sets() can only be used for point (0-D) calculations. Use enumerate_composition_sets() instead.') return next(self.enumerate_composition_sets())[1]
@property def condition_axis_order(self): str_conds_keys = [str(k) for k in self.eq.coords.keys() if k not in ('vertex', 'component', 'internal_dof')] conds_keys = [None] * len(str_conds_keys) for k in self.conditions.keys(): cond_idx = str_conds_keys.index(str(k)) # unit-length dimensions will be 'squeezed' out if len(self.eq.coords[str(k)]) > 1: conds_keys[cond_idx] = k return [c for c in conds_keys if c is not None]
[docs] def get_dict(self, *args: Tuple[ComputableProperty]): args = list(map(as_property, args)) self._expand_property_arguments(args) arg_units = {arg: (ureg.Unit(getattr(arg, 'implementation_units', '')), ureg.Unit(getattr(arg, 'display_units', ''))) for arg in args} arr_size = self.eq.GM.size results = dict() prop_MU_values = self.eq.MU str_conds_keys = [str(k) for k in self.eq.coords.keys() if k not in ('vertex', 'component', 'internal_dof')] conds_keys = [None] * len(str_conds_keys) for k in self.conditions.keys(): cond_idx = str_conds_keys.index(str(k)) conds_keys[cond_idx] = k local_index = 0 for index, composition_sets in self.enumerate_composition_sets(): cur_conds = OrderedDict(zip(conds_keys, [np.asarray(self.eq.coords[b][a], dtype=np.float64) for a, b in zip(index, str_conds_keys)])) chemical_potentials = prop_MU_values[index] for arg in args: prop_implementation_units, prop_display_units = arg_units[arg] context = unit_conversion_context(composition_sets, arg) if results.get(arg, None) is None: results[arg] = np.zeros((arr_size,) + arg.shape) results[arg][local_index, ...] = Q_(arg.compute_property(composition_sets, cur_conds, chemical_potentials), prop_implementation_units).to(prop_display_units, context).magnitude local_index += 1 # roll the dimensions of the property arrays back up conds_shape = tuple(len(self.eq.coords[str(b)]) for b in self.condition_axis_order) for arg in results.keys(): results[arg] = results[arg].reshape(conds_shape + arg.shape) return results
[docs] def get(self, *args: Tuple[ComputableProperty]): result = list(self.get_dict(*args).values()) if len(result) != 1: return result else: # For single properties, just return the result without wrapping in a list return result[0]
[docs] def copy(self): return copy(self)