Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Dict, List, Tuple, Any, Optional | |
| from pydantic import BaseModel, Field | |
| import streamlit as st | |
| from dotenv import load_dotenv | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_openai import ChatOpenAI | |
| from langchain_groq import ChatGroq | |
| from langgraph.graph import StateGraph, END | |
| # Load environment variables (still useful as fallback) | |
| load_dotenv() | |
| # Configure page | |
| st.set_page_config(page_title="AI Blog Generator", layout="wide") | |
| # API Key handling in sidebar | |
| with st.sidebar: | |
| st.title("Configuration") | |
| # LLM Provider Selection | |
| provider = st.radio("LLM Provider", ["OpenAI", "Groq"]) | |
| if provider == "OpenAI": | |
| openai_api_key = st.text_input("OpenAI API Key", type="password", help="Enter your OpenAI API key here") | |
| model = st.selectbox("Model", ["gpt-3.5-turbo", "gpt-4", "gpt-4o"]) | |
| if openai_api_key: | |
| os.environ["OPENAI_API_KEY"] = openai_api_key | |
| else: # Groq | |
| groq_api_key = st.text_input("Groq API Key", type="password", help="Enter your Groq API key here") | |
| model = st.selectbox("Model", ["llama-3.3-70b-versatile","gemma2-9b-it","qwen-2.5-32b","mistral-saba-24b", "deepseek-r1-distill-qwen-32b"]) | |
| if groq_api_key: | |
| os.environ["GROQ_API_KEY"] = groq_api_key | |
| st.divider() | |
| st.write("## About") | |
| st.write("This app uses LangGraph to generate structured blog posts with a multi-step workflow.") | |
| st.write("Made with ❤️ using LangGraph and Streamlit") | |
| # Define the state schema | |
| class BlogGeneratorState(BaseModel): | |
| topic: str = Field(default="") | |
| audience: str = Field(default="") | |
| tone: str = Field(default="") | |
| word_count: int = Field(default=500) | |
| outline: List[str] = Field(default_factory=list) | |
| sections: Dict[str, str] = Field(default_factory=dict) | |
| final_blog: str = Field(default="") | |
| error: Optional[str] = Field(default=None) | |
| # Initialize LLM based on selected provider | |
| def get_llm(): | |
| global provider, model | |
| if provider == "OpenAI": | |
| if not os.environ.get("OPENAI_API_KEY"): | |
| st.error("Please enter your OpenAI API key in the sidebar") | |
| st.stop() | |
| return ChatOpenAI(model=model, temperature=0.7) | |
| else: # Groq | |
| if not os.environ.get("GROQ_API_KEY"): | |
| st.error("Please enter your Groq API key in the sidebar") | |
| st.stop() | |
| return ChatGroq(model=model, temperature=0.7) | |
| # Create prompt templates | |
| outline_prompt = ChatPromptTemplate.from_template( | |
| """You are a professional blog writer. Create an outline for a blog post about {topic}. | |
| The audience is {audience} and the tone should be {tone}. | |
| The blog should be approximately {word_count} words. | |
| Return ONLY the outline as a list of section headings (without numbers or bullets). | |
| Each heading should be concise and engaging.""" | |
| ) | |
| section_prompt = ChatPromptTemplate.from_template( | |
| """Write content for the following section of a blog post about {topic}: | |
| Section: {section} | |
| The audience is {audience} and the tone should be {tone}. | |
| Make this section approximately {section_word_count} words. | |
| Make the content engaging, informative, and valuable to the reader. | |
| Return ONLY the content for this section, without the heading.""" | |
| ) | |
| final_assembly_prompt = ChatPromptTemplate.from_template( | |
| """You have a blog post with the following sections: | |
| {sections_content} | |
| Format this into a complete, professional blog post in Markdown format with: | |
| 1. An engaging title at the top as an H1 heading | |
| 2. A brief introduction before the first section | |
| 3. Each section heading as an H2 | |
| 4. A conclusion at the end | |
| 5. Proper spacing between sections | |
| 6. 2-3 relevant markdown formatting elements like bold, italic, blockquotes, or bullet points where appropriate | |
| The blog should maintain the {tone} tone and be targeted at {audience}. | |
| Make it flow naturally between sections.""" | |
| ) | |
| # Define the nodes for the graph | |
| def get_outline(state: BlogGeneratorState) -> BlogGeneratorState: | |
| """Generate an outline for the blog post.""" | |
| try: | |
| llm = get_llm() | |
| chain = outline_prompt | llm | |
| response = chain.invoke({ | |
| "topic": state.topic, | |
| "audience": state.audience, | |
| "tone": state.tone, | |
| "word_count": state.word_count | |
| }) | |
| # Parse the outline into a list | |
| output_text = response.content | |
| outline = [line.strip() for line in output_text.split('\n') if line.strip()] | |
| return BlogGeneratorState(**{**state.model_dump(), "outline": outline}) | |
| except Exception as e: | |
| st.error(f"Outline Error: {str(e)}") | |
| return BlogGeneratorState(**{**state.model_dump(), "error": f"Error generating outline: {str(e)}"}) | |
| def generate_sections(state: BlogGeneratorState) -> BlogGeneratorState: | |
| """Generate content for each section in the outline.""" | |
| if state.error: | |
| return state | |
| sections = {} | |
| section_word_count = state.word_count // len(state.outline) | |
| try: | |
| llm = get_llm() | |
| chain = section_prompt | llm | |
| # Show progress | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| for i, section in enumerate(state.outline): | |
| status_text.text(f"Generating section {i+1}/{len(state.outline)}: {section}") | |
| response = chain.invoke({ | |
| "topic": state.topic, | |
| "section": section, | |
| "audience": state.audience, | |
| "tone": state.tone, | |
| "section_word_count": section_word_count | |
| }) | |
| sections[section] = response.content | |
| progress_bar.progress((i + 1) / len(state.outline)) | |
| status_text.empty() | |
| progress_bar.empty() | |
| return BlogGeneratorState(**{**state.model_dump(), "sections": sections}) | |
| except Exception as e: | |
| return BlogGeneratorState(**{**state.model_dump(), "error": f"Error generating sections: {str(e)}"}) | |
| def assemble_blog(state: BlogGeneratorState) -> BlogGeneratorState: | |
| """Assemble the final blog post in Markdown format.""" | |
| if state.error: | |
| return state | |
| try: | |
| llm = get_llm() | |
| chain = final_assembly_prompt | llm | |
| sections_content = "\n\n".join([f"Section: {heading}\nContent: {content}" | |
| for heading, content in state.sections.items()]) | |
| response = chain.invoke({ | |
| "sections_content": sections_content, | |
| "tone": state.tone, | |
| "audience": state.audience | |
| }) | |
| final_blog = response.content | |
| return BlogGeneratorState(**{**state.model_dump(), "final_blog": final_blog}) | |
| except Exception as e: | |
| return BlogGeneratorState(**{**state.model_dump(), "error": f"Error assembling blog: {str(e)}"}) | |
| # Define the workflow graph | |
| def create_blog_generator_graph(): | |
| workflow = StateGraph(BlogGeneratorState) | |
| # Add nodes | |
| workflow.add_node("get_outline", get_outline) | |
| workflow.add_node("generate_sections", generate_sections) | |
| workflow.add_node("assemble_blog", assemble_blog) | |
| # Add edges | |
| workflow.add_edge("get_outline", "generate_sections") | |
| workflow.add_edge("generate_sections", "assemble_blog") | |
| workflow.add_edge("assemble_blog", END) | |
| # Set the entry point | |
| workflow.set_entry_point("get_outline") | |
| return workflow.compile() | |
| # Create the Streamlit app main content | |
| st.title("AI Blog Generator") | |
| st.write("Generate professional blog posts with a structured workflow") | |
| with st.form("blog_generator_form"): | |
| topic = st.text_input("Blog Topic", placeholder="E.g., Sustainable Living in Urban Environments") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| audience = st.text_input("Target Audience", placeholder="E.g., Young professionals") | |
| tone = st.selectbox("Tone", ["Informative", "Conversational", "Professional", "Inspirational", "Technical"]) | |
| with col2: | |
| word_count = st.slider("Approximate Word Count", min_value=300, max_value=2000, value=800, step=100) | |
| submit_button = st.form_submit_button("Generate Blog") | |
| if submit_button: | |
| if (provider == "OpenAI" and not os.environ.get("OPENAI_API_KEY")) or \ | |
| (provider == "Groq" and not os.environ.get("GROQ_API_KEY")): | |
| st.error(f"Please enter your {provider} API key in the sidebar before generating a blog") | |
| elif not topic or not audience: | |
| st.error("Please fill out all required fields.") | |
| else: | |
| with st.spinner(f"Initializing blog generation using {provider} {model}..."): | |
| try: | |
| # Initialize the graph | |
| blog_generator = create_blog_generator_graph() | |
| # Set the initial state | |
| initial_state = BlogGeneratorState( | |
| topic=topic, | |
| audience=audience, | |
| tone=tone, | |
| word_count=word_count | |
| ) | |
| # Run the graph | |
| result = blog_generator.invoke(initial_state) | |
| # Check if result is a dict and has expected keys | |
| if isinstance(result, dict): | |
| final_blog = result.get("final_blog", "") | |
| outline = result.get("outline", []) | |
| error = result.get("error") | |
| if error: | |
| st.error(f"Error: {error}") | |
| elif final_blog: | |
| # Display the blog post | |
| st.success("Blog post generated successfully!") | |
| st.subheader("Generated Blog Post") | |
| st.markdown(final_blog) | |
| # Download button for the blog post | |
| st.download_button( | |
| label="Download Blog as Markdown", | |
| data=final_blog, | |
| file_name=f"{topic.replace(' ', '_').lower()}_blog.md", | |
| mime="text/markdown", | |
| ) | |
| # Show metadata about the generation | |
| st.info(f"Generated using {provider} {model}") | |
| # Optionally show the outline | |
| with st.expander("View Blog Outline"): | |
| for i, section in enumerate(outline, 1): | |
| st.write(f"{i}. {section}") | |
| else: | |
| st.error("Blog generation completed but no content was produced") | |
| else: | |
| st.error(f"Unexpected result type: {type(result)}") | |
| except Exception as e: | |
| st.error(f"An error occurred: {str(e)}") | |
| st.info("Please check your API key and try again.") |