import pyaudio
from typing import Optional
import numpy as np
import struct
import time

PREFERRED_API = "WASAPI"
MICROPHONE_NAME = "Beyond"
PREFERRED_FORMAT = pyaudio.paInt16
PREFERRED_CHANNELS = 2
PREFERRED_RATE = 48000
PREFERRED_CHUNK = int(PREFERRED_RATE / 100) # about every 10ms

mngr = pyaudio.PyAudio()
audiostream = None

def connect_mic() -> Optional[pyaudio.Stream]:
    global audiostream, mngr

    # Ensure the audio stream is closed before opening a new one
    if audiostream is not None:
        mngr.close(audiostream)
        mngr.terminate()
        audiostream.stop_stream()
        audiostream.close()
        mngr = pyaudio.PyAudio()

    apis = []
    preferred_api_index = 0 # default in case preferred api is not found
    audiostream = None
    apis = [mngr.get_host_api_info_by_index(i) for i in range(mngr.get_host_api_count())]
    devs = [mngr.get_device_info_by_index(i) for i in range(mngr.get_device_count())]

    # find our preferred API
    for api in apis:
        if(PREFERRED_API in api['name']):
            preferred_api_index = api['index']

    # find our microphone
    for dev in devs:
        if(dev['hostApi'] == preferred_api_index and MICROPHONE_NAME in dev['name']):

            audiostream = mngr.open(format=PREFERRED_FORMAT,
                                    channels=PREFERRED_CHANNELS,
                                    rate=PREFERRED_RATE,
                                    input=True,
                                    input_device_index=dev['index'],
                                    frames_per_buffer=PREFERRED_CHUNK)

    return audiostream


def get_mic_data(audiostream: pyaudio.Stream) -> np.array:
    try:
        chunk_data = audiostream.read(PREFERRED_CHUNK) # returns bytes
        newdata = struct.unpack('<'+str(2*PREFERRED_CHUNK)+'h', chunk_data)
        retdata = np.vstack((np.array(newdata[::2]).astype(float), 
                                  np.array(newdata[1::2]).astype(float)))
        return retdata
    except Exception as e:
        return np.array(())

# reads all available frames in the receive buffer to clear it   
def clear_stream(audiostream: pyaudio.Stream) -> None:
    num_frames_to_read = audiostream.get_read_available()
    audiostream.read(num_frames_to_read)

# records data from microphone for time_ms milliseconds and 
# returns separate left and right channel values
def get_mic_over_time(audiostream:pyaudio.Stream, time_ms: int) -> Optional[np.array]:
    # get up to NOW by clearing the stream
    clear_stream(audiostream)

    mic_data = None
    start_time = time.monotonic_ns()
    # attempt to grab the exact number of samples for 
    # this length of time
    num_samps = time_ms*PREFERRED_RATE / 1000

    timeout_ms = time_ms + 200 # give it a little extra time for
    # remaining samples. exact timeout tends to under-record

    while( (start_time + (timeout_ms*1000000)) > time.monotonic_ns()):
        try:
            if(mic_data is None):
                mic_data = get_mic_data(audiostream)
            else:
                mic_data = np.hstack((mic_data, get_mic_data(audiostream)))
            if(mic_data.shape[1] >= (num_samps)):
                break
        except:
            pass

    return mic_data

if __name__ == '__main__':
    quiet_sample_length = 200
    num_quiet_samples = 5

    test_length = 10

    hit_threshold_multiplier = 10

    bmic = connect_mic()
    
    print("now recording normal level. please remain quiet")
    avg_l = []
    avg_r = []
    # collect 5 samples at "quiet" level and find the peak value
    # in each sample
    for i in range(num_quiet_samples):
        mdat = get_mic_over_time(bmic, quiet_sample_length)
        avg_l.append(np.max(np.power(mdat[0,:],2)))
        avg_r.append(np.max(np.power(mdat[1,:],2)))
    avg_l = np.mean(np.array(avg_l))
    avg_r = np.mean(np.array(avg_r))
    
    numhits = 0

    print('Left average: {}'.format(avg_l))
    print('Right average: {}'.format(avg_r))

    print("start clapping!")

    start_time = time.monotonic_ns()
    stop_time = start_time + test_length*1e9 
    while(time.monotonic_ns() < stop_time):
        # grab a short segment of audio
        mdat = get_mic_over_time(bmic, 200)
        # find if peaks occurred
        lpeaks = np.mean(np.power(mdat[0,:].reshape(-1,48),2), 1) # downsample by averaging 48 samples
        rpeaks = np.mean(np.power(mdat[1,:].reshape(-1,48),2), 1)
        if((np.max(lpeaks) > avg_l*hit_threshold_multiplier) and (np.max(rpeaks) > avg_r*hit_threshold_multiplier)):
            numhits = numhits + 1
            print("hit!")

        if(numhits >= 5):
            print("success!")
            break