Tutorial: Expressing code-transformation space
Tutorial: Expressing code-transformation space#
We provide an example of code-transformation space for batched matrix-vector multiplication kernels as:
import feinsum as f
import loopy as lp
import numpy as np
from feinsum.tuning import IntParameter, einsum_arg, transform_param
@einsum_arg("N", lambda ensm: ensm.arg_shapes[0][1])
@transform_param("n_x_per_wg", lambda ensm: IntParameter(2, 32))
@transform_param("n_wi_per_x", lambda ensm: IntParameter(1, ensm.arg_shapes[0][1]))
def transform(t_unit, N, n_x_per_wg, n_wi_per_x,
insn_match=None, kernel_name=None):
ref_einsum = f.einsum("xik,kj->xij",
f.array((np.inf, N, N), np.float64),
f.array((N, N), np.float64),
arg_names=["A", "B"])
subst_map = f.match_t_unit_to_einsum(t_unit, ref_einsum, insn_match=insn_match)
vng = t_unit.default_entrypoint.get_var_name_generator()
x = subst_map["x"]
i = subst_map["i"]
j = subst_map["j"]
k = subst_map["k"]
B = subst_map["B"]
Bprftch_k, Bprftch_j = vng("Bprftch_k"), vng("Bprftch_j")
x_outer = vng(f"{x}_outer")
# Notation:
# - The einsum computes 'out[x, i, j]'.
#
# Transform space:
# - Each work-group computes the slice:
# 'out[group_id(0)*n_x_per_wg:(group_id(0)+1)*n_x_per_wg,:,:]'
# - Each work_item computes the slice:
# 'out[group_id(0)+n_x_per_wg*local_id(1),local_id(0)::nwi_per_x,:]'
# Loop Transformations: Work Division
# -----------------------------------
t_unit = lp.split_iname(t_unit, x, n_x_per_wg,
outer_iname=x_outer,
inner_tag="l.1", outer_tag="g.0")
t_unit = lp.split_iname(t_unit, i, n_wi_per_x,
inner_tag="l.0")
# Data Transformations: Precompute 'B'
# ------------------------------------
t_unit = lp.precompute(t_unit, B,
sweep_inames={k, j},
precompute_outer_inames=frozenset([x_outer]),
precompute_inames=[Bprftch_k, Bprftch_j],
default_tag="l.auto")
return t_unit
The main parts of the script include:
transformfunction that implements the parametric transformations.einsum_argdecorator that parametrizes the sub-class of einstein summations it targets.transform_paramdecorator that prescribes the bounds of the transform space.match_t_unit_to_einsumre-interprets the relevant expresions in the translation unit as per the provided reference batched einsum and provides a substitution mapping from the einsumβs entities to the kernelβs entities as:An einsumβs index corresponds to one of the inames in the translation unit
An einsumβs argument corresponds to one the substitution rules in the translation unit
We ask feinsum to record facts about the transform space as:
import pyopencl as cl
cl_ctx = cl.create_some_context()
my_einsum = f.einsum("xik,kj->xij",
f.array((np.inf, 35, 35), np.float64),
f.array((np.inf, 35, 35), np.float64))
f.autotune(my_einsum, cl_ctx=cl_ctx, module_path=__file__)
Using the above script, kernel for my_einsum is tuned with the previously defined transform-space. feinsum leverages OpenTuner to traverse the search space.
On running the code we should see messages printed to stderr as:
$ python main.py
[ 2s] INFO feinsum.tuning: {'n_x_per_wg': 26, 'n_wi_per_x': 12}
[ 3s] INFO feinsum.measure: Statistically verified the soundness of the transformation
[ 6s] INFO feinsum.tuning:
βββββββββββ€ββββββββββββββββββ€ββββββββββββββββββ
β Dtype β Measured GOps/s β Roofline GOps/s β
βββββββββββΌββββββββββββββββββΌββββββββββββββββββ€
β float64 β 22.3 β 42.1 β
βββββββββββ§ββββββββββββββββββ§ββββββββββββββββββ
[ 8s] INFO feinsum.tuning: {'n_x_per_wg': 8, 'n_wi_per_x': 19}
[ 9s] INFO feinsum.measure: Statistically verified the soundness of the transformation
[ 11s] INFO feinsum.tuning:
βββββββββββ€ββββββββββββββββββ€ββββββββββββββββββ
β Dtype β Measured GOps/s β Roofline GOps/s β
βββββββββββΌββββββββββββββββββΌββββββββββββββββββ€
β float64 β 29.3 β 42.1 β
βββββββββββ§ββββββββββββββββββ§ββββββββββββββββββ
[ 13s] INFO feinsum.tuning: {'n_x_per_wg': 10, 'n_wi_per_x': 26}
[ 14s] INFO feinsum.measure: Statistically verified the soundness of the transformation
[ 17s] INFO feinsum.tuning:
βββββββββββ€ββββββββββββββββββ€ββββββββββββββββββ
β Dtype β Measured GOps/s β Roofline GOps/s β
βββββββββββΌββββββββββββββββββΌββββββββββββββββββ€
β float64 β 19.3 β 42.1 β
βββββββββββ§ββββββββββββββββββ§ββββββββββββββββββ