Shortcuts

Source code for executorch.exir.lowered_backend_module

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import copy
import operator
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.utils._pytree as pytree
from executorch.exir._serialize import _serialize_pte_binary
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name
from executorch.exir.emit import emit_program

from executorch.exir.graph_module import _get_submodule

from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
from executorch.exir.passes.spec_prop_pass import make_spec, SpecPropPass
from executorch.exir.schema import Program

from executorch.exir.tracer import Value

from torch._export.exported_program import ExportedProgram
from torch._subclasses import FakeTensor
from torch.export.exported_program import (
    ExportGraphSignature,
    InputKind,
    InputSpec,
    OutputKind,
    OutputSpec,
    TensorArgument,
)
from torch.fx.passes.utils.fuser_utils import (
    erase_nodes,
    fuse_as_graphmodule,
    insert_subgm,
    legalize_graph,
    NodeList,
    topo_sort,
)


[docs]class LoweredBackendModule(torch.nn.Module): """ A subclass of nn.Module that is generated for modules containing delegated functions. This is can be created by calling `to_backend`. """ _backend_id: str # The backend's name _processed_bytes: bytes # The delegate blobs created from backend.preprocess _compile_specs: List[ CompileSpec ] # A list of backend-specific objects with static metadata to configure the "compilation" process. _original_module: ExportedProgram # The original EXIR module def __init__( self, edge_program: ExportedProgram, backend_id: str, processed_bytes: bytes, compile_specs: List[CompileSpec], ) -> None: super().__init__() self._original_module = edge_program self._backend_id = backend_id self._processed_bytes = processed_bytes self._compile_specs = compile_specs @property def backend_id(self) -> str: """ Returns the backends name. """ return self._backend_id @property def processed_bytes(self) -> bytes: """ Returns the delegate blob created from backend.preprocess """ return self._processed_bytes @property def compile_specs(self) -> List[CompileSpec]: """ Returns a list of backend-specific objects with static metadata to configure the "compilation" process. """ return self._compile_specs @property def original_module(self) -> ExportedProgram: """ Returns the original EXIR module """ return self._original_module # TODO(chenlai): consolidate the seriailization config with serialize_to_flatbuffer api
[docs] def buffer( self, extract_segments: bool = False, segment_alignment: int = 4096, constant_tensor_alignment: Optional[int] = None, delegate_alignment: Optional[int] = None, ) -> bytes: """ Returns a buffer containing the serialized ExecuTorch binary. """ out = _serialize_pte_binary( program=self.program(), extract_segments=extract_segments, segment_alignment=segment_alignment, constant_tensor_alignment=constant_tensor_alignment, delegate_alignment=delegate_alignment, ) return out
# TODO(chenlai): re-consider recapture instead of manually constructing the program because # the meta data construction is done manually.
[docs] def program(self, emit_stacktrace: bool = False) -> Program: """ Returns the object that represents the ExecuTorch binary before serialization. """ # Creates a new module based on the original module. The original module will # look something like following: # # opcode name target args kwargs # ------------- ------------------- ---------------- ------------------------------------------ -------- # placeholder arg0_1 arg0_1 () {} # placeholder arg1_1 arg1_1 () {} # call_function aten_repeat_default * (arg1_1, [4, 1]) {} # call_function aten_mul_tensor * (aten_repeat_default, aten_repeat_default) {} # call_function aten_add_tensor * (arg1_1, arg1_1) {} # output output output ([aten_mul_tensor, aten_add_tensor],) {} # # if the whole module is lowered, the resulting lowered module look like # # opcode name target args kwargs # ------------- ------------------------ --------------------------- ---------------------------------- -------- # placeholder arg0_1 arg0_1 () {} # placeholder arg1_1 arg1_1 () {} # get_attr lowered_module_0 lowered_module_0 () {} # call_function executorch_call_delegate executorch_call_delegate (lowered_module_0, arg0_1, arg1_1) {} # call_function getitem <built-in function getitem> (executorch_call_delegate, 0) {} # call_function getitem_1 <built-in function getitem> (executorch_call_delegate, 1) {} # output output_1 output ([getitem, getitem_1],) {} # # We'll remove all call_function nodes, insert an call_delegate node, inserting getitems nodes to get the result for call_delegate node # and return the list of getitems as the output lowered_exported_program = copy.deepcopy(self.original_module) # The real input nodes are the ones not buffer or parameter all_input_nodes = [ node for node in lowered_exported_program.graph.nodes if ( node.op == "placeholder" and node.name not in lowered_exported_program.graph_signature.inputs_to_buffers and node.name not in lowered_exported_program.graph_signature.inputs_to_parameters ) ] output_node = [ node for node in lowered_exported_program.graph.nodes if node.op == "output" ] assert len(output_node) == 1, "There should be only one output node" # Step 1. Cleaning up the graph before inserting the call_delegate node # Remove the original output node lowered_exported_program.graph.erase_node(output_node[0]) # Remove all the everything else except the input for node in reversed(lowered_exported_program.graph.nodes): if node.op != "placeholder": lowered_exported_program.graph.erase_node(node) # Find placeholders that are parameters or buffers, remove them from the main graph for node in lowered_exported_program.graph.nodes: if node.op == "placeholder" and ( node.name in lowered_exported_program.graph_signature.inputs_to_buffers or node.name in lowered_exported_program.graph_signature.inputs_to_parameters ): lowered_exported_program.graph.erase_node(node) # Step 2. Start constructing the graph lowered_name = get_lowered_module_name( lowered_exported_program.graph_module, self ) # Insert the lowered module to the graph module as an attibute lowered_node = lowered_exported_program.graph.get_attr(lowered_name) # Insert a call_delegate node to the graph module, with arguments from the arg list delegate_node = lowered_exported_program.graph.call_function( executorch_call_delegate, (lowered_node, *all_input_nodes) ) # Get the output list. Since the output node is a tuple of list, like ([aten_mul_tensor, aten_add_tensor],) # We add some handling logic to get the list `[aten_mul_tensor, aten_add_tensor]` properly original_output_nodes = [ node for node in self.original_module.graph.nodes if node.op == "output" ][0].args[0] delegate_node.meta["spec"] = tuple( [make_spec(node.meta["val"]) for node in original_output_nodes] ) # The getitem nodes that are going to be inserted to the lowered graph module getitem_nodes = [] for i in range(len(original_output_nodes)): getitem_node = lowered_exported_program.graph.call_function( operator.getitem, args=(delegate_node, i), ) getitem_nodes.append(getitem_node) lowered_exported_program.graph.output(getitem_nodes) lowered_exported_program.graph_module.recompile() lowered_exported_program.graph.lint() # Users output will be the get items nodes instead output_specs = [ OutputSpec( kind=OutputKind.USER_OUTPUT, arg=TensorArgument(name=getitem_node.name), target=None, ) for getitem_node in getitem_nodes ] # All data are consumed by the delegates so they should be removed from the state dict. inputs_to_parameters = ( lowered_exported_program.graph_signature.inputs_to_parameters ) inputs_to_buffers = lowered_exported_program.graph_signature.inputs_to_buffers input_specs = [ InputSpec( kind=InputKind.USER_INPUT, arg=TensorArgument(name=node.name), target=None, ) for user_input in lowered_exported_program.graph_signature.user_inputs if user_input not in inputs_to_parameters and user_input not in inputs_to_buffers ] # Double check the ExportedProgram data(especially everything except graph) is good exported_program = ExportedProgram( root=lowered_exported_program.graph_module, graph=lowered_exported_program.graph, graph_signature=ExportGraphSignature( input_specs=input_specs, output_specs=output_specs ), # TODO: May need to set lowered_exported_program.call_spec = CallSpec(None, None) # somewhere as we should pass it a list of tensors to the lowered module and output a # list of tensors. Putting call_spec=lowered_exported_program.call_spec is correct here as the # inputs/outputs to the toplevel program will be in the format of the eager module. state_dict={}, # None because all data are consumed by delegate range_constraints=lowered_exported_program.range_constraints, equality_constraints=lowered_exported_program.equality_constraints, module_call_graph=lowered_exported_program.module_call_graph, ) exported_program = exported_program._transform( SpecPropPass(), MemoryPlanningPass("greedy") ) emitted_program = emit_program( exported_program, emit_stacktrace=emit_stacktrace ).program return emitted_program
# Used to patch each delegated function with a call_delegate call # @staticmethod def forward( self, *args: Value, **kwargs: Tuple[Value, ...], ) -> Value: return executorch_call_delegate(self, *args)
# TODO(zhxchen17) Try ExportPass def _fixup_output_node(gm: torch.fx.GraphModule) -> None: for node in reversed(gm.graph.nodes): if node.op == "output": with gm.graph.inserting_before(node): assert len(node.args) == 1 outputs = node.args[0] if isinstance(outputs, torch.fx.Node): val = outputs.meta.get("val") if isinstance(val, list): # If a list is returned, in some cases it is represented as a # singular node, like `split_copy_tensor` but EXIR will return a # opened-up list like `[getitem1, getitem2]` outputs = [ torch.fx.Proxy(outputs)[i].node for i in range(len(val)) ] returns, out_spec = pytree.tree_flatten(outputs) node.args = (returns,) return def arrange_graph_placeholders( gm: torch.fx.GraphModule, owning_program: ExportedProgram ) -> torch.fx.GraphModule: """ Modifies the graph of the given graphmodule with one that contains the same nodes as the original, but with placeholders in order of (Params + Buffers) (User Inputs) This is used by the delegate api which disturbs the placeholder ordering when creating a submodule from partitioned nodes Args: gm: The graph module that we want arranged owning_program: ExportedProgram that the submodule (gm) belongs to Returns: The graph module in-placed arranged """ new_graph = torch.fx.Graph() node_map = {} # mapping of nodes from old graph to new graph graph_sign = owning_program.graph_signature # Add all placeholders into the graph first: param_nodes = [] buffer_nodes = [] input_nodes = [] for node in gm.graph.nodes: if node.op != "placeholder": continue if node.name in graph_sign.inputs_to_parameters: param_nodes.append(node) elif node.name in graph_sign.inputs_to_buffers: buffer_nodes.append(node) else: input_nodes.append(node) for param_node in param_nodes: new_node = new_graph.node_copy(param_node, lambda x: node_map[x]) node_map[param_node] = new_node for buffer_node in buffer_nodes: new_node = new_graph.node_copy(buffer_node, lambda x: node_map[x]) node_map[buffer_node] = new_node for input_node in input_nodes: new_node = new_graph.node_copy(input_node, lambda x: node_map[x]) node_map[input_node] = new_node # Now add all the other nodes in order for node in gm.graph.nodes: if node.op == "placeholder": continue new_node = new_graph.node_copy(node, lambda x: node_map[x]) node_map[node] = new_node # lint to ensure correctness new_graph.lint() new_graph._codegen = gm.graph._codegen gm.graph = new_graph return gm # TODO Don't regenerate new signature manually. def _get_new_signature( original_program: ExportedProgram, gm: torch.fx.GraphModule ) -> Tuple[ExportGraphSignature, Dict[str, Union[torch.Tensor, torch.nn.Parameter]]]: old_signature = original_program.graph_signature input_specs = [] output_specs = [] new_signature = ExportGraphSignature( input_specs=input_specs, output_specs=output_specs ) new_state_dict = {} for node in gm.graph.nodes: if node.op == "placeholder": if node.name in old_signature.inputs_to_parameters: parameter_name = old_signature.inputs_to_parameters[node.name] # add param to graph signature input_specs.append( InputSpec( kind=InputKind.PARAMETER, arg=TensorArgument(name=node.name), target=parameter_name, ) ) # add param to state_dict new_state_dict[parameter_name] = original_program.state_dict[ parameter_name ] elif node.name in old_signature.inputs_to_buffers: buffer_name = old_signature.inputs_to_buffers[node.name] # add buffer to graph signature input_specs.append( InputSpec( kind=InputKind.BUFFER, arg=TensorArgument(name=node.name), target=buffer_name, ) ) # add param to new_state_dict new_state_dict[buffer_name] = original_program.state_dict[buffer_name] else: # not param or buffer then user input input_specs.append( InputSpec( kind=InputKind.USER_INPUT, arg=TensorArgument(name=node.name), target=None, ) ) if node.op == "output": for output in node.all_input_nodes: output_specs.append( OutputSpec( kind=OutputKind.USER_OUTPUT, arg=TensorArgument(name=output.name), target=None, ) ) return new_signature, new_state_dict def create_exported_program_from_submodule( submodule: torch.fx.GraphModule, owning_program: ExportedProgram, ) -> ExportedProgram: """ Creates an ExportedProgram from the given submodule using the parameters and buffers from the top-level owning program Args: submodule: submodule to create and exported program from owning_program: exported program containing the parameters and buffers used within the submodule Returns: The ExportedProgram created from submodule """ # Arrange the submodule's placeholders in order submodule = arrange_graph_placeholders(submodule, owning_program) # Get updated graph signature subgraph_signature, subgraph_state_dict = _get_new_signature( owning_program, submodule ) return ExportedProgram( root=submodule, graph=submodule.graph, graph_signature=subgraph_signature, state_dict=subgraph_state_dict, range_constraints=copy.deepcopy(owning_program.range_constraints), equality_constraints=[], module_call_graph=[], ) def create_submodule_from_nodes( gm: torch.fx.GraphModule, node_list: NodeList, tag: str, skip_legalize_graph: bool = False, ) -> Tuple[torch.fx.GraphModule, torch.fx.Node]: """ Modifies the given graph module in-place to separate out the given nodes into a submodule. The given node_list should form a fully connected subgraph. Args: gm: The graph module that we want to partition node_list: A list of nodes that belong in the partition Returns: The submodule that has been partitioned, the call_module node in the toplevel graph module calling the submodule """ sorted_nodes = topo_sort(node_list) submodule_name = "fused_" + tag sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule( gm, sorted_nodes, submodule_name ) _fixup_output_node(sub_gm) gm = insert_subgm(gm, sub_gm, orig_inputs, orig_outputs) submodule_node = None for node in gm.graph.nodes: if node.op == "call_module": if node.target == submodule_name: submodule_node = node else: raise RuntimeError( f"The submodule created with nodes {node_list} did not form \ one fully contained subgraph. Check that these nodes form a \ fully contained graph. Partitioned graph: {gm.graph}." ) if len(orig_outputs) == 1 and isinstance(orig_outputs[0].meta["val"], FakeTensor): # If the original output is a single tensor, it has been # pytree.tree_flatten-ed to be a singleton list, so we want to replace # all uses with a getitem call to the 0th index of the result with gm.graph.inserting_after(submodule_node): proxy_out = torch.fx.Proxy(submodule_node)[0].node # type: ignore[index] submodule_node.replace_all_uses_with(proxy_out) proxy_out.meta["val"] = submodule_node.meta["val"] # Reset the args since it was overwritten in the previous line proxy_out.args = (submodule_node, 0) else: # fuse_as_graphmodule will automatically propagate the metadata of the # partition's last node to the getitem nodes that appear after the # call_module node. However, in the case of delegation we do not want # these getitem nodes to contain irrelevant previous metadata # (ex. source_fn, # nn_module_stack) for user_node in submodule_node.users: user_node.meta.pop("nn_module_stack", None) user_node.meta.pop("source_fn_stack", None) erase_nodes(gm, sorted_nodes) # Topological sort original gm with newly created sub_gm # TODO : T153794167 Get rid of support for skipping legalize graph in create_submodule_from_nodes # once we transition to using fuse_by_partitions. if not skip_legalize_graph: legalize_graph(gm) # Get the call_module node submodule_node = None for node in gm.graph.nodes: if node.op == "call_module" and node.target == submodule_name: submodule_node = node elif node.op == "call_module": raise RuntimeError( f"The submodule created with nodes {node_list} did not form \ one fully contained subgraph. Check that these nodes form a \ fully contained graph. Partitioned graph: {gm.graph}." ) assert ( submodule_node is not None ), f"No submodule was created with the nodes {node_list} in the graph {gm.graph}" return sub_gm, submodule_node def get_lowered_submodules( graph_module: torch.fx.GraphModule, ) -> List[Tuple[str, LoweredBackendModule, torch.fx.Node]]: """ Returns a list of lowered modules that are in the given graph (does not look into submodules). Specifically, the returned value is a list containing a tuple of (name of the lowered module that's stored in the graph module, the lowered module itself, and the fx node that called this lowered module). """ lowered_submodules = [] for node in graph_module.graph.nodes: if node.op == "call_function" and node.target == executorch_call_delegate: name, module, node = _get_submodule(graph_module, node, 0) assert isinstance(module, LoweredBackendModule) lowered_submodules.append((name, module, node)) return lowered_submodules def get_lowered_backend_modules( graph_module: torch.fx.GraphModule, ) -> List[LoweredBackendModule]: """ Returns a list of exported programs which were lowered by backen delegates """ lowered_programs = [] for node in graph_module.graph.nodes: if node.op == "call_function" and node.target == executorch_call_delegate: lowered_backend_module = getattr(graph_module, node.args[0].name) lowered_programs.append(lowered_backend_module) return lowered_programs

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources