| |
| import contextlib |
| from typing import AsyncIterator, Dict, Sequence |
|
|
| import torch |
| from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, serialize_torch_tensor |
| from hivemind.moe.server.connection_handler import ConnectionHandler |
| from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE |
| from hivemind.proto import runtime_pb2 |
| from hivemind.utils import as_aiter |
| from hivemind.utils.asyncio import anext |
| from hivemind.utils.streaming import split_for_streaming |
|
|
| from src.data_structures import CHAIN_DELIMITER, ModuleUID |
| from src.server.backend import MAX_LENGTH, TransformerBackend |
|
|
|
|
| class TransformerConnectionHandler(ConnectionHandler): |
| """Handles three request types: forward, backward and forward-incremental (inference)""" |
|
|
| module_backends: Dict[ModuleUID, TransformerBackend] |
|
|
| def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]): |
| super().__init__(dht, module_backends) |
| for module_backend in self.module_backends.values(): |
| assert isinstance(module_backend, TransformerBackend) |
|
|
| async def rpc_inference( |
| self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext |
| ) -> AsyncIterator[runtime_pb2.ExpertRequest]: |
| """Compute a single step of inference using attention cache; update attention cache accordingly.""" |
| try: |
| print("OPENED RPC_INFERENCE") |
| request = await anext(requests) |
| requested_uids = self._check_header(request) |
| requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) |
|
|
| cache_metadata = torch.tensor([[-1, -1]], dtype=torch.int64) |
| prefix_length = 0 |
|
|
| async with self._allocate_caches(requested_backends) as cache_handles: |
| assert len(cache_handles) == len(requested_backends) |
| while request.tensors: |
| hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors] |
|
|
| |
| for backend, cache_handle in zip(requested_backends, cache_handles): |
| cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, prefix_length |
| assert ( |
| len(hidden_states) == 1 and hidden_states[0].ndim == 3 |
| ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" |
|
|
| hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states) |
| assert isinstance(hidden_states, (list, tuple)) |
| assert len(hidden_states) == 1 and hidden_states[0].ndim == 3 |
|
|
| |
| yield runtime_pb2.ExpertResponse( |
| tensors=[ |
| serialize_torch_tensor(result, proto.compression, allow_inplace=True) |
| for result, proto in zip( |
| hidden_states, nested_flatten(requested_backends[-1].outputs_schema) |
| ) |
| ] |
| ) |
|
|
| |
| prefix_length += hidden_states[0].shape[1] |
| request = await (anext(requests)) |
| finally: |
| print("CLOSED RPC_INFERENCE") |
|
|
| async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: |
| |
| hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors] |
| requested_uids = self._check_header(request) |
| requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) |
|
|
| |
| for backend in requested_backends: |
| assert isinstance(hidden_states, (list, tuple)) |
| assert ( |
| len(hidden_states) == 1 and hidden_states[0].ndim == 3 |
| ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" |
| hidden_states = await backend.forward_pool.submit_task(*hidden_states) |
|
|
| |
| assert len(hidden_states) == 1 and hidden_states[0].ndim == 3 |
| return runtime_pb2.ExpertResponse( |
| tensors=[ |
| serialize_torch_tensor(result, proto.compression, allow_inplace=True) |
| for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema)) |
| ] |
| ) |
|
|
| async def rpc_forward_stream( |
| self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext |
| ) -> AsyncIterator[runtime_pb2.ExpertRequest]: |
| |
| uids_header, hidden_states = await self._gather_inputs(requests, context) |
| requested_uids = self._check_header_str(uids_header) |
| requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) |
|
|
| |
| for backend in requested_backends: |
| assert isinstance(hidden_states, (list, tuple)) |
| assert ( |
| len(hidden_states) == 1 and hidden_states[0].ndim == 3 |
| ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" |
| hidden_states = await backend.forward_pool.submit_task(*hidden_states) |
|
|
| |
| assert len(hidden_states) == 1 and hidden_states[0].ndim == 3 |
| serialized_output = [ |
| serialize_torch_tensor(result, proto.compression, allow_inplace=True) |
| for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema)) |
| ] |
|
|
| |
| output_split = [ |
| part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) |
| ] |
| async for part in as_aiter(*output_split): |
| yield runtime_pb2.ExpertResponse(tensors=[part]) |
|
|
| async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: |
| |
| inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors] |
| requested_uids = self._check_header(request) |
| requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) |
|
|
| |
| |
| inter_inputs = [inputs] |
| for backend in requested_backends[:-1]: |
| assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" |
| inputs = await backend.forward_pool.submit_task(inputs) |
| assert isinstance(inputs, (list, tuple)) and len(inputs) == 1 |
| inputs = inputs[0] |
| inter_inputs.append(inputs) |
|
|
| |
| for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]): |
| inputs_and_grads = [inp, grads] |
| grads = await backend.backward_pool.submit_task(*inputs_and_grads) |
| assert isinstance(grads, (list, tuple)) and len(grads) == 1 |
| grads = grads[0] |
|
|
| |
| return runtime_pb2.ExpertResponse( |
| tensors=[ |
| serialize_torch_tensor(result, proto.compression, allow_inplace=True) |
| for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema)) |
| ] |
| ) |
|
|
| async def rpc_backward_stream( |
| self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext |
| ) -> AsyncIterator[runtime_pb2.ExpertResponse]: |
| uids_header, inputs_and_grads = await self._gather_inputs(requests, context) |
| inputs, grads = inputs_and_grads |
| requested_uids = self._check_header_str(uids_header) |
| requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) |
|
|
| |
| |
| inter_inputs = [inputs] |
| for backend in requested_backends[:-1]: |
| assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" |
| inputs = await backend.forward_pool.submit_task(inputs) |
| assert isinstance(inputs, (list, tuple)) and len(inputs) == 1 |
| inputs = inputs[0] |
| inter_inputs.append(inputs) |
|
|
| |
| for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]): |
| inputs_and_grads = [inp, grads] |
| grads = await backend.backward_pool.submit_task(*inputs_and_grads) |
| assert isinstance(grads, (list, tuple)) and len(grads) == 1 |
| grads = grads[0] |
|
|
| |
| serialized_grad_inputs = [ |
| serialize_torch_tensor(result, proto.compression, allow_inplace=True) |
| for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema)) |
| ] |
| |
| output_split = [ |
| part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) |
| ] |
|
|
| async for part in as_aiter(*output_split): |
| yield runtime_pb2.ExpertResponse(tensors=[part]) |
|
|
| def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]: |
| """Check that the first request to rpc_inference is valid""" |
| uids = (request.uid or "").split(CHAIN_DELIMITER) |
| if not uids: |
| raise RuntimeError("User did not provide any uids") |
| for uid in uids: |
| if uid not in self.module_backends: |
| raise RuntimeError(f"Remote peer does not serve {uid}") |
| return tuple(uids) |
|
|
| def _check_header_str(self, header) -> Sequence[ModuleUID]: |
| """Check that the first request to rpc_inference is valid""" |
| uids = (header or "").split(CHAIN_DELIMITER) |
| if not uids: |
| raise RuntimeError("User did not provide any uids") |
| for uid in uids: |
| if uid not in self.module_backends: |
| raise RuntimeError(f"Remote peer does not serve {uid}") |
| return tuple(uids) |
|
|
| @contextlib.asynccontextmanager |
| async def _allocate_caches(self, backends: Sequence[TransformerBackend]) -> Sequence[int]: |
| """Allocate memory caches for each transformer block, return cache handles""" |
| async with contextlib.AsyncExitStack() as stack: |
| handles = [] |
| for backend in backends: |
| num_heads = backend.module.self_attention.num_heads |
| head_dim = backend.module.self_attention.head_dim |
|
|
| cache_descriptor = TensorDescriptor(size=(2, 1, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32) |
| |
|
|
| handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor))) |
|
|
| yield handles |
|
|