-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
150 lines (128 loc) · 5.41 KB
/
app.py
File metadata and controls
150 lines (128 loc) · 5.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import streamlit as st
import sounddevice as sd
from scipy.io.wavfile import write
import numpy as np
from transformers import pipeline
import time
import librosa # librosa is often a dependency for audio processing in transformers
# --- CONFIGURATION ---
# This is a well-regarded model for environmental sound classification.
MODEL_NAME = "MIT/ast-finetuned-audioset-10-10-0.9437"
# The model was trained on audio with this sample rate. We must match it.
SAMPLE_RATE = 16000
# The duration of each audio clip we'll capture from the microphone.
RECORD_DURATION = 4 # seconds
# --- SOUND CATEGORIES ---
# We'll map the model's predictions to these broader categories.
# You can customize these lists based on what you find focus-friendly or distracting.
FOCUS_FRIENDLY_SOUNDS = [
"Speech", "Music", "Silence", "White noise", "Water",
"Wind", "Bird", "Rain", "Typing"
]
DISTRACTING_SOUNDS = [
"Siren", "Dog", "Cat", "Vehicle", "Alarm", "Door",
"Dishes, pots, and pans", "Crying, sobbing", "Yell", "Screaming"
]
# --- LOAD THE AI MODEL ---
# Using st.cache_resource ensures the model is loaded only once, speeding up the app.
@st.cache_resource
def load_model():
"""
Loads the audio classification model from Hugging Face.
Displays an info message while loading.
"""
st.info("Loading AI model... This might take a moment on the first run.")
try:
classifier = pipeline("audio-classification", model=MODEL_NAME)
st.success("AI model loaded successfully!")
return classifier
except Exception as e:
st.error(f"Error loading model: {e}")
return None
# --- CORE FUNCTIONS ---
def record_audio(duration, sample_rate):
"""
Records audio from the default microphone for a given duration and sample rate.
"""
st.info("Listening...")
# Record audio data. 'channels=1' for mono audio.
recording = sd.rec(int(duration * sample_rate), samplerate=sample_rate, channels=1)
sd.wait() # Wait for the recording to complete
st.success("Processing sound...")
# Flatten the NumPy array to a 1D array, which is what the model expects.
return recording.flatten()
def classify_audio(classifier, audio_data):
"""
Classifies the audio data using the loaded Hugging Face pipeline.
"""
if classifier is None:
return {"label": "Error", "score": 0}
# The pipeline expects the raw audio data as input.
# We ask for the top prediction with top_k=1.
prediction = classifier(audio_data, top_k=1)
return prediction[0] # Return the dictionary of the top prediction
def map_to_focus_category(prediction):
"""
Maps a prediction label from the model to our custom focus categories.
"""
label = prediction['label']
score = prediction['score']
# Check if any keyword from our lists is in the model's predicted label.
if any(keyword.lower() in label.lower() for keyword in FOCUS_FRIENDLY_SOUNDS):
return "✅ Focus-Friendly", label, score
elif any(keyword.lower() in label.lower() for keyword in DISTRACTING_SOUNDS):
return "❌ Distracting", label, score
else:
# If the sound is not explicitly in either list, we can default it.
# Let's default to distracting as a safe bet for a focus app.
return "⚠️ Neutral / Unknown", label, score
# --- STREAMLIT USER INTERFACE ---
st.set_page_config(layout="wide")
st.title("🎙️ Real-Time Focus Environment Classifier")
st.write(
"This app uses an AI model to listen to your environment through your microphone "
"and determines if it's a good place to focus."
)
# Load the model and store it in the classifier variable.
classifier = load_model()
# Initialize session state to keep track of whether the analysis is running.
if 'is_running' not in st.session_state:
st.session_state.is_running = False
# Create columns for the buttons for a cleaner layout.
col1, col2 = st.columns(2)
with col1:
if st.button("🚀 Start Analyzing", type="primary"):
if classifier:
st.session_state.is_running = True
else:
st.error("Cannot start, AI model failed to load.")
with col2:
if st.button("🛑 Stop Analyzing"):
st.session_state.is_running = False
st.info("Analysis stopped.")
# --- MAIN ANALYSIS LOOP ---
if st.session_state.is_running:
st.markdown("---")
st.success("Live analysis is active. I'm listening...")
# Create placeholders that we can update in real-time within the loop.
status_col, sound_col = st.columns(2)
with status_col:
focus_status_placeholder = st.empty()
with sound_col:
sound_type_placeholder = st.empty()
# The loop continues as long as 'is_running' is True.
while st.session_state.is_running:
# 1. Record a clip of audio.
audio_clip = record_audio(RECORD_DURATION, SAMPLE_RATE)
# 2. Classify the audio clip.
top_prediction = classify_audio(classifier, audio_clip)
# 3. Map the result to our focus categories.
focus_category, detected_sound, confidence_score = map_to_focus_category(top_prediction)
# 4. Update the UI placeholders with the new results.
focus_status_placeholder.metric("Environment Status", focus_category)
sound_type_placeholder.metric(
f"Detected Sound: *{detected_sound}*",
f"Confidence: {confidence_score:.2%}"
)
# A short pause before the next cycle to prevent overwhelming the UI.
time.sleep(1)