I followed the instructions provided in the README
file on the GitHub repository of meta-llama
to download the required llama2-7b-chat
model using CLI.
Now I have files that include params.json
, consolidated.pth
(I believe that's the weights file) and tokenizer.model
(MODEL File).
My goal is to integrate it into the Streamlit application. However, I am not able to figure out how to do so. Can anyone help me understand how to get it running in Streamlit?
This is the first time I am integrating the LLM model into Streamlit. Please help me with any relevant links/codes.
NOTE: I have not used huggingface anywhere and I don't want huggingface solution.
import streamlit as st
from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
import json
import torch
@st.cache_resource
def load_model_and_tokenizer():
# Step 1: Load configuration locally
config_path = r"path\\params.json"
with open(config_path, "r") as f:
config_data = json.load(f)
config = LlamaConfig(**config_data)
# Step 2: Initialize model architecture
model = LlamaForCausalLM(config)
# Step 3: Load weights
weights_path = r"path\\consolidated.pth"
state_dict = torch.load(weights_path, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
# Step 4: Load tokenizer
tokenizer_path = r"path\\tokenizer.model"
tokenizer = LlamaTokenizer(tokenizer_path)
return model, tokenizer
# Load model and tokenizer
model, tokenizer = load_model_and_tokenizer()
# Streamlit app
st.title("LLaMA Chatbot")
user_input = st.text_input("Enter your message:")
if user_input:
with st.spinner("Generating response..."):
inputs = tokenizer(user_input, return_tensors="pt")
outputs = model.generate(inputs.input_ids, max_length=200)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
st.write("### Response:")
st.write(response)
This code was provided by ChatGPT and it shows an error for LlamaForCausalLM
. I not quite getting my head around this code. Is this even the correct way?
发布者:admin,转转请注明出处:http://www.yc00.com/questions/1745648154a4638094.html
评论列表(0条)