/*-
 * Copyright (c) 2008 Ariff Abdullah <ariff@FreeBSD.org>
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 *
 * $FreeBSD$
 */

#include <err.h>
#include <errno.h>
#include <stdio.h>
#include <sysexits.h>

#include "pcm.h"
#include "waveutil.h"

#ifndef min
#define min(x, y)       ((x) < (y) ? (x) : (y))
#endif

int
wave_header_read(struct wave_info *w)
{
        uint32_t sz;
        size_t readsz;
        uint8_t buf[64];

        if (w->fp == NULL) {
                warnx("%s(): no stream assigned", __func__);
                return (EIO);
        }

        w->format = WAVE_FORMAT_UNKNOWN;
        w->bit = 0;
        w->channels = 0;
        w->rate = 0;
        w->endian = WAVE_UNKNOWN_ENDIAN;
        w->header_size = 0;
        w->data_size = 0;
        w->seekable = (ftell(w->fp) == -1) ? 0 : 1;

        for (;;) {
                readsz = fread(buf, 1, 8, w->fp);
                if (readsz != 8) {
                        warnx("%s(): io error [1]", __func__);
                        return (EIO);
                }
                sz = WAVE_READ_32(w, buf + 4);
                w->header_size += 8;
                switch (WAVE_MAGIC_READ(buf)) {
                case WAVE_MAGIC_RIFX:
                        w->endian = WAVE_BIG_ENDIAN;
                case WAVE_MAGIC_RIFF:
                        if (w->endian == WAVE_UNKNOWN_ENDIAN)
                                w->endian = WAVE_LITTLE_ENDIAN;
                        readsz = fread(buf, 1, 4, w->fp);
                        if (readsz != 4) {
                                warnx("%s(): io error [2]", __func__);
                                return (EIO);
                        }
                        if (WAVE_MAGIC_READ(buf) != WAVE_MAGIC_WAVE) {
                                warnx("%s(): not a WAVE file", __func__);
                                return (EFTYPE);
                        }
                        w->header_size += 4;
                        break;
                case WAVE_MAGIC_FMT:
                        if (!(sz == 16 || sz == 18 || sz == 40)) {
                                warnx("%s(): illegal header size=%u",
                                    __func__, sz);
                                return (EFTYPE);
                        }
                        readsz = fread(buf, 1, sz, w->fp);
                        if (readsz != (size_t)sz) {
                                warnx("%s(): io error [3]", __func__);
                                return (EIO);
                        }
                        w->format = WAVE_READ_16(w, buf);
                        switch (w->format) {
                        case WAVE_FORMAT_PCM:
                        case WAVE_FORMAT_ALAW:
                        case WAVE_FORMAT_ULAW:
                        case WAVE_FORMAT_EXT:
                                break;
                        default:
                                warnx("%s(): unsupported format 0x%04x",
                                    __func__, w->format);
                                return (EFTYPE);
                                break;
                        }
                        w->channels = WAVE_READ_16(w, buf + 2);
                        w->rate = WAVE_READ_32(w, buf + 4);
                        w->bit = WAVE_READ_16(w, buf + 14);
                        w->header_size += sz;
                        break;
                case WAVE_MAGIC_DATA:
                        w->data_size = sz;
                        return (0);
                        break;
                default:
                        if (w->header_size == 8) {
                                warnx("%s(): not a RIFF file", __func__);
                                return (EFTYPE);
                        }
                        while (sz > 0) {
                                readsz = fread(buf, 1, min(sz, sizeof(buf)),
                                    w->fp);
                                if (readsz < 1) {
                                        warnx("%s(): io error [4]", __func__);
                                        return (EIO);
                                }
                                w->header_size += readsz;
                                sz -= readsz;
                        }
                        break;
                }
        }

        /* NOTREACHED */
        return (0);
}

int
wave_header_write(struct wave_info *w)
{
        long pos;
        ssize_t total;
        int ext;
        uint8_t buf[4];

        if (w->fp == NULL) {
                warnx("%s(): no stream assigned", __func__);
                return (EIO);
        }

        if (!((w->format == WAVE_FORMAT_PCM ||
            w->format == WAVE_FORMAT_EXT ||
            w->format == WAVE_FORMAT_ULAW ||
            w->format == WAVE_FORMAT_ALAW) &&
            (w->endian == WAVE_LITTLE_ENDIAN ||
            w->endian == WAVE_BIG_ENDIAN))) {
                warnx("%s(): format=0x%08x endian=%u unknown",
                    __func__, w->format, w->endian);
                return (EFTYPE);
        }

        if (w->format == WAVE_FORMAT_EXT)
                ext = 24;
        else if (w->format == WAVE_FORMAT_ULAW ||
            w->format == WAVE_FORMAT_ALAW)
                ext = 2;
        else
                ext = 0;

        if (w->seekable != 0) {
                if ((pos = ftell(w->fp)) == -1 ||
                    fseek(w->fp, 0L, SEEK_SET) == -1)
                        return (EIO);
                pos -= WAVE_HEADER_SIZE + ext + ((ext != 0) ? 12 : 0);
                if (w->data_size != (uint32_t)pos)
                        warnx("%s(): WARNING data_size=%u != pos=%u",
                            __func__, w->data_size, (uint32_t)pos);
        }

        if (w->endian == WAVE_LITTLE_ENDIAN)
                WAVE_MAGIC_WRITE(buf, WAVE_MAGIC_RIFF);
        else
                WAVE_MAGIC_WRITE(buf, WAVE_MAGIC_RIFX);
        total = fwrite(buf, 1, 4, w->fp);
        WAVE_WRITE_32(w, buf, w->data_size + WAVE_HEADER_SIZE + ext +
            (!(w->format == WAVE_FORMAT_PCM || w->format == WAVE_FORMAT_EXT) ?
            12 : 0) - 8);
        total += fwrite(buf, 1, 4, w->fp);
        WAVE_MAGIC_WRITE(buf, WAVE_MAGIC_WAVE);
        total += fwrite(buf, 1, 4, w->fp);
        WAVE_MAGIC_WRITE(buf, WAVE_MAGIC_FMT);
        total += fwrite(buf, 1, 4, w->fp);
        WAVE_WRITE_32(w, buf, 16 + ext);
        total += fwrite(buf, 1, 4, w->fp);
        WAVE_WRITE_16(w, buf, w->format);
        total += fwrite(buf, 1, 2, w->fp);
        WAVE_WRITE_16(w, buf, w->channels);
        total += fwrite(buf, 1, 2, w->fp);
        WAVE_WRITE_32(w, buf, w->rate);
        total += fwrite(buf, 1, 4, w->fp);
        WAVE_WRITE_32(w, buf, w->rate * WAVE_BLOCK_ALIGN(w));
        total += fwrite(buf, 1, 4, w->fp);
        WAVE_WRITE_16(w, buf, WAVE_BLOCK_ALIGN(w));
        total += fwrite(buf, 1, 2, w->fp);
        WAVE_WRITE_16(w, buf, w->bit);
        total += fwrite(buf, 1, 2, w->fp);
        if (ext == 2) {
                WAVE_WRITE_16(w, buf, 0);
                total += fwrite(buf, 1, 2, w->fp);
        } else if (ext == 24) {
                WAVE_WRITE_16(w, buf, 22);
                total += fwrite(buf, 1, 2, w->fp);
                WAVE_WRITE_16(w, buf, w->bit);
                total += fwrite(buf, 1, 2, w->fp);
                WAVE_WRITE_32(w, buf, 0);
                total += fwrite(buf, 1, 4, w->fp);
                WAVE_WRITE_16(w, buf, WAVE_FORMAT_PCM);
                total += fwrite(buf, 1, 2, w->fp);
                total += fwrite("\x00\x00\x00\x00\x10\x00\x80\x00"
                    "\x00\xAA\x00\x38\x9B\x71", 1, 14, w->fp);
        }
        if (ext != 0) {
                WAVE_MAGIC_WRITE(buf, WAVE_MAGIC_FACT);
                total += fwrite(buf, 1, 4, w->fp);
                WAVE_WRITE_32(w, buf, 4);
                total += fwrite(buf, 1, 4, w->fp);
                WAVE_WRITE_32(w, buf,
                    w->data_size / WAVE_BLOCK_ALIGN(w));
                total += fwrite(buf, 1, 4, w->fp);
                ext += 12;
        }
        WAVE_MAGIC_WRITE(buf, WAVE_MAGIC_DATA);
        total += fwrite(buf, 1, 4, w->fp);
        WAVE_WRITE_32(w, buf, w->data_size);
        total += fwrite(buf, 1, 4, w->fp);

        if (total != (WAVE_HEADER_SIZE + ext)) {
                warnx("%s(): WARNING: header write size=%zu != %u",
                    __func__, total, WAVE_HEADER_SIZE + ext);
                return (EIO);
        }

        return (0);
}