How to summarize text through parallelization
LLMs can summarize and otherwise distill desired information from text, including large volumes of text. In many cases, especially when the amount of text is large compared to the size of the model's context window, it can be helpful (or necessary) to break up the summarization task into smaller components.
Map-reduce represents one class of strategies for accomplishing this. The idea is to break the text into "sub-documents", and first map each sub-document to an individual summary using an LLM. Then, we reduce or consolidate those summaries into a single global summary.
Note that the map step is typically parallelized over the input documents. This strategy is especially effective when understanding of a sub-document does not rely on preceeding context. For example, when summarizing a corpus of many, shorter documents.
LangGraph, built on top of langchain-core
, suports map-reduce workflows and is well-suited to this problem:
- LangGraph allows for individual steps (such as successive summarizations) to be streamed, allowing for greater control of execution;
- LangGraph's checkpointing supports error recovery, extending with human-in-the-loop workflows, and easier incorporation into conversational applications.
- The LangGraph implementation is straightforward to modify and extend.
Below, we demonstrate how to summarize text via a map-reduce strategy.
Load chat modelβ
Let's first load a chat model:
- OpenAI
- Anthropic
- Azure
- AWS
- Cohere
- NVIDIA
- FireworksAI
- Groq
- MistralAI
- TogetherAI
- Databricks
pip install -qU langchain-openai
import getpass
import os
os.environ["OPENAI_API_KEY"] = getpass.getpass()
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-4o-mini")
pip install -qU langchain-anthropic
import getpass
import os
os.environ["ANTHROPIC_API_KEY"] = getpass.getpass()
from langchain_anthropic import ChatAnthropic
llm = ChatAnthropic(model="claude-3-5-sonnet-20240620")
pip install -qU langchain-openai
import getpass
import os
os.environ["AZURE_OPENAI_API_KEY"] = getpass.getpass()
from langchain_openai import AzureChatOpenAI
llm = AzureChatOpenAI(
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
azure_deployment=os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"],
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
)
pip install -qU langchain-google-vertexai
# Ensure your VertexAI credentials are configured
from langchain_google_vertexai import ChatVertexAI
llm = ChatVertexAI(model="gemini-1.5-flash")
pip install -qU langchain-aws
# Ensure your AWS credentials are configured
from langchain_aws import ChatBedrock
llm = ChatBedrock(model="anthropic.claude-3-5-sonnet-20240620-v1:0",
beta_use_converse_api=True)
pip install -qU langchain-cohere
import getpass
import os
os.environ["COHERE_API_KEY"] = getpass.getpass()
from langchain_cohere import ChatCohere
llm = ChatCohere(model="command-r-plus")
pip install -qU langchain-nvidia-ai-endpoints
import getpass
import os
os.environ["NVIDIA_API_KEY"] = getpass.getpass()
from langchain_nvidia_ai_endpoints import ChatNVIDIA
llm = ChatNVIDIA(model="meta/llama3-70b-instruct")
pip install -qU langchain-fireworks
import getpass
import os
os.environ["FIREWORKS_API_KEY"] = getpass.getpass()
from langchain_fireworks import ChatFireworks
llm = ChatFireworks(model="accounts/fireworks/models/llama-v3p1-70b-instruct")
pip install -qU langchain-groq
import getpass
import os
os.environ["GROQ_API_KEY"] = getpass.getpass()
from langchain_groq import ChatGroq
llm = ChatGroq(model="llama3-8b-8192")
pip install -qU langchain-mistralai
import getpass
import os
os.environ["MISTRAL_API_KEY"] = getpass.getpass()
from langchain_mistralai import ChatMistralAI
llm = ChatMistralAI(model="mistral-large-latest")
pip install -qU langchain-openai
import getpass
import os
os.environ["TOGETHER_API_KEY"] = getpass.getpass()
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(
base_url="https://api.together.xyz/v1",
api_key=os.environ["TOGETHER_API_KEY"],
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
)
pip install -qU databricks-langchain
import getpass
import os
os.environ["DATABRICKS_TOKEN"] = getpass.getpass()
from databricks_langchain import ChatDatabricks
$os.environ["DATABRICKS_HOST"] = "https://example.staging.cloud.databricks.com/serving-endpoints"
llm = ChatDatabricks(endpoint="databricks-meta-llama-3-1-70b-instruct")
Load documentsβ
First we load in our documents. We will use WebBaseLoader to load a blog post, and split the documents into smaller sub-documents.
from langchain_community.document_loaders import WebBaseLoader
from langchain_text_splitters import CharacterTextSplitter
text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
chunk_size=1000, chunk_overlap=0
)
loader = WebBaseLoader("https://lilianweng.github.io/posts/2023-06-23-agent/")
docs = loader.load()
split_docs = text_splitter.split_documents(docs)
print(f"Generated {len(split_docs)} documents.")
Created a chunk of size 1003, which is longer than the specified 1000
``````output
Generated 14 documents.
Create graphβ
Map stepβ
Let's first define the prompt associated with the map step, and associated it with the LLM via a chain:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
map_prompt = ChatPromptTemplate.from_messages(
[("human", "Write a concise summary of the following:\\n\\n{context}")]
)
map_chain = map_prompt | llm | StrOutputParser()
Reduce stepβ
We also define a chain that takes the document mapping results and reduces them into a single output.
reduce_template = """
The following is a set of summaries:
{docs}
Take these and distill it into a final, consolidated summary
of the main themes.
"""
reduce_prompt = ChatPromptTemplate([("human", reduce_template)])
reduce_chain = reduce_prompt | llm | StrOutputParser()
Orchestration via LangGraphβ
Below we implement a simple application that maps the summarization step on a list of documents, then reduces them using the above prompts.
Map-reduce flows are particularly useful when texts are long compared to the context window of a LLM. For long texts, we need a mechanism that ensures that the context to be summarized in the reduce step does not exceed a model's context window size. Here we implement a recursive "collapsing" of the summaries: the inputs are partitioned based on a token limit, and summaries are generated of the partitions. This step is repeated until the total length of the summaries is within a desired limit, allowing for the summarization of arbitrary-length text.
We will need to install langgraph
:
pip install -qU langgraph
import operator
from typing import Annotated, List, Literal, TypedDict
from langchain.chains.combine_documents.reduce import (
acollapse_docs,
split_list_of_docs,
)
from langchain_core.documents import Document
from langgraph.constants import Send
from langgraph.graph import END, START, StateGraph
token_max = 1000
def length_function(documents: List[Document]) -> int:
"""Get number of tokens for input contents."""
return sum(llm.get_num_tokens(doc.page_content) for doc in documents)
# This will be the overall state of the main graph.
# It will contain the input document contents, corresponding
# summaries, and a final summary.
class OverallState(TypedDict):
# Notice here we use the operator.add
# This is because we want combine all the summaries we generate
# from individual nodes back into one list - this is essentially
# the "reduce" part
contents: List[str]
summaries: Annotated[list, operator.add]
collapsed_summaries: List[Document]
final_summary: str
# This will be the state of the node that we will "map" all
# documents to in order to generate summaries
class SummaryState(TypedDict):
content: str
# Here we generate a summary, given a document
async def generate_summary(state: SummaryState):
response = await map_chain.ainvoke(state["content"])
return {"summaries": [response]}
# Here we define the logic to map out over the documents
# We will use this an edge in the graph
def map_summaries(state: OverallState):
# We will return a list of `Send` objects
# Each `Send` object consists of the name of a node in the graph
# as well as the state to send to that node
return [
Send("generate_summary", {"content": content}) for content in state["contents"]
]
def collect_summaries(state: OverallState):
return {
"collapsed_summaries": [Document(summary) for summary in state["summaries"]]
}
# Add node to collapse summaries
async def collapse_summaries(state: OverallState):
doc_lists = split_list_of_docs(
state["collapsed_summaries"], length_function, token_max
)
results = []
for doc_list in doc_lists:
results.append(await acollapse_docs(doc_list, reduce_chain.ainvoke))
return {"collapsed_summaries": results}
# This represents a conditional edge in the graph that determines
# if we should collapse the summaries or not
def should_collapse(
state: OverallState,
) -> Literal["collapse_summaries", "generate_final_summary"]:
num_tokens = length_function(state["collapsed_summaries"])
if num_tokens > token_max:
return "collapse_summaries"
else:
return "generate_final_summary"
# Here we will generate the final summary
async def generate_final_summary(state: OverallState):
response = await reduce_chain.ainvoke(state["collapsed_summaries"])
return {"final_summary": response}
# Construct the graph
# Nodes:
graph = StateGraph(OverallState)
graph.add_node("generate_summary", generate_summary) # same as before
graph.add_node("collect_summaries", collect_summaries)
graph.add_node("collapse_summaries", collapse_summaries)
graph.add_node("generate_final_summary", generate_final_summary)
# Edges:
graph.add_conditional_edges(START, map_summaries, ["generate_summary"])
graph.add_edge("generate_summary", "collect_summaries")
graph.add_conditional_edges("collect_summaries", should_collapse)
graph.add_conditional_edges("collapse_summaries", should_collapse)
graph.add_edge("generate_final_summary", END)
app = graph.compile()
LangGraph allows the graph structure to be plotted to help visualize its function:
from IPython.display import Image
Image(app.get_graph().draw_mermaid_png())
Invoke graphβ
When running the application, we can stream the graph to observe its sequence of steps. Below, we will simply print out the name of the step.
Note that because we have a loop in the graph, it can be helpful to specify a recursion_limit on its execution. This will raise a specific error when the specified limit is exceeded.
async for step in app.astream(
{"contents": [doc.page_content for doc in split_docs]},
{"recursion_limit": 10},
):
print(list(step.keys()))
['generate_summary']
['generate_summary']
['generate_summary']
['generate_summary']
['generate_summary']
['generate_summary']
['generate_summary']
['generate_summary']
['generate_summary']
['generate_summary']
['generate_summary']
['generate_summary']
['generate_summary']
['generate_summary']
['collect_summaries']
['collapse_summaries']
['collapse_summaries']
['generate_final_summary']
print(step)
{'generate_final_summary': {'final_summary': 'The consolidated summary of the main themes from the provided documents highlights the advancements and applications of large language models (LLMs) in artificial intelligence, particularly in autonomous agents and software development. Key themes include:\n\n1. **Integration of LLMs**: LLMs play a crucial role in enabling autonomous agents to perform complex tasks through advanced reasoning and decision-making techniques, such as Chain of Thought (CoT) and Tree of Thoughts.\n\n2. **Memory Management**: The categorization of memory into sensory, short-term, and long-term types parallels machine learning concepts, with short-term memory facilitating in-context learning and long-term memory enhanced by external storage solutions.\n\n3. **Tool Use and APIs**: Autonomous agents utilize external APIs to expand their capabilities, demonstrating adaptability and improved problem-solving skills.\n\n4. **Search Algorithms**: Various approximate nearest neighbor search algorithms, including Locality-Sensitive Hashing (LSH) and FAISS, are discussed for enhancing search efficiency in high-dimensional spaces.\n\n5. **Neuro-Symbolic Architectures**: The integration of neuro-symbolic systems, such as the MRKL framework, combines expert modules with LLMs to improve problem-solving, particularly in complex tasks.\n\n6. **Challenges and Innovations**: The documents address challenges like hallucination and inefficient planning in LLMs, alongside innovative methods such as Chain of Hindsight (CoH) and Algorithm Distillation (AD) for performance enhancement.\n\n7. **Software Development Practices**: The use of LLMs in software development is explored, particularly in creating structured applications like a Super Mario game using the model-view-controller (MVC) architecture, emphasizing task management, component organization, and documentation.\n\n8. **Limitations of LLMs**: Constraints such as finite context length and challenges in long-term planning are acknowledged, along with concerns regarding the reliability of natural language as an interface.\n\nOverall, the integration of LLMs and neuro-symbolic architectures signifies a significant evolution in AI, with ongoing research focused on enhancing planning, memory management, and problem-solving capabilities across various applications.'}}
Next stepsβ
Check out the LangGraph documentation for detail on building with LangGraph, including this guide on the details of map-reduce in LangGraph.
See the summarization how-to guides for additional summarization strategies, including those designed for larger volumes of text.
See also this tutorial for more detail on summarization.