"""
TreeFarm class and member functions
"""
#-----------------------------------------------------------------------------
# Copyright (c) ytree development team. All rights reserved.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file COPYING.txt, distributed with this software.
#-----------------------------------------------------------------------------
import numpy as np
import os
from yt.frontends.ytdata.utilities import \
save_as_dataset
from yt.units.yt_array import \
YTArray
from yt.utilities.parallel_tools.parallel_analysis_interface import \
_get_comm, \
parallel_objects
from treefarm.ancestry_checker import \
ancestry_checker_registry
from treefarm.ancestry_filter import \
ancestry_filter_registry
from treefarm.ancestry_short import \
ancestry_short_registry
from treefarm.halo_selector import \
selector_registry, \
clear_id_cache
from treefarm.utilities.funcs import \
ensure_dir, \
get_output_filename, \
is_sequence
from treefarm.utilities.io import \
yt_load
from treefarm.utilities.logger import \
get_pbar, \
set_parallel_logger, \
treefarmLogger as mylog
[docs]class TreeFarm(object):
r"""
TreeFarm is the merger-tree creator for Gadget FoF and Subfind
halo catalogs.
TreeFarm can be used to create a merger-tree for the full set of
halos, starting from the first catalog, or can be used to trace the
ancestry of specific halos, starting from the last catalog. The
merger-tree process will create a new set of halo catalogs,
containing necessary fields (positions, velocities, masses),
user-requested fields, and descendent IDs for each halo. These
halo catalogs can be loaded at yt datasets.
Parameters
----------
time_series : yt DatasetSeries object
A yt time-series object containing the datasets over which
the merger-tree will be calculated.
setup_function : optional, callable
A function that accepts a yt Dataset object and performs any
setup, such as adding derived fields.
Examples
--------
To create a full merger tree:
>>> import nummpy as np
>>> import yt
>>> import ytree
>>> from treefarm import TreeFarm
>>> ts = yt.DatasetSeries("data/groups_*/fof_subhalo_tab*.0.hdf5")
>>> my_tree = TreeFarm(ts)
>>> my_tree.trace_descendents("Group", filename="all_halos/")
>>> a = ytree.load("all_halos/fof_subhalo_tab_000.0.h5")
>>> m = a["particle_mass"]
>>> i = np.argmax(m)
>>> print (a.trees[i]["prog", "particle_mass").to("Msun/h"))
To create a merger tree for a specific halo or set of halos:
>>> import nummpy as np
>>> import yt
>>> import ytree
>>> from treefarm import TreeFarm
>>> ts = yt.DatasetSeries("data/groups_*/fof_subhalo_tab*.0.hdf5")
>>> ds = yt[-1]
>>> i = np.argmax(ds.r["Group", "particle_mass"].d)
>>> my_ids = ds.r["Group", "particle_identifier"][i_max]
>>> my_tree = TreeFarm(ts)
>>> my_tree.set_ancestry_filter("most_massive")
>>> my_tree.set_ancestry_short("above_mass_fraction", 0.5)
>>> my_tree.trace_ancestors("Group", my_ids, filename="my_halos/")
>>> a = ytree.load("my_halos/fof_subhalo_tab_025.0.h5")
>>> print (a[0]["prog", "particle_mass").to("Msun/h"))
"""
[docs] def __init__(self, time_series, setup_function=None):
self.ts = time_series
self.setup_function = setup_function
# set a default selector
self.set_selector("all")
# set a default ancestry checker
self.set_ancestry_checker("common_ids")
self.ancestry_filter = None
self.ancestry_short = None
self.comm = _get_comm(())
set_parallel_logger(self.comm)
[docs] def set_selector(self, selector, *args, **kwargs):
r"""
Set the method for selecting candidate halos for tracing
halo ancestry.
The default selector is "all", i.e., check every halo for a
possible match. This can be slow. The "sphere" selector
can be used to specify that only halos within some sphere
be checked.
Parameters
----------
selector : string
Name of selector.
"""
self.selector = selector_registry.find(selector, *args, **kwargs)
[docs] def set_ancestry_checker(self, ancestry_checker, *args, **kwargs):
r"""
Set the method for determing if a halo is the ancestor of
another halo.
The default method defines an ancestor as a halo where at least
50% of its particles are found in the descendent.
Parameters
----------
ancestry_checker : string
Name of checking method.
"""
self.ancestry_checker = \
ancestry_checker_registry.find(ancestry_checker, *args, **kwargs)
[docs] def set_ancestry_filter(self, ancestry_filter, *args, **kwargs):
r"""
Select a method for determining which ancestors are kept.
The kept ancestors will have their ancestries tracked. This
can be used to speed up merger-trees for targeted halos by
specifying that only the most massive ancestor be kept.
Parameters
----------
ancestry_filter : string
Name of filter method.
"""
self.ancestry_filter = \
ancestry_filter_registry.find(ancestry_filter, *args, **kwargs)
[docs] def set_ancestry_short(self, ancestry_short, *args, **kwargs):
r"""
Select a method for cutting short the ancestry search.
This can be used to speed up merger-trees for targeted halos by
specifying that the search come to an end when an ancestor with
greater than 50% of the halo's mass has been found, thereby
ensuring that the most massive halo has already been found.
Parameters
----------
ancestry_short : string
Name of short-out method.
"""
self.ancestry_short = \
ancestry_short_registry.find(ancestry_short, *args, **kwargs)
def _load_ds(self, filename, **kwargs):
"""
Load a catalog as a yt dataset and call setup function.
"""
ds = yt_load(filename, **kwargs)
if self.setup_function is not None:
self.setup_function(ds)
return ds
def _find_ancestors(self, hc, halo_type, ds2, id_store=None):
"""
Search for ancestors of a given halo.
"""
if id_store is None: id_store = []
halo_member_ids = hc[halo_type, "member_ids"].d.astype(np.int64)
candidate_ids = self.selector(hc, ds2)
ancestors = []
for candidate_id in candidate_ids:
if candidate_id in id_store: continue
candidate = ds2.halo(hc.ptype, candidate_id)
candidate_member_ids = candidate[halo_type, "member_ids"].d.astype(np.int64)
if self.ancestry_checker(halo_member_ids, candidate_member_ids):
candidate.descendent_identifier = hc.particle_identifier
ancestors.append(candidate)
if self.ancestry_short is not None and \
self.ancestry_short(hc, candidate):
break
id_store.extend([ancestor.particle_identifier
for ancestor in ancestors])
if self.ancestry_filter is not None:
ancestors = self.ancestry_filter(hc, ancestors)
return ancestors
def _find_descendent(self, hc, halo_type, ds2):
"""
Search for descendents of a given halo.
"""
halo_member_ids = hc[halo_type, "member_ids"].d.astype(np.int64)
candidate_ids = self.selector(hc, ds2)
hc.descendent_identifier = -1
for candidate_id in candidate_ids:
candidate = ds2.halo(hc.ptype, candidate_id)
candidate_member_ids = candidate[halo_type, "member_ids"].d.astype(np.int64)
if self.ancestry_checker(candidate_member_ids, halo_member_ids):
hc.descendent_identifier = candidate.particle_identifier
break
[docs] def trace_ancestors(self, halo_type, root_ids,
fields=None, filename=None):
"""
Trace the ancestry of a given set of halos.
A merger-tree for a specific set of halos will be created,
starting with the last halo catalog and moving backward.
Parameters
----------
halo_type : string
The type of halo, typically "FOF" for FoF groups or
"Subfind" for subhalos.
root_ids : integer or array of integers
The halo IDs from the last halo catalog for the
targeted halos.
fields : optional, list of strings
List of additional fields to be saved to halo catalogs.
filename : optional, string
Directory in which merger-tree catalogs will be saved.
"""
output_dir = os.path.dirname(filename)
if self.comm.rank == 0 and len(output_dir) > 0:
ensure_dir(output_dir)
all_outputs = self.ts.outputs[::-1]
ds1 = None
for i, fn2 in enumerate(all_outputs[1:]):
fn1 = all_outputs[i]
target_filename = get_output_filename(
filename, f"{_get_tree_basename(fn1)}.{0}", ".h5")
catalog_filename = get_output_filename(
filename, f"{_get_tree_basename(fn2)}.{0}", ".h5")
if os.path.exists(catalog_filename):
continue
if ds1 is None:
ds1 = self._load_ds(fn1)
ds2 = self._load_ds(fn2)
if self.comm.rank == 0:
_print_link_info(ds1, ds2)
if _get_total_halos(ds2, halo_type) == 0:
mylog.info("%s has no halos of type %s, ending." %
(ds2, halo_type))
break
if i == 0:
target_ids = root_ids
if not is_sequence(target_ids):
target_ids = np.array([target_ids])
if isinstance(target_ids, YTArray):
target_ids = target_ids.d
if target_ids.dtype != np.int64:
target_ids = target_ids.astype(np.int64)
else:
mylog.info("Loading target ids from %s.", target_filename)
ds_target = yt_load(target_filename)
target_ids = \
ds_target.r["halos",
"particle_identifier"].d.astype(np.int64)
del ds_target
id_store = []
target_halos = []
ancestor_halos = []
njobs = min(self.comm.size, target_ids.size)
pbar = get_pbar("Linking halos",
target_ids.size, parallel=True)
my_i = 0
for halo_id in parallel_objects(target_ids, njobs=njobs):
my_halo = ds1.halo(halo_type, halo_id)
target_halos.append(my_halo)
my_ancestors = self._find_ancestors(my_halo, halo_type, ds2,
id_store=id_store)
ancestor_halos.extend(my_ancestors)
my_i += njobs
pbar.update(my_i)
pbar.finish()
if i == 0:
for halo in target_halos:
halo.descendent_identifier = -1
self._save_catalog(filename, ds1, target_halos, halo_type, fields)
self._save_catalog(filename, ds2, ancestor_halos, halo_type, fields)
if len(ancestor_halos) == 0:
break
ds1 = ds2
clear_id_cache()
[docs] def trace_descendents(self, halo_type,
fields=None, filename=None):
"""
Trace the descendents of all halos.
A merger-tree for all halos will be created, starting
with the first halo catalog and moving forward.
Parameters
----------
halo_type : string
The type of halo, typically "FOF" for FoF groups or
"Subfind" for subhalos.
fields : optional, list of strings
List of additional fields to be saved to halo catalogs.
filename : optional, string
Directory in which merger-tree catalogs will be saved.
"""
output_dir = os.path.dirname(filename)
if self.comm.rank == 0 and len(output_dir) > 0:
ensure_dir(output_dir)
all_outputs = self.ts.outputs[:]
ds1 = ds2 = None
for i, fn2 in enumerate(all_outputs[1:]):
fn1 = all_outputs[i]
target_filename = get_output_filename(
filename, f"{_get_tree_basename(fn1)}.{0}", ".h5")
catalog_filename = get_output_filename(
filename, f"{_get_tree_basename(fn2)}.{0}", ".h5")
if os.path.exists(target_filename):
continue
if ds1 is None:
ds1 = self._load_ds(fn1)
ds2 = self._load_ds(fn2)
if self.comm.rank == 0:
_print_link_info(ds1, ds2)
target_halos = []
if _get_total_halos(ds1, halo_type) == 0:
self._save_catalog(filename, ds1, target_halos, halo_type, fields)
ds1 = ds2
continue
target_ids = \
ds1.r[halo_type, "particle_identifier"].d.astype(np.int64)
njobs = min(self.comm.size, target_ids.size)
pbar = get_pbar("Linking halos",
target_ids.size, parallel=True)
my_i = 0
for halo_id in parallel_objects(target_ids, njobs=njobs):
my_halo = ds1.halo(halo_type, halo_id)
target_halos.append(my_halo)
self._find_descendent(my_halo, halo_type, ds2)
my_i += njobs
pbar.update(my_i)
pbar.finish()
self._save_catalog(filename, ds1, target_halos, halo_type, fields)
ds1 = ds2
clear_id_cache()
if os.path.exists(catalog_filename):
return
if ds2 is None:
ds2 = self._load_ds(fn2)
if self.comm.rank == 0:
self._save_catalog(filename, ds2, halo_type, fields)
def _save_catalog(self, filename, ds, halos, halo_type, fields=None):
"""
Save halo catalog with descendent information.
"""
if self.comm is None:
rank = 0
else:
rank = self.comm.rank
filename = get_output_filename(
filename, f"{_get_tree_basename(ds)}.{rank}", ".h5")
if fields is None:
my_fields = []
else:
my_fields = fields[:]
default_fields = \
["particle_identifier",
"descendent_identifier",
"particle_mass"] + \
[f"particle_position_{ax}" for ax in "xyz"] + \
[f"particle_velocity_{ax}" for ax in "xyz"]
for field in default_fields:
if field not in my_fields:
my_fields.append(field)
if isinstance(halos, list):
num_halos = len(halos)
data = self._create_halo_data_lists(halos, halo_type, my_fields)
else:
num_halos = _get_total_halos(ds, halos)
data = dict((field, ds.r[halos, field].in_base())
for field in my_fields
if field != "descendent_identifier")
data["descendent_identifier"] = -1 * np.ones(num_halos)
ftypes = dict([(field, ".") for field in data])
extra_attrs = {"num_halos": num_halos,
"data_type": "halo_catalog"}
mylog.info(f"Saving catalog with {num_halos} halos to {filename}.")
save_as_dataset(ds, filename, data, field_types=ftypes,
extra_attrs=extra_attrs)
def _create_halo_data_lists(self, halos, halo_type, fields):
"""
Given a list of halo containers, return a dictionary
of field values for all halos.
"""
data = dict([(hp, []) for hp in fields])
if len(halos) > 0:
pbar = get_pbar("Gathering field data from halos",
self.comm.size*len(halos), parallel=True)
my_i = 0
for halo in halos:
for hp in fields:
data[hp].append(_get_halo_property(halo, halo_type, hp))
my_i += self.comm.size
pbar.update(my_i)
pbar.finish()
for hp in fields:
if data[hp] and hasattr(data[hp][0], "units"):
data[hp] = YTArray(data[hp]).in_base()
else:
data[hp] = np.array(data[hp])
shape = data[hp].shape
if len(shape) > 1 and shape[-1] == 1:
data[hp] = np.reshape(data[hp], shape[:-1])
return data
def _get_total_halos(ds, halo_type):
return sum([df.total_particles[halo_type]
for df in ds.index.data_files])
def _get_tree_basename(fn):
myfn = getattr(fn, "basename", fn)
return os.path.basename(myfn).split(".", 1)[0]
def _get_halo_property(halo, halo_type, halo_property):
"""
Convenience function for querying fields and
other properties from halo containers.
"""
val = getattr(halo, halo_property, None)
if val is None:
val = halo[halo_type, halo_property]
return val
def _print_link_info(ds1, ds2):
"""
Print information about linking datasets.
"""
units = {"current_time": "Gyr"}
for attr in ["basename", "current_time", "current_redshift"]:
v1 = getattr(ds1, attr)
v2 = getattr(ds2, attr)
if attr in units:
v1.convert_to_units(units[attr])
v2.convert_to_units(units[attr])
s = "Linking: %-20s = %-28s - %-28s" % (attr, v1, v2)
mylog.info(s)