Tutorial: Expressing code-transformation space#

We provide an example of code-transformation space for batched matrix-vector multiplication kernels as:

import feinsum as f
import loopy as lp
import numpy as np

from feinsum.tuning import IntParameter, einsum_arg, transform_param


@einsum_arg("N", lambda ensm: ensm.arg_shapes[0][1])
@transform_param("n_x_per_wg", lambda ensm: IntParameter(2, 32))
@transform_param("n_wi_per_x", lambda ensm: IntParameter(1, ensm.arg_shapes[0][1]))
def transform(t_unit, N, n_x_per_wg, n_wi_per_x,
              insn_match=None, kernel_name=None):
    ref_einsum = f.einsum("xik,kj->xij",
                          f.array((np.inf, N, N), np.float64),
                          f.array((N, N), np.float64),
                          arg_names=["A", "B"])

    subst_map = f.match_t_unit_to_einsum(t_unit, ref_einsum, insn_match=insn_match)

    vng = t_unit.default_entrypoint.get_var_name_generator()
    x = subst_map["x"]
    i = subst_map["i"]
    j = subst_map["j"]
    k = subst_map["k"]
    B = subst_map["B"]

    Bprftch_k, Bprftch_j = vng("Bprftch_k"), vng("Bprftch_j")
    x_outer = vng(f"{x}_outer")

    # Notation:
    #    - The einsum computes 'out[x, i, j]'.
    #
    # Transform space:
    #   - Each work-group computes the slice:
    #     'out[group_id(0)*n_x_per_wg:(group_id(0)+1)*n_x_per_wg,:,:]'
    #   - Each work_item computes the slice:
    #     'out[group_id(0)+n_x_per_wg*local_id(1),local_id(0)::nwi_per_x,:]'

    # Loop Transformations: Work Division
    # -----------------------------------
    t_unit = lp.split_iname(t_unit, x, n_x_per_wg,
                            outer_iname=x_outer,
                            inner_tag="l.1", outer_tag="g.0")

    t_unit = lp.split_iname(t_unit, i, n_wi_per_x,
                            inner_tag="l.0")

    # Data Transformations: Precompute 'B'
    # ------------------------------------
    t_unit = lp.precompute(t_unit, B,
                           sweep_inames={k, j},
                           precompute_outer_inames=frozenset([x_outer]),
                           precompute_inames=[Bprftch_k, Bprftch_j],
                           default_tag="l.auto")

    return t_unit

The main parts of the script include:

  • transform function that implements the parametric transformations.

  • einsum_arg decorator that parametrizes the sub-class of einstein summations it targets.

  • transform_param decorator that prescribes the bounds of the transform space.

  • match_t_unit_to_einsum re-interprets the relevant expresions in the translation unit as per the provided reference batched einsum and provides a substitution mapping from the einsum’s entities to the kernel’s entities as:

    • An einsum’s index corresponds to one of the inames in the translation unit

    • An einsum’s argument corresponds to one the substitution rules in the translation unit

We ask feinsum to record facts about the transform space as:

import pyopencl as cl

cl_ctx = cl.create_some_context()

my_einsum = f.einsum("xik,kj->xij",
                     f.array((np.inf, 35, 35), np.float64),
                     f.array((np.inf, 35, 35), np.float64))
f.autotune(my_einsum, cl_ctx=cl_ctx, module_path=__file__)

Using the above script, kernel for my_einsum is tuned with the previously defined transform-space. feinsum leverages OpenTuner to traverse the search space.

On running the code we should see messages printed to stderr as:

$ python main.py

[     2s]    INFO feinsum.tuning: {'n_x_per_wg': 26, 'n_wi_per_x': 12}
[     3s]    INFO feinsum.measure: Statistically verified the soundness of the transformation
[     6s]    INFO feinsum.tuning:
╒═════════╀═════════════════╀═════════════════╕
β”‚ Dtype   β”‚ Measured GOps/s β”‚ Roofline GOps/s β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ float64 β”‚ 22.3            β”‚ 42.1            β”‚
β•˜β•β•β•β•β•β•β•β•β•β•§β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•§β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•›
[     8s]    INFO feinsum.tuning: {'n_x_per_wg': 8, 'n_wi_per_x': 19}
[     9s]    INFO feinsum.measure: Statistically verified the soundness of the transformation
[    11s]    INFO feinsum.tuning:
╒═════════╀═════════════════╀═════════════════╕
β”‚ Dtype   β”‚ Measured GOps/s β”‚ Roofline GOps/s β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ float64 β”‚ 29.3            β”‚ 42.1            β”‚
β•˜β•β•β•β•β•β•β•β•β•β•§β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•§β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•›
[    13s]    INFO feinsum.tuning: {'n_x_per_wg': 10, 'n_wi_per_x': 26}
[    14s]    INFO feinsum.measure: Statistically verified the soundness of the transformation
[    17s]    INFO feinsum.tuning:
╒═════════╀═════════════════╀═════════════════╕
β”‚ Dtype   β”‚ Measured GOps/s β”‚ Roofline GOps/s β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚ float64 β”‚ 19.3            β”‚ 42.1            β”‚
β•˜β•β•β•β•β•β•β•β•β•β•§β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•§β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•›