from .cpp import cpp_binding
from . import dataset
from . import column
from . import selection
from .node import query, result
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.tree import Tree
from rich.rule import Rule
from rich.text import Text
# ==========================================
# Configured Analysis
# ==========================================
class _custom_analysis:
def __init__(self, func, args, kwargs):
self.func = func
self.args = args
self.kwargs = kwargs
def __ror__(self, df):
# ignore returned, always return df
self.func(df, *self.args, **self.kwargs)
return df
def output(self, df):
# transparently return back out whatever the results are
return self.func(df, *self.args, **self.kwargs)
# ==========================================
# Definition Objects (Decorators)
# ==========================================
class _analysis:
def __init__(self, func):
self.func = func
def __call__(self, *args, **kwargs):
return _custom_analysis(self.func, args, kwargs)
[docs]
class dataflow(cpp_binding):
"""
qtypy layer for ``qty::dataflow``.
Provides a high-level interface for defining, selecting, and querying columns
from datasets, with optional multi-threaded execution.
Parameters
----------
multithread : bool, optional
Enable multithreading (default is False).
n_threads : int, optional
Number of threads to use for processing (default is -1, meaning all available threads).
n_rows : int, optional
Number of rows to process from the dataset (default is -1, meaning all rows).
Attributes
----------
dataset : object
Dataset attached to the dataflow. Initially `None`.
columns : dict
Mapping of column names to column definitions.
selections : dict
Mapping of selection names to selection objects.
Examples
--------
>>> df = dataflow() # use all available threads
>>> df = dataflow(n_threads=8) # use (up to) 8 threads
>>> df = dataflow(enable_mt=False) # single-threaded
"""
def __init__(self, *, n_threads: int = -1, n_rows: int = -1, enable_mt : bool = True):
super().__init__()
self.enable_mt = enable_mt
self.n_threads = n_threads
self.n_rows = n_rows
self.dataset = None
self.columns = {}
self.selections = {}
self.current_selection = self
self.current_selection_name = None
self.queries = []
self._instantiate()
@property
def cpp_initialization(self):
return f"""qty::dataflow(qty::multithread::enable({self.enable_mt * self.n_threads}),qty::dataset::head({self.n_rows}))"""
def input(self, ds):
ds._contextualize(self)
return self
[docs]
def compute(self, columns: dict):
"""
Define additional columns in the dataflow.
Parameters
----------
columns : dict
A dictionary mapping column names (strings) to one of the following:
- ``qtypy.dataset.column``
Existing quantity in dataset.
- ``qtypy.column.constant``
Constant value of any C++ data type.
- ``qtypy.column.expression``
JIT-compiled one-line C++ expression.
- ``qtypy.column.definition``
Compiled C++ implementation of
``qty::column::definition<Ret(Args...)>``.
Returns
-------
self
Enables method chaining.
"""
for column_name, column_node in columns.items():
column_node._contextualize(self, column_name)
column_node._instantiate()
return self
def filter(self, cuts: dict):
last_selection = self.current_selection_name
for cut_name, cut_expression in cuts.items():
cut_node = selection.cut(cut_expression)
cut_node._contextualize(self, cut_name)
cut_node._instantiate()
last_selection = cut_name
return _dataflow_at_selection(self, last_selection)
def weight(self, weights: dict):
last_selection = None
for wgt_name, wgt_expression in weights.items():
wgt_node = selection.weight(wgt_expression)
wgt_node._contextualize(self, wgt_name)
wgt_node._instantiate()
last_selection = wgt_name
return _dataflow_at_selection(self, last_selection)
def at(self, selection_name: str):
if selection_name not in self.selections:
raise KeyError(f"selection '{selection_name}' not found in dataflow.")
return _dataflow_at_selection(self, selection_name)
def get(self, query_definition_or_custom_analysis):
# # issue new lazy<query> node everytime so existing definitions can be recycled later
# query_node = query(query_definition)
# query_node._contextualize(self)
# # return the (not yet instantiated) result node
# return result(query_node)
return query_definition_or_custom_analysis.output(self)
def __setitem__(self, column_name, column_node):
if isinstance(column_node, str):
column_node = column.expression(column_node)
column_node._contextualize(self, column_name)
column_node._instantiate()
# DSL syntax
__lshift__ = input
__matmul__ = at
__and__ = filter
__mul__ = weight
__rshift__ = get
analysis = _analysis
def output(self, custom_analysis):
return custom_analysis.output(self)
def explain(self):
console = Console()
# ------------------------------------------------
# DATASET BLOCK
# ------------------------------------------------
dataset_panel = Table.grid(padding=(0,1))
dataset_panel.add_row(f"[grey]{self.dataset}[/grey]")
# horizontal divider
dataset_panel.add_row(Rule(style="dim"))
# show dataset columns if they exist
dataset_columns = {
name : col
for name, col in self.columns.items()
if isinstance(col, dataset.column)
}
if dataset_columns:
for name, col in dataset_columns.items():
dataset_panel.add_row(f"[grey]{name}{col}[/grey]")
console.print(
Panel(
dataset_panel,
title="[bold]dataset[/bold]",
title_align="left",
expand=False
)
)
# ------------------------------------------------
# COLUMN COMPUTATION BLOCK
# ------------------------------------------------
computation_panel = Table.grid(padding=(0,1))
for name, col in self.columns.items():
# skip dataset-native columns
if isinstance(col, dataset.column):
continue
if name in self.selections:
continue
computation_panel.add_row(Text(f"{name} := {col}", style="blue"))
console.print(
Panel(
computation_panel,
title="[bold]observable[/bold]",
title_align="left",
expand=False
)
)
# ------------------------------------------------
# SELECTION TREE BLOCK
# ------------------------------------------------
# root node with a neutral symbol '*'
cutflow_tree = Tree("[blue]*[/blue]")
if not self.selections:
cutflow_tree.add("[dim]none[/dim]")
console.print(
Panel(
cutflow_tree,
title="[bold]cutflow[/bold]",
title_align="left",
expand=False
)
)
return
# --------------------------------------------
# Build parent → children map
# --------------------------------------------
children = {}
roots = []
for name, sel in self.selections.items():
prev_name = getattr(sel, "prev_name", None)
if prev_name is None:
roots.append(name)
else:
children.setdefault(prev_name, []).append(name)
# --------------------------------------------
# Recursive builder
# --------------------------------------------
def build(node, name):
sel = self.selections[name]
branch = node.add(f"@ {name} → {sel}")
for child in children.get(name, []):
build(branch, child)
for root in roots:
build(cutflow_tree, root)
# wrap the tree in a panel
console.print(
Panel(
cutflow_tree,
title="[bold]cutflow[/bold]",
title_align="left",
expand=False
)
)
# ------------------------------------------------
# QUERY TREE BLOCK
# ------------------------------------------------
# container grid (like your dataset block)
query_panel = Table.grid(padding=(0, 1))
# group queries by booked_selection_name
queries_by_selection = {}
for q in self.queries:
sel_name = getattr(q, "booked_selection_name", None)
queries_by_selection.setdefault(sel_name, []).append(q)
# create flat roots (no hierarchy)
for sel_name in self.selections.keys():
sel_tree = Tree(f"[blue]@ {sel_name}[/blue]")
queries = queries_by_selection.get(sel_name, [])
if not queries:
continue
for q in queries:
sel_tree.add(Text(f"{q.defn}", style="red")) # relies on __str__()
query_panel.add_row(sel_tree)
console.print(
Panel(
query_panel,
title="[bold]query[/bold]",
title_align="left",
expand=False
)
)
class _dataflow_at_selection:
"""
Restricted view when operating at a specific selection.
Only the following operations are allowed:
- filter (&)
- weight (*)
- get (>>)
"""
_allowed_methods = {"filter", "weight", "get"}
def __init__(self, df, selection):
self._df = df
self._selection_name = selection
def _activate(self):
self._df.current_selection_name = self._selection_name
self._df.current_selection = self._df.selections[self._selection_name]
# Explicitly allow only specific methods
def __getattr__(self, name):
if name in self._allowed_methods:
orig_method = getattr(self._df, name)
def wrapper(*args, **kwargs):
self._activate()
return orig_method(*args, **kwargs)
return wrapper
raise AttributeError(
f"Operation '{name}' is not allowed at a specific selection. "
f"Only {self._allowed_methods} are permitted."
)
def __and__(self, cuts):
return self.filter(cuts)
def __mul__(self, weights):
return self.weight(weights)
def __rshift__(self, query_definition):
return self.get(query_definition)
def __eq__(self, other):
return (self._df == other._df) and (self._selection_name == other._selection_name)